Commit 03f0e278 authored by charlie's avatar charlie
Browse files

Fix parsing and add test

parent 22012c6d
...@@ -113,15 +113,19 @@ struct parse_gemm : op_parser<parse_gemm> ...@@ -113,15 +113,19 @@ struct parse_gemm : op_parser<parse_gemm>
make_op("multibroadcast", {{"out_lens", out_lens}}), args[2]); make_op("multibroadcast", {{"out_lens", out_lens}}), args[2]);
} }
} }
auto beta_literal = info.add_literal(beta);
auto beta_c = info.add_broadcastable_binary_op("mul", c_arg, beta_literal); if(not float_equal(beta, 1.0f))
if(beta_c->get_shape().type() != dot_type)
{ {
beta_c = info.add_instruction(make_op("convert", {{"target_type", dot_type}}), auto beta_literal = info.add_literal(beta);
beta_c); c_arg = info.add_broadcastable_binary_op("mul", c_arg, beta_literal);
if(c_arg->get_shape().type() != dot_type)
{
c_arg = info.add_instruction(
make_op("convert", {{"target_type", dot_type}}), c_arg);
}
} }
return info.add_instruction(make_op("add"), dot_ins, beta_c); return info.add_instruction(make_op("add"), dot_ins, c_arg);
} }
} }
return dot_ins; return dot_ins;
......
...@@ -2215,7 +2215,7 @@ def gemm_dyn_outer_test(): ...@@ -2215,7 +2215,7 @@ def gemm_dyn_outer_test():
@onnx_test() @onnx_test()
def gemm_dyn_C_error(): def gemm_dyn_bias_test():
A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [8, None]) A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [8, None])
B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [8, 7]) B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [8, 7])
C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [1, 7]) C = helper.make_tensor_value_info('C', TensorProto.FLOAT, [1, 7])
......
...@@ -2278,6 +2278,26 @@ TEST_CASE(gemm_dyn_outer_test) ...@@ -2278,6 +2278,26 @@ TEST_CASE(gemm_dyn_outer_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(gemm_dyn_bias_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x0 =
mm->add_parameter("A", migraphx::shape{migraphx::shape::float_type, {{8, 8}, {1, 10}}});
auto x1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::float_type, {8, 7}});
auto x2 = mm->add_parameter("C", migraphx::shape{migraphx::shape::float_type, {1, 7}});
auto x0_t = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), x0);
auto dot = mm->add_instruction(migraphx::make_op("dot"), x0_t, x1);
auto x2_b = mm->add_instruction(migraphx::make_op("multibroadcast"), x2, dot);
auto ret = mm->add_instruction(migraphx::make_op("add"), dot, x2_b);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 10};
auto prog = parse_onnx("gemm_dyn_bias_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(gemm_rank_error) TEST_CASE(gemm_rank_error)
{ {
EXPECT(test::throws([&] { migraphx::parse_onnx("gemm_rank_error.onnx"); })); EXPECT(test::throws([&] { migraphx::parse_onnx("gemm_rank_error.onnx"); }));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment