"driver/olCompiling/addkernels/include_inliner.cpp" did not exist on "d2315b0dfcd6f31cca4328819eaf60d77e952dd6"
Commit c7239863 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add the shape op

parent 157935ff
...@@ -161,6 +161,7 @@ register_migraphx_ops( ...@@ -161,6 +161,7 @@ register_migraphx_ops(
rsqrt rsqrt
scalar scalar
scatter scatter
shape_op
sigmoid sigmoid
sign sign
sinh sinh
......
#ifndef MIGRAPHX_GUARD_OPERATORS_SHAPE_OP_HPP
#define MIGRAPHX_GUARD_OPERATORS_SHAPE_OP_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/context.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct shape_op
{
std::string name() const { return "shape"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
std::vector<std::size_t> lens = {inputs[0].lens().size()};
return {shape::int64_type, lens};
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto lens = args.front().get_shape().lens();
result.visit([&](auto v) {
std::copy(lens.begin(), lens.end(), v.begin());
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -86,6 +86,7 @@ ...@@ -86,6 +86,7 @@
#include <migraphx/op/rsqrt.hpp> #include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp> #include <migraphx/op/scalar.hpp>
#include <migraphx/op/scatter.hpp> #include <migraphx/op/scatter.hpp>
#include <migraphx/op/shape_op.hpp>
#include <migraphx/op/sigmoid.hpp> #include <migraphx/op/sigmoid.hpp>
#include <migraphx/op/sign.hpp> #include <migraphx/op/sign.hpp>
#include <migraphx/op/sinh.hpp> #include <migraphx/op/sinh.hpp>
......
...@@ -19,14 +19,19 @@ struct parse_shape : op_parser<parse_shape> ...@@ -19,14 +19,19 @@ struct parse_shape : op_parser<parse_shape>
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
if(args.size() != 1) if(args.size() != 1)
{
MIGRAPHX_THROW("Shape: operator should have 1 operand"); MIGRAPHX_THROW("Shape: operator should have 1 operand");
std::vector<std::size_t> arg_shape = args[0]->get_shape().lens(); }
std::vector<int64_t> vec_shape(arg_shape.size());
migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()}); return info.add_instruction(make_op("shape"), args);
std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
return int64_t(i); // std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
}); // std::vector<int64_t> vec_shape(arg_shape.size());
return info.add_literal(migraphx::literal{s, vec_shape}); // migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()});
// std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
// return int64_t(i);
// });
// return info.add_literal(migraphx::literal{s, vec_shape});
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment