Commit 32815c2e authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add test cases for operator in the bert model

parent 24a6fec2
...@@ -545,6 +545,12 @@ struct onnx_parser ...@@ -545,6 +545,12 @@ struct onnx_parser
const std::vector<instruction_ref>&) const std::vector<instruction_ref>&)
{ {
literal v = parse_value(attributes.at("value")); literal v = parse_value(attributes.at("value"));
// return empty literal
if (v.get_shape().elements() == 0)
{
return prog.add_literal(literal{});
}
auto dim_size = attributes.at("value").t().dims_size(); auto dim_size = attributes.at("value").t().dims_size();
// if dim_size is 0, it is a scalar // if dim_size is 0, it is a scalar
if(dim_size == 0) if(dim_size == 0)
...@@ -923,21 +929,31 @@ struct onnx_parser ...@@ -923,21 +929,31 @@ struct onnx_parser
// input is empty, output is a scalar // input is empty, output is a scalar
auto type = l_val.get_shape().type(); auto type = l_val.get_shape().type();
if(args.size() == 0) if(args.size() == 0)
{ {
return prog.add_literal(literal({type, {1}, {0}}, l_val.data())); MIGRAPHX_THROW("Parse ConstantOfShape : must have 1 input!");
}
else
{
migraphx::shape s;
// empty input tensor, output is a scalar
if (args[0]->get_shape().elements() == 0)
{
s = migraphx::shape{type, {1}, {0}};
} }
else else
{ {
migraphx::argument in = args[0]->eval(); migraphx::argument in = args[0]->eval();
if(in.empty()) if(in.empty())
{ {
MIGRAPHX_THROW("ConstantOfShape: cannot handle dynamic shape as input"); MIGRAPHX_THROW("Parse ConstantOfShape: cannot handle dynamic shape as input");
} }
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); }); in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
migraphx::shape s(type, dims); s = migraphx::shape{type, dims};
}
literal l_out; literal l_out;
l_val.visit([&](auto val) { l_val.visit([&](auto val) {
...@@ -955,8 +971,14 @@ struct onnx_parser ...@@ -955,8 +971,14 @@ struct onnx_parser
parse_expand(const std::string&, attribute_map, std::vector<instruction_ref> args) parse_expand(const std::string&, attribute_map, std::vector<instruction_ref> args)
{ {
auto in_lens = args[0]->get_shape().lens(); auto in_lens = args[0]->get_shape().lens();
auto ex_lens = args[1]->get_shape().lens(); migraphx::argument arg_s = args[1]->eval();
auto out_lens = compute_broadcasted_lens(in_lens, ex_lens); if (arg_s.empty())
{
MIGRAPHX_THROW("Parse Expand: cannot handle dynamic shape as input");
}
std::vector<std::size_t> dims;
arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
auto out_lens = compute_broadcasted_lens(in_lens, dims);
return prog.add_instruction(op::multibroadcast{out_lens}, std::move(args[0])); return prog.add_instruction(op::multibroadcast{out_lens}, std::move(args[0]));
} }
......
...@@ -6,10 +6,10 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -6,10 +6,10 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void pow(hipStream_t stream, const argument& result, const argument& arg2, const argument& arg1) void pow(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{ {
nary(stream, result, arg1, arg2)( nary(stream, result, arg1, arg2)(
[](auto x, auto y) { return ::pow(to_hip_type(x), to_hip_type(y)); }); [](auto e, auto b) { return ::pow(to_hip_type(b), to_hip_type(e)); });
} }
} // namespace device } // namespace device
......
...@@ -8,6 +8,7 @@ namespace device { ...@@ -8,6 +8,7 @@ namespace device {
void reduce_sum(hipStream_t stream, const argument& result, const argument& arg) void reduce_sum(hipStream_t stream, const argument& result, const argument& arg)
{ {
reduce(stream, result, arg, sum{}, 0, id{}, id{}); reduce(stream, result, arg, sum{}, 0, id{}, id{});
} }
......
 cast-example:F

xy"Cast*
to test_castZ
x



b
y


B
constant-of-shape:
6shape"Constant*#
value**B shape_tensor 
7
shapey"ConstantOfShape*
value*:
Bvalue constant_of_shapeb
y



B
constant-of-shape:
6shape"Constant*#
value**B shape_tensor 

shapey"ConstantOfShapeconstant_of_shapeb
y



B
expand:
7shape"Constant*$
value**B shape_tensor

x
shapey"ExpandexpandZ
x



b
y




B
...@@ -884,4 +884,78 @@ TEST_CASE(pow_test) ...@@ -884,4 +884,78 @@ TEST_CASE(pow_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(cast_test)
{
migraphx::program p;
auto l = p.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {10}});
p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, l);
auto prog = migraphx::parse_onnx("cast_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(const_of_shape1)
{
migraphx::program p;
migraphx::shape ss(migraphx::shape::int32_type, {3});
p.add_literal(migraphx::literal(ss, {2, 3, 4}));
migraphx::shape s(migraphx::shape::float_type, {2, 3, 4});
std::vector<float> vec(s.elements(), 10.0f);
p.add_literal(migraphx::literal(s, vec));
auto prog = migraphx::parse_onnx("const_of_shape1.onnx");
EXPECT(p == prog);
}
TEST_CASE(const_of_shape2)
{
migraphx::program p;
migraphx::shape ss(migraphx::shape::int32_type, {3});
p.add_literal(migraphx::literal(ss, {2, 3, 4}));
migraphx::shape s(migraphx::shape::int64_type, {2, 3, 4});
std::vector<int64_t> vec(s.elements(), 10.0f);
p.add_literal(migraphx::literal(s, vec));
auto prog = migraphx::parse_onnx("const_of_shape2.onnx");
EXPECT(p == prog);
}
TEST_CASE(const_of_shape3)
{
migraphx::program p;
migraphx::shape ss(migraphx::shape::int32_type, {3});
p.add_literal(migraphx::literal(ss, {2, 3, 4}));
migraphx::shape s(migraphx::shape::float_type, {2, 3, 4});
std::vector<float> vec(s.elements(), 0.0f);
p.add_literal(migraphx::literal(s, vec));
auto prog = migraphx::parse_onnx("const_of_shape3.onnx");
EXPECT(p == prog);
}
TEST_CASE(const_of_shape4)
{
migraphx::program p;
p.add_literal(migraphx::literal());
migraphx::shape s(migraphx::shape::int64_type, {1}, {0});
std::vector<int64_t> vec(s.elements(), 10);
p.add_literal(migraphx::literal(s, vec));
auto prog = migraphx::parse_onnx("const_of_shape4.onnx");
EXPECT(p == prog);
}
TEST_CASE(expand_test)
{
migraphx::program p;
migraphx::shape s(migraphx::shape::float_type, {3, 1, 1});
auto param = p.add_parameter("x", s);
migraphx::shape ss(migraphx::shape::int32_type, {4});
p.add_literal(migraphx::literal(ss, {2, 3, 4, 5}));
p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, param);
auto prog = migraphx::parse_onnx("expand_test.onnx");
EXPECT(p == prog);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment