Commit 40dcfa6b authored by Khalique's avatar Khalique
Browse files

added more operator parsers with edge cases

parent 1a2e5c11
......@@ -391,7 +391,7 @@ struct onnx_parser
}
if(contains(attributes, "beta"))
{
alpha = parse_value(attributes.at("beta")).at<float>();
beta = parse_value(attributes.at("beta")).at<float>();
}
if(contains(attributes, "transA"))
{
......
......@@ -421,6 +421,36 @@ void constant_test()
EXPECT(p == prog);
}
void gemm_test()
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 7}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {11, 5}});
auto l2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {}});
auto t0 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l0);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto d0 = p.add_instruction(migraphx::op::dot{2, 2}, t0, t1);
auto b0 = p.add_instruction(migraphx::op::broadcast{1, d0->get_shape()}, l2);
p.add_instruction(migraphx::op::add{}, d0, b0);
auto prog = migraphx::parse_onnx("gemm_test.onnx");
EXPECT(p == prog);
}
void add_scalar_test()
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}}, {1}});
auto m0 = p.add_instruction(migraphx::op::multibroadcast{{0, 0, 0, 5}}, l0);
auto m1 = p.add_instruction(migraphx::op::multibroadcast{{0, 0, 0, 5}}, l1);
p.add_instruction(migraphx::op::add{}, m0, m1);
auto prog = migraphx::parse_onnx("add_scalar_test.onnx");
EXPECT(p == prog);
}
int main()
{
pytorch_conv_bias_test();
......@@ -454,4 +484,6 @@ int main()
concat_test();
slice_test();
constant_test();
gemm_test();
add_scalar_test();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment