Commit 22012c6d authored by charlie's avatar charlie
Browse files

Change parse gemm, remove test

parent eaba20d8
...@@ -90,28 +90,29 @@ struct parse_gemm : op_parser<parse_gemm> ...@@ -90,28 +90,29 @@ struct parse_gemm : op_parser<parse_gemm>
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1]) ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1])
: args[1]; : args[1];
auto ret = info.add_instruction(make_op("dot"), a_arg, b_arg); auto dot_ins = info.add_instruction(make_op("dot"), a_arg, b_arg);
if(args.size() == 3) if(args.size() == 3)
{ {
// TODO: support dynamic C input if(not float_equal(beta, 0.0f))
if(std::any_of(args.cbegin(), args.cend(), [](auto in_arg) {
return in_arg->get_shape().dynamic();
}))
{ {
MIGRAPHX_THROW("PARSE_GEMM: C input not handled for dynamic input shapes"); auto c_arg = args[2];
if(dot_ins->get_shape().dynamic())
{
c_arg = info.add_instruction(make_op("multibroadcast"), args[2], dot_ins);
} }
if(not float_equal(beta, 0.0f) and args[2]->get_shape().elements() > 0) else
{ {
auto out_lens = a_arg->get_shape().lens(); auto out_lens = a_arg->get_shape().lens();
out_lens.back() = b_arg->get_shape().lens().back(); out_lens.back() = b_arg->get_shape().lens().back();
auto c_arg = args[2];
auto c_lens = c_arg->get_shape().lens(); auto c_lens = c_arg->get_shape().lens();
if(not std::equal(out_lens.begin(), out_lens.end(), c_lens.begin(), c_lens.end())) if(not std::equal(
out_lens.begin(), out_lens.end(), c_lens.begin(), c_lens.end()))
{ {
c_arg = info.add_instruction( c_arg = info.add_instruction(
make_op("multibroadcast", {{"out_lens", out_lens}}), args[2]); make_op("multibroadcast", {{"out_lens", out_lens}}), args[2]);
} }
}
auto beta_literal = info.add_literal(beta); auto beta_literal = info.add_literal(beta);
auto beta_c = info.add_broadcastable_binary_op("mul", c_arg, beta_literal); auto beta_c = info.add_broadcastable_binary_op("mul", c_arg, beta_literal);
if(beta_c->get_shape().type() != dot_type) if(beta_c->get_shape().type() != dot_type)
...@@ -120,11 +121,10 @@ struct parse_gemm : op_parser<parse_gemm> ...@@ -120,11 +121,10 @@ struct parse_gemm : op_parser<parse_gemm>
beta_c); beta_c);
} }
return info.add_instruction(make_op("add"), ret, beta_c); return info.add_instruction(make_op("add"), dot_ins, beta_c);
} }
} }
return dot_ins;
return ret;
} }
}; };
......
...@@ -2278,13 +2278,6 @@ TEST_CASE(gemm_dyn_outer_test) ...@@ -2278,13 +2278,6 @@ TEST_CASE(gemm_dyn_outer_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(gemm_dyn_C_error)
{
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
EXPECT(test::throws([&] { migraphx::parse_onnx("gemm_dyn_C_error.onnx", options); }));
}
TEST_CASE(gemm_rank_error) TEST_CASE(gemm_rank_error)
{ {
EXPECT(test::throws([&] { migraphx::parse_onnx("gemm_rank_error.onnx"); })); EXPECT(test::throws([&] { migraphx::parse_onnx("gemm_rank_error.onnx"); }));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment