Commit dfbfd078 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'shape_op' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into parse_dynamic_shape

parents ecb1545c f20d6acb
......@@ -161,6 +161,7 @@ register_migraphx_ops(
rsqrt
scalar
scatter
shape_op
sigmoid
sign
sinh
......
#ifndef MIGRAPHX_GUARD_OPERATORS_SHAPE_OP_HPP
#define MIGRAPHX_GUARD_OPERATORS_SHAPE_OP_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/context.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct shape_op
{
std::string name() const { return "shape"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
std::vector<std::size_t> lens = {inputs[0].lens().size()};
return {shape::int64_type, lens};
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto lens = args.front().get_shape().lens();
result.visit([&](auto v) { std::copy(lens.begin(), lens.end(), v.begin()); });
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -86,6 +86,7 @@
#include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp>
#include <migraphx/op/scatter.hpp>
#include <migraphx/op/shape_op.hpp>
#include <migraphx/op/sigmoid.hpp>
#include <migraphx/op/sign.hpp>
#include <migraphx/op/sinh.hpp>
......
......@@ -19,14 +19,19 @@ struct parse_shape : op_parser<parse_shape>
std::vector<instruction_ref> args) const
{
if(args.size() != 1)
{
MIGRAPHX_THROW("Shape: operator should have 1 operand");
std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
std::vector<int64_t> vec_shape(arg_shape.size());
migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()});
std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
return int64_t(i);
});
return info.add_literal(migraphx::literal{s, vec_shape});
}
return info.add_instruction(make_op("shape"), args);
// std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
// std::vector<int64_t> vec_shape(arg_shape.size());
// migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()});
// std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
// return int64_t(i);
// });
// return info.add_literal(migraphx::literal{s, vec_shape});
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment