"vscode:/vscode.git/clone" did not exist on "924225ed5f393fc620344e1e907769209ed11f06"
Commit 96ca7a5e authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'bert_operators' into test_bert

parents c35131f8 badacbcc
...@@ -475,8 +475,7 @@ struct onnx_parser ...@@ -475,8 +475,7 @@ struct onnx_parser
if(args.size() == 2) if(args.size() == 2)
{ {
auto s = args[1]->eval(); auto s = args[1]->eval();
if(s.empty()) check_arg_empty(s, "Reshape: dynamic shape is not supported");
MIGRAPHX_THROW("Dynamic shape is not supported.");
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
} }
...@@ -895,10 +894,7 @@ struct onnx_parser ...@@ -895,10 +894,7 @@ struct onnx_parser
} }
migraphx::argument in = args[0]->eval(); migraphx::argument in = args[0]->eval();
if(in.empty()) check_arg_empty(in, "ConstantFill: dynamic shape is not supported");
{
MIGRAPHX_THROW("ConstantFill: cannot handle dynamic shape as input");
}
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); }); in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
...@@ -949,7 +945,7 @@ struct onnx_parser ...@@ -949,7 +945,7 @@ struct onnx_parser
if(args.empty()) if(args.empty())
{ {
MIGRAPHX_THROW("Parse ConstantOfShape : must have 1 input!"); MIGRAPHX_THROW("ConstantOfShape : must have 1 input!");
} }
else else
{ {
...@@ -962,19 +958,22 @@ struct onnx_parser ...@@ -962,19 +958,22 @@ struct onnx_parser
else else
{ {
migraphx::argument in = args[0]->eval(); migraphx::argument in = args[0]->eval();
if(in.empty()) check_arg_empty(in, "ConstantOfShape: dynamic shape is not supported");
{
MIGRAPHX_THROW("Parse ConstantOfShape: cannot handle dynamic shape as input");
}
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); }); in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
s = migraphx::shape{type, dims}; s = migraphx::shape{type, dims};
} }
literal l_out; literal l_out{};
l_val.visit([&](auto val) { l_val.visit([&](auto val) {
// this #ifdef is to avoid a false cppcheck error, will remove later
// when a newer version of cppcheck is used
#ifdef CPPCHECK
using type = float;
#else
using type = std::remove_cv_t<typename decltype(val)::value_type>; using type = std::remove_cv_t<typename decltype(val)::value_type>;
#endif
// l_val contains only one element // l_val contains only one element
std::vector<type> out_vec(s.elements(), *val.begin()); std::vector<type> out_vec(s.elements(), *val.begin());
l_out = literal(s, out_vec); l_out = literal(s, out_vec);
...@@ -989,10 +988,7 @@ struct onnx_parser ...@@ -989,10 +988,7 @@ struct onnx_parser
{ {
auto in_lens = args[0]->get_shape().lens(); auto in_lens = args[0]->get_shape().lens();
migraphx::argument arg_s = args[1]->eval(); migraphx::argument arg_s = args[1]->eval();
if(arg_s.empty()) check_arg_empty(arg_s, "Expand: dynamic shape is not supported");
{
MIGRAPHX_THROW("Parse Expand: cannot handle dynamic shape as input");
}
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); }); arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims); auto out_lens = compute_broadcasted_lens(in_lens, dims);
...@@ -1746,6 +1742,14 @@ struct onnx_parser ...@@ -1746,6 +1742,14 @@ struct onnx_parser
} }
} }
} }
void check_arg_empty(const argument& arg, const std::string& msg)
{
if(arg.empty())
{
MIGRAPHX_THROW(msg);
}
}
}; };
program parse_onnx(const std::string& name) program parse_onnx(const std::string& name)
......
...@@ -935,7 +935,7 @@ TEST_CASE(cast_test) ...@@ -935,7 +935,7 @@ TEST_CASE(cast_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(const_of_shape1) TEST_CASE(const_of_shape_float)
{ {
migraphx::program p; migraphx::program p;
migraphx::shape ss(migraphx::shape::int32_type, {3}); migraphx::shape ss(migraphx::shape::int32_type, {3});
...@@ -948,20 +948,20 @@ TEST_CASE(const_of_shape1) ...@@ -948,20 +948,20 @@ TEST_CASE(const_of_shape1)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(const_of_shape2) TEST_CASE(const_of_shape_int64)
{ {
migraphx::program p; migraphx::program p;
migraphx::shape ss(migraphx::shape::int32_type, {3}); migraphx::shape ss(migraphx::shape::int32_type, {3});
p.add_literal(migraphx::literal(ss, {2, 3, 4})); p.add_literal(migraphx::literal(ss, {2, 3, 4}));
migraphx::shape s(migraphx::shape::int64_type, {2, 3, 4}); migraphx::shape s(migraphx::shape::int64_type, {2, 3, 4});
std::vector<int64_t> vec(s.elements(), 10.0f); std::vector<int64_t> vec(s.elements(), 10);
p.add_literal(migraphx::literal(s, vec)); p.add_literal(migraphx::literal(s, vec));
auto prog = migraphx::parse_onnx("const_of_shape2.onnx"); auto prog = migraphx::parse_onnx("const_of_shape2.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(const_of_shape3) TEST_CASE(const_of_shape_no_value_attr)
{ {
migraphx::program p; migraphx::program p;
migraphx::shape ss(migraphx::shape::int32_type, {3}); migraphx::shape ss(migraphx::shape::int32_type, {3});
...@@ -974,7 +974,7 @@ TEST_CASE(const_of_shape3) ...@@ -974,7 +974,7 @@ TEST_CASE(const_of_shape3)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(const_of_shape4) TEST_CASE(const_of_shape_empty_input)
{ {
migraphx::program p; migraphx::program p;
p.add_literal(migraphx::literal()); p.add_literal(migraphx::literal());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment