Commit 22012c6d authored by charlie's avatar charlie
Browse files

Change parse gemm, remove test

parent eaba20d8
......@@ -90,27 +90,28 @@ struct parse_gemm : op_parser<parse_gemm>
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1])
: args[1];
auto ret = info.add_instruction(make_op("dot"), a_arg, b_arg);
auto dot_ins = info.add_instruction(make_op("dot"), a_arg, b_arg);
if(args.size() == 3)
{
// TODO: support dynamic C input
if(std::any_of(args.cbegin(), args.cend(), [](auto in_arg) {
return in_arg->get_shape().dynamic();
}))
if(not float_equal(beta, 0.0f))
{
MIGRAPHX_THROW("PARSE_GEMM: C input not handled for dynamic input shapes");
}
if(not float_equal(beta, 0.0f) and args[2]->get_shape().elements() > 0)
{
auto out_lens = a_arg->get_shape().lens();
out_lens.back() = b_arg->get_shape().lens().back();
auto c_arg = args[2];
auto c_lens = c_arg->get_shape().lens();
if(not std::equal(out_lens.begin(), out_lens.end(), c_lens.begin(), c_lens.end()))
auto c_arg = args[2];
if(dot_ins->get_shape().dynamic())
{
c_arg = info.add_instruction(
make_op("multibroadcast", {{"out_lens", out_lens}}), args[2]);
c_arg = info.add_instruction(make_op("multibroadcast"), args[2], dot_ins);
}
else
{
auto out_lens = a_arg->get_shape().lens();
out_lens.back() = b_arg->get_shape().lens().back();
auto c_lens = c_arg->get_shape().lens();
if(not std::equal(
out_lens.begin(), out_lens.end(), c_lens.begin(), c_lens.end()))
{
c_arg = info.add_instruction(
make_op("multibroadcast", {{"out_lens", out_lens}}), args[2]);
}
}
auto beta_literal = info.add_literal(beta);
auto beta_c = info.add_broadcastable_binary_op("mul", c_arg, beta_literal);
......@@ -120,11 +121,10 @@ struct parse_gemm : op_parser<parse_gemm>
beta_c);
}
return info.add_instruction(make_op("add"), ret, beta_c);
return info.add_instruction(make_op("add"), dot_ins, beta_c);
}
}
return ret;
return dot_ins;
}
};
......
......@@ -2278,13 +2278,6 @@ TEST_CASE(gemm_dyn_outer_test)
EXPECT(p == prog);
}
TEST_CASE(gemm_dyn_C_error)
{
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
EXPECT(test::throws([&] { migraphx::parse_onnx("gemm_dyn_C_error.onnx", options); }));
}
TEST_CASE(gemm_rank_error)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("gemm_rank_error.onnx"); }));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment