Commit 3059c59c authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix review comments

parent b003c8fd
......@@ -467,8 +467,7 @@ struct onnx_parser
if(args.size() == 2)
{
auto s = args[1]->eval();
if(s.empty())
MIGRAPHX_THROW("Dynamic shape is not supported.");
check_arg_empty(s, "Reshape: dynamic shape is not supported");
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
}
return prog.add_instruction(op, args[0]);
......@@ -881,10 +880,7 @@ struct onnx_parser
}
migraphx::argument in = args[0]->eval();
if(in.empty())
{
MIGRAPHX_THROW("ConstantFill: cannot handle dynamic shape as input");
}
check_arg_empty(in, "ConstantFill: dynamic shape is not supported");
std::vector<std::size_t> dims;
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
......@@ -948,10 +944,7 @@ struct onnx_parser
else
{
migraphx::argument in = args[0]->eval();
if(in.empty())
{
MIGRAPHX_THROW("Parse ConstantOfShape: cannot handle dynamic shape as input");
}
check_arg_empty(in, "ConstantOfShape: dynamic shape is not supported");
std::vector<std::size_t> dims;
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
......@@ -975,10 +968,7 @@ struct onnx_parser
{
auto in_lens = args[0]->get_shape().lens();
migraphx::argument arg_s = args[1]->eval();
if(arg_s.empty())
{
MIGRAPHX_THROW("Parse Expand: cannot handle dynamic shape as input");
}
check_arg_empty(arg_s, "Expand: dynamic shape is not supported");
std::vector<std::size_t> dims;
arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims);
......@@ -1733,6 +1723,14 @@ struct onnx_parser
}
}
}
void check_arg_empty(const argument& arg, const std::string& msg)
{
if (arg.empty())
{
MIGRAPHX_THROW(msg);
}
}
};
program parse_onnx(const std::string& name)
......
......@@ -935,7 +935,7 @@ TEST_CASE(cast_test)
EXPECT(p == prog);
}
TEST_CASE(const_of_shape1)
TEST_CASE(const_of_shape_float)
{
migraphx::program p;
migraphx::shape ss(migraphx::shape::int32_type, {3});
......@@ -948,20 +948,20 @@ TEST_CASE(const_of_shape1)
EXPECT(p == prog);
}
TEST_CASE(const_of_shape2)
TEST_CASE(const_of_shape_int64)
{
migraphx::program p;
migraphx::shape ss(migraphx::shape::int32_type, {3});
p.add_literal(migraphx::literal(ss, {2, 3, 4}));
migraphx::shape s(migraphx::shape::int64_type, {2, 3, 4});
std::vector<int64_t> vec(s.elements(), 10.0f);
std::vector<int64_t> vec(s.elements(), 10);
p.add_literal(migraphx::literal(s, vec));
auto prog = migraphx::parse_onnx("const_of_shape2.onnx");
EXPECT(p == prog);
}
TEST_CASE(const_of_shape3)
TEST_CASE(const_of_shape_no_value_attr)
{
migraphx::program p;
migraphx::shape ss(migraphx::shape::int32_type, {3});
......@@ -974,7 +974,7 @@ TEST_CASE(const_of_shape3)
EXPECT(p == prog);
}
TEST_CASE(const_of_shape4)
TEST_CASE(const_of_shape_empty_input)
{
migraphx::program p;
p.add_literal(migraphx::literal());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment