Unverified Commit f2c7e9b3 authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Dynamic onnx gemm bias (#1527)

Adds support for parsing dynamic ONNX gemm bias input C
parent d309e02f
...@@ -90,41 +90,45 @@ struct parse_gemm : op_parser<parse_gemm> ...@@ -90,41 +90,45 @@ struct parse_gemm : op_parser<parse_gemm>
? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1]) ? info.add_instruction(make_op("transpose", {{"permutation", perm}}), args[1])
: args[1]; : args[1];
auto ret = info.add_instruction(make_op("dot"), a_arg, b_arg); auto dot_ins = info.add_instruction(make_op("dot"), a_arg, b_arg);
if(args.size() == 3) if(args.size() == 3)
{ {
// TODO: support dynamic C input if(not float_equal(beta, 0.0f))
if(std::any_of(args.cbegin(), args.cend(), [](auto in_arg) {
return in_arg->get_shape().dynamic();
}))
{ {
MIGRAPHX_THROW("PARSE_GEMM: C input not handled for dynamic input shapes"); auto c_arg = args[2];
} if(dot_ins->get_shape().dynamic())
if(not float_equal(beta, 0.0f) and args[2]->get_shape().elements() > 0) {
{ c_arg = info.add_instruction(make_op("multibroadcast"), args[2], dot_ins);
auto out_lens = a_arg->get_shape().lens(); }
out_lens.back() = b_arg->get_shape().lens().back(); else
auto c_arg = args[2];
auto c_lens = c_arg->get_shape().lens();
if(not std::equal(out_lens.begin(), out_lens.end(), c_lens.begin(), c_lens.end()))
{ {
c_arg = info.add_instruction( auto out_lens = a_arg->get_shape().lens();
make_op("multibroadcast", {{"out_lens", out_lens}}), args[2]); out_lens.back() = b_arg->get_shape().lens().back();
auto c_lens = c_arg->get_shape().lens();
if(not std::equal(
out_lens.begin(), out_lens.end(), c_lens.begin(), c_lens.end()))
{
c_arg = info.add_instruction(
make_op("multibroadcast", {{"out_lens", out_lens}}), args[2]);
}
} }
auto beta_literal = info.add_literal(beta);
auto beta_c = info.add_broadcastable_binary_op("mul", c_arg, beta_literal); if(not float_equal(beta, 1.0f))
if(beta_c->get_shape().type() != dot_type)
{ {
beta_c = info.add_instruction(make_op("convert", {{"target_type", dot_type}}), auto beta_literal = info.add_literal(beta);
beta_c); c_arg = info.add_broadcastable_binary_op("mul", c_arg, beta_literal);
if(c_arg->get_shape().type() != dot_type)
{
c_arg = info.add_instruction(
make_op("convert", {{"target_type", dot_type}}), c_arg);
}
} }
return info.add_instruction(make_op("add"), ret, beta_c); return info.add_instruction(make_op("add"), dot_ins, c_arg);
} }
} }
return dot_ins;
return ret;
} }
}; };
......
...@@ -2215,7 +2215,7 @@ def gemm_dyn_outer_test(): ...@@ -2215,7 +2215,7 @@ def gemm_dyn_outer_test():
@onnx_test() @onnx_test()
def gemm_dyn_C_error(): def gemm_dyn_bias_test():
A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [8, None]) A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [8, None])
B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [8, 7]) B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [8, 7])
C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [1, 7]) C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [1, 7])
......
...@@ -2278,11 +2278,24 @@ TEST_CASE(gemm_dyn_outer_test) ...@@ -2278,11 +2278,24 @@ TEST_CASE(gemm_dyn_outer_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(gemm_dyn_C_error) TEST_CASE(gemm_dyn_bias_test)
{ {
migraphx::program p;
auto* mm = p.get_main_module();
auto x0 =
mm->add_parameter("A", migraphx::shape{migraphx::shape::float_type, {{8, 8}, {1, 10}}});
auto x1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::float_type, {8, 7}});
auto x2 = mm->add_parameter("C", migraphx::shape{migraphx::shape::float_type, {1, 7}});
auto x0_t = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), x0);
auto dot = mm->add_instruction(migraphx::make_op("dot"), x0_t, x1);
auto x2_b = mm->add_instruction(migraphx::make_op("multibroadcast"), x2, dot);
auto ret = mm->add_instruction(migraphx::make_op("add"), dot, x2_b);
mm->add_return({ret});
migraphx::onnx_options options; migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0}; options.default_dyn_dim_value = {1, 10};
EXPECT(test::throws([&] { migraphx::parse_onnx("gemm_dyn_C_error.onnx", options); })); auto prog = parse_onnx("gemm_dyn_bias_test.onnx", options);
EXPECT(p == prog);
} }
TEST_CASE(gemm_rank_error) TEST_CASE(gemm_rank_error)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment