"test/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "4d471bdaac896d41d963936dba5f4df8551575df"
Unverified Commit 9d71a5e6 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Parse gemm type mismatch (#895)



* fix an issue for type mismatch in parsing gemm

* clang format

* add unit tests

* clang format

* add missing onnx file
Co-authored-by: default avatarChris Austen <causten@users.noreply.github.com>
parent 68032c62
...@@ -43,13 +43,16 @@ struct parse_gemm : op_parser<parse_gemm> ...@@ -43,13 +43,16 @@ struct parse_gemm : op_parser<parse_gemm>
std::swap(*perm.rbegin(), *(perm.rbegin() + 1)); std::swap(*perm.rbegin(), *(perm.rbegin() + 1));
auto l1 = args[0]; auto l1 = args[0];
auto dot_type = l1->get_shape().type();
if(alpha != 1.0f) if(alpha != 1.0f)
{ {
auto alpha_literal = info.add_literal(alpha); auto alpha_literal = info.add_literal(alpha);
auto alpha_l1 = info.add_broadcastable_binary_op("mul", alpha_literal, l1); l1 = info.add_broadcastable_binary_op("mul", alpha_literal, l1);
l1 = info.add_instruction(make_op("convert", {{"target_type", l1->get_shape().type()}}), if(l1->get_shape().type() != dot_type)
alpha_l1); {
l1 = info.add_instruction(make_op("convert", {{"target_type", dot_type}}), l1);
}
} }
l1 = (transa) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), l1) : l1; l1 = (transa) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), l1) : l1;
...@@ -70,12 +73,15 @@ struct parse_gemm : op_parser<parse_gemm> ...@@ -70,12 +73,15 @@ struct parse_gemm : op_parser<parse_gemm>
make_op("multibroadcast", {{"output_lens", out_lens}}), args[2]); make_op("multibroadcast", {{"output_lens", out_lens}}), args[2]);
} }
auto beta_literal = info.add_literal(beta); auto beta_literal = info.add_literal(beta);
auto beta_broadcast = info.add_instruction( auto beta_l3 = info.add_broadcastable_binary_op("mul", l3, beta_literal);
make_op("multibroadcast", {{"output_lens", out_lens}}), beta_literal); if(beta_l3->get_shape().type() != dot_type)
l3 = info.add_instruction(make_op("mul"), l3, beta_broadcast); {
beta_l3 = info.add_instruction(make_op("convert", {{"target_type", dot_type}}),
beta_l3);
}
return info.add_instruction( return info.add_instruction(
make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), l1, l2, l3); make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), l1, l2, beta_l3);
} }
} }
......
...@@ -1437,6 +1437,23 @@ def gemm_ex_brcst_test(): ...@@ -1437,6 +1437,23 @@ def gemm_ex_brcst_test():
return ([node], [m1, m2, m3], [y]) return ([node], [m1, m2, m3], [y])
@onnx_test
def gemm_half_test():
m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT16, [1, 1, 8, 6])
m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT16, [1, 1, 8, 7])
m3 = helper.make_tensor_value_info('3', TensorProto.FLOAT16, [1, 1, 6, 1])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [1, 1, 6, 7])
node = onnx.helper.make_node('Gemm',
inputs=['1', '2', '3'],
outputs=['y'],
alpha=0.5,
beta=0.8,
transA=1)
return ([node], [m1, m2, m3], [y])
@onnx_test @onnx_test
def globalavgpool_test(): def globalavgpool_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 16, 16]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 16, 16])
......
...@@ -1292,8 +1292,6 @@ TEST_CASE(gemm_test) ...@@ -1292,8 +1292,6 @@ TEST_CASE(gemm_test)
auto beta = 2.0f; auto beta = 2.0f;
auto a_l = mm->add_literal(alpha); auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0}); auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", l1->get_shape().type()}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), t_a); t_a = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), t_a);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1); auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1);
...@@ -1320,9 +1318,6 @@ TEST_CASE(gemm_ex_test) ...@@ -1320,9 +1318,6 @@ TEST_CASE(gemm_ex_test)
auto beta = 0.8f; auto beta = 0.8f;
auto a_l = mm->add_literal(alpha); auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0}); auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", l1->get_shape().type()}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), t_a); t_a = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), t_a);
auto b_l = mm->add_literal(beta); auto b_l = mm->add_literal(beta);
...@@ -1348,9 +1343,6 @@ TEST_CASE(gemm_ex_brcst_test) ...@@ -1348,9 +1343,6 @@ TEST_CASE(gemm_ex_brcst_test)
auto beta = 0.8f; auto beta = 0.8f;
auto a_l = mm->add_literal(alpha); auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0}); auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", l1->get_shape().type()}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), t_a); t_a = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), t_a);
auto b_l = mm->add_literal(beta); auto b_l = mm->add_literal(beta);
...@@ -1367,6 +1359,37 @@ TEST_CASE(gemm_ex_brcst_test) ...@@ -1367,6 +1359,37 @@ TEST_CASE(gemm_ex_brcst_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(gemm_half_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::half_type, {1, 1, 8, 6}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::half_type, {1, 1, 8, 7}});
auto l2 = mm->add_parameter("3", migraphx::shape{migraphx::shape::half_type, {1, 1, 6, 1}});
auto alpha = 0.5f;
auto beta = 0.8f;
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), t_a);
std::vector<std::size_t> lens = {1, 1, 6, 7};
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), l2);
l2 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l2);
auto b_l = mm->add_literal(beta);
auto b_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), b_l);
auto l2_b = mm->add_instruction(migraphx::make_op("mul"), l2, b_b);
l2_b = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), l2_b);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), t_a, l1, l2_b);
auto prog = optimize_onnx("gemm_half_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(globalavgpool_test) TEST_CASE(globalavgpool_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment