Commit c2898ca7 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

simplify parsing the gemm operator

parent 1681e49a
...@@ -481,22 +481,20 @@ struct onnx_parser ...@@ -481,22 +481,20 @@ struct onnx_parser
auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1]; auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
if(args.size() == 3) if(args.size() == 3)
{ {
if(beta != 0.f) if(beta != 0.f && args[2]->get_shape().elements() > 0)
{ {
auto l3 = prog.add_instruction(op::dot{alpha}, l1, l2); auto out_lens = l1->get_shape().lens();
auto l4 = args[2]; out_lens.back() = l2->get_shape().lens().back();
if(l4->get_shape().scalar()) // ignore args[2] (no C value added to alpha*A*B) auto l3 = args[2];
return l3; if (!std::equal(out_lens.begin(), out_lens.end(), args[2]->get_shape().lens().begin()))
if(beta != 1.f)
{ {
auto beta_val = prog.add_literal(beta); l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]);
auto l5 = prog.add_instruction(op::scalar{args[2]->get_shape()}, beta_val);
l4 = prog.add_instruction(op::mul{}, args[2], l5);
} }
return add_broadcastable_binary_op(l3, l4, op::add{});
return prog.add_instruction(op::dot{alpha, beta}, l1, l2, l3);
} }
} }
return prog.add_instruction(op::dot{alpha, beta}, l1, l2); return prog.add_instruction(op::dot{alpha}, l1, l2);
} }
instruction_ref instruction_ref
......
...@@ -580,14 +580,8 @@ TEST_CASE(gemm_ex) ...@@ -580,14 +580,8 @@ TEST_CASE(gemm_ex)
auto l2 = p.add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 7}}); auto l2 = p.add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 7}});
auto t0 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l0); auto t0 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l0);
auto alpha = 0.5f; auto alpha = 0.5f;
auto res_ab = p.add_instruction(migraphx::op::dot{alpha}, t0, l1);
auto beta = 0.8f; auto beta = 0.8f;
auto l_beta = p.add_literal(beta); p.add_instruction(migraphx::op::dot{alpha, beta}, t0, l1, l2);
auto brcst_beta = p.add_instruction(migraphx::op::scalar{l2->get_shape()}, l_beta);
auto res_c = p.add_instruction(migraphx::op::mul{}, l2, brcst_beta);
p.add_instruction(migraphx::op::add{}, res_ab, res_c);
auto prog = migraphx::parse_onnx("gemm_test_ex.onnx"); auto prog = migraphx::parse_onnx("gemm_test_ex.onnx");
EXPECT(p == prog); EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment