Commit 40dcfa6b authored by Khalique's avatar Khalique
Browse files

added more operator parsers with edge cases

parent 1a2e5c11
...@@ -391,7 +391,7 @@ struct onnx_parser ...@@ -391,7 +391,7 @@ struct onnx_parser
} }
if(contains(attributes, "beta")) if(contains(attributes, "beta"))
{ {
alpha = parse_value(attributes.at("beta")).at<float>(); beta = parse_value(attributes.at("beta")).at<float>();
} }
if(contains(attributes, "transA")) if(contains(attributes, "transA"))
{ {
......
...@@ -421,6 +421,36 @@ void constant_test() ...@@ -421,6 +421,36 @@ void constant_test()
EXPECT(p == prog); EXPECT(p == prog);
} }
void gemm_test()
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 7}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {11, 5}});
auto l2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {}});
auto t0 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l0);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto d0 = p.add_instruction(migraphx::op::dot{2, 2}, t0, t1);
auto b0 = p.add_instruction(migraphx::op::broadcast{1, d0->get_shape()}, l2);
p.add_instruction(migraphx::op::add{}, d0, b0);
auto prog = migraphx::parse_onnx("gemm_test.onnx");
EXPECT(p == prog);
}
void add_scalar_test()
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}}, {1}});
auto m0 = p.add_instruction(migraphx::op::multibroadcast{{0, 0, 0, 5}}, l0);
auto m1 = p.add_instruction(migraphx::op::multibroadcast{{0, 0, 0, 5}}, l1);
p.add_instruction(migraphx::op::add{}, m0, m1);
auto prog = migraphx::parse_onnx("add_scalar_test.onnx");
EXPECT(p == prog);
}
int main() int main()
{ {
pytorch_conv_bias_test(); pytorch_conv_bias_test();
...@@ -454,4 +484,6 @@ int main() ...@@ -454,4 +484,6 @@ int main()
concat_test(); concat_test();
slice_test(); slice_test();
constant_test(); constant_test();
gemm_test();
add_scalar_test();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment