Commit 19877c95 authored by Khalique's avatar Khalique
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into pow_op

parents 8fd95296 21d4395e
......@@ -68,11 +68,13 @@ struct onnx_parser
add_mem_op("ArgMax", &onnx_parser::parse_argmax);
add_mem_op("ArgMin", &onnx_parser::parse_argmin);
add_mem_op("Cast", &onnx_parser::parse_cast);
add_mem_op("Clip", &onnx_parser::parse_clip);
add_mem_op("LRN", &onnx_parser::parse_lrn);
add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
add_mem_op("Elu", &onnx_parser::parse_elu);
add_mem_op("Expand", &onnx_parser::parse_expand);
add_mem_op("Constant", &onnx_parser::parse_constant);
add_mem_op("Conv", &onnx_parser::parse_conv);
add_mem_op("MaxPool", &onnx_parser::parse_pooling);
......@@ -93,6 +95,7 @@ struct onnx_parser
add_mem_op("Gather", &onnx_parser::parse_gather);
add_mem_op("Shape", &onnx_parser::parse_shape);
add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
add_mem_op("ConstantOfShape", &onnx_parser::parse_constant_of_shape);
add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("RNN", &onnx_parser::parse_rnn);
add_mem_op("GRU", &onnx_parser::parse_gru);
......@@ -464,8 +467,7 @@ struct onnx_parser
if(args.size() == 2)
{
auto s = args[1]->eval();
if(s.empty())
MIGRAPHX_THROW("Dynamic shape is not supported.");
check_arg_empty(s, "Reshape: dynamic shape is not supported");
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
}
return prog.add_instruction(op, args[0]);
......@@ -544,7 +546,13 @@ struct onnx_parser
attribute_map attributes,
const std::vector<instruction_ref>&)
{
literal v = parse_value(attributes.at("value"));
literal v = parse_value(attributes.at("value"));
// return empty literal
if(v.get_shape().elements() == 0)
{
return prog.add_literal(literal{});
}
auto dim_size = attributes.at("value").t().dims_size();
// if dim_size is 0, it is a scalar
if(dim_size == 0)
......@@ -872,10 +880,7 @@ struct onnx_parser
}
migraphx::argument in = args[0]->eval();
if(in.empty())
{
MIGRAPHX_THROW("ConstantFill: cannot handle dynamic shape as input");
}
check_arg_empty(in, "ConstantFill: dynamic shape is not supported");
std::vector<std::size_t> dims;
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
......@@ -903,6 +908,74 @@ struct onnx_parser
}
}
instruction_ref parse_constant_of_shape(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
literal l_val{};
if(contains(attributes, "value"))
{
l_val = parse_value(attributes.at("value"));
if(l_val.get_shape().elements() != 1)
{
MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!");
}
}
else
{
l_val = literal({shape::float_type, {1}, {0}}, {0.0f});
}
// input is empty, output is a scalar
auto type = l_val.get_shape().type();
if(args.empty())
{
MIGRAPHX_THROW("ConstantOfShape : must have 1 input!");
}
else
{
migraphx::shape s;
// empty input tensor, output is a scalar
if(args[0]->get_shape().elements() == 0)
{
s = migraphx::shape{type, {1}, {0}};
}
else
{
migraphx::argument in = args[0]->eval();
check_arg_empty(in, "ConstantOfShape: dynamic shape is not supported");
std::vector<std::size_t> dims;
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
s = migraphx::shape{type, dims};
}
literal l_out{};
l_val.visit([&](auto val) {
using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
// l_val contains only one element
std::vector<val_type> out_vec(s.elements(), *val.begin());
l_out = literal(s, out_vec);
});
return prog.add_literal(l_out);
}
}
instruction_ref
parse_expand(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
auto in_lens = args[0]->get_shape().lens();
migraphx::argument arg_s = args[1]->eval();
check_arg_empty(arg_s, "Expand: dynamic shape is not supported");
std::vector<std::size_t> dims;
arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims);
return prog.add_instruction(op::multibroadcast{out_lens}, args[0]);
}
std::vector<instruction_ref>
parse_rnn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
......@@ -1325,6 +1398,19 @@ struct onnx_parser
}
}
instruction_ref
parse_cast(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
if(!contains(attributes, "to"))
{
MIGRAPHX_THROW("PARSE_CAST: missing to type attribute!");
}
int to_type = parse_value(attributes.at("to")).at<int>();
shape::type_t type = get_type(to_type);
return prog.add_instruction(op::convert{type}, std::move(args));
}
void parse_from(std::istream& is)
{
onnx::ModelProto model;
......@@ -1637,6 +1723,14 @@ struct onnx_parser
}
}
}
void check_arg_empty(const argument& arg, const std::string& msg)
{
if(arg.empty())
{
MIGRAPHX_THROW(msg);
}
}
};
program parse_onnx(const std::string& name)
......
......@@ -8,6 +8,7 @@ namespace device {
void reduce_sum(hipStream_t stream, const argument& result, const argument& arg)
{
reduce(stream, result, arg, sum{}, 0, id{}, id{});
}
......
 cast-example:F

xy"Cast*
to test_castZ
x



b
y


B
constant-of-shape:
6shape"Constant*#
value**B shape_tensor 
7
shapey"ConstantOfShape*
value*:
Bvalue constant_of_shapeb
y



B
constant-of-shape:
6shape"Constant*#
value**B shape_tensor 

shapey"ConstantOfShapeconstant_of_shapeb
y



B
expand:
7shape"Constant*$
value**B shape_tensor

x
shapey"ExpandexpandZ
x



b
y




B
......@@ -925,4 +925,78 @@ TEST_CASE(pow_test)
EXPECT(p == prog);
}
TEST_CASE(cast_test)
{
migraphx::program p;
auto l = p.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {10}});
p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, l);
auto prog = migraphx::parse_onnx("cast_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(const_of_shape_float)
{
migraphx::program p;
migraphx::shape ss(migraphx::shape::int32_type, {3});
p.add_literal(migraphx::literal(ss, {2, 3, 4}));
migraphx::shape s(migraphx::shape::float_type, {2, 3, 4});
std::vector<float> vec(s.elements(), 10.0f);
p.add_literal(migraphx::literal(s, vec));
auto prog = migraphx::parse_onnx("const_of_shape1.onnx");
EXPECT(p == prog);
}
TEST_CASE(const_of_shape_int64)
{
migraphx::program p;
migraphx::shape ss(migraphx::shape::int32_type, {3});
p.add_literal(migraphx::literal(ss, {2, 3, 4}));
migraphx::shape s(migraphx::shape::int64_type, {2, 3, 4});
std::vector<int64_t> vec(s.elements(), 10);
p.add_literal(migraphx::literal(s, vec));
auto prog = migraphx::parse_onnx("const_of_shape2.onnx");
EXPECT(p == prog);
}
TEST_CASE(const_of_shape_no_value_attr)
{
migraphx::program p;
migraphx::shape ss(migraphx::shape::int32_type, {3});
p.add_literal(migraphx::literal(ss, {2, 3, 4}));
migraphx::shape s(migraphx::shape::float_type, {2, 3, 4});
std::vector<float> vec(s.elements(), 0.0f);
p.add_literal(migraphx::literal(s, vec));
auto prog = migraphx::parse_onnx("const_of_shape3.onnx");
EXPECT(p == prog);
}
TEST_CASE(const_of_shape_empty_input)
{
migraphx::program p;
p.add_literal(migraphx::literal());
migraphx::shape s(migraphx::shape::int64_type, {1}, {0});
std::vector<int64_t> vec(s.elements(), 10);
p.add_literal(migraphx::literal(s, vec));
auto prog = migraphx::parse_onnx("const_of_shape4.onnx");
EXPECT(p == prog);
}
TEST_CASE(expand_test)
{
migraphx::program p;
migraphx::shape s(migraphx::shape::float_type, {3, 1, 1});
auto param = p.add_parameter("x", s);
migraphx::shape ss(migraphx::shape::int32_type, {4});
p.add_literal(migraphx::literal(ss, {2, 3, 4, 5}));
p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, param);
auto prog = migraphx::parse_onnx("expand_test.onnx");
EXPECT(p == prog);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment