Unverified Commit 9d71a5e6 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Parse gemm type mismatch (#895)



* fix an issue for type mismatch in parsing gemm

* clang format

* add unit tests

* clang format

* add missing onnx file
Co-authored-by: default avatarChris Austen <causten@users.noreply.github.com>
parent 68032c62
......@@ -42,14 +42,17 @@ struct parse_gemm : op_parser<parse_gemm>
// swap the last two elements
std::swap(*perm.rbegin(), *(perm.rbegin() + 1));
auto l1 = args[0];
auto l1 = args[0];
auto dot_type = l1->get_shape().type();
if(alpha != 1.0f)
{
auto alpha_literal = info.add_literal(alpha);
auto alpha_l1 = info.add_broadcastable_binary_op("mul", alpha_literal, l1);
l1 = info.add_instruction(make_op("convert", {{"target_type", l1->get_shape().type()}}),
alpha_l1);
l1 = info.add_broadcastable_binary_op("mul", alpha_literal, l1);
if(l1->get_shape().type() != dot_type)
{
l1 = info.add_instruction(make_op("convert", {{"target_type", dot_type}}), l1);
}
}
l1 = (transa) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), l1) : l1;
......@@ -69,13 +72,16 @@ struct parse_gemm : op_parser<parse_gemm>
l3 = info.add_instruction(
make_op("multibroadcast", {{"output_lens", out_lens}}), args[2]);
}
auto beta_literal = info.add_literal(beta);
auto beta_broadcast = info.add_instruction(
make_op("multibroadcast", {{"output_lens", out_lens}}), beta_literal);
l3 = info.add_instruction(make_op("mul"), l3, beta_broadcast);
auto beta_literal = info.add_literal(beta);
auto beta_l3 = info.add_broadcastable_binary_op("mul", l3, beta_literal);
if(beta_l3->get_shape().type() != dot_type)
{
beta_l3 = info.add_instruction(make_op("convert", {{"target_type", dot_type}}),
beta_l3);
}
return info.add_instruction(
make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), l1, l2, l3);
make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), l1, l2, beta_l3);
}
}
......
......@@ -1437,6 +1437,23 @@ def gemm_ex_brcst_test():
return ([node], [m1, m2, m3], [y])
@onnx_test
def gemm_half_test():
m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT16, [1, 1, 8, 6])
m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT16, [1, 1, 8, 7])
m3 = helper.make_tensor_value_info('3', TensorProto.FLOAT16, [1, 1, 6, 1])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [1, 1, 6, 7])
node = onnx.helper.make_node('Gemm',
inputs=['1', '2', '3'],
outputs=['y'],
alpha=0.5,
beta=0.8,
transA=1)
return ([node], [m1, m2, m3], [y])
@onnx_test
def globalavgpool_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 16, 16])
......
......@@ -1292,10 +1292,8 @@ TEST_CASE(gemm_test)
auto beta = 2.0f;
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", l1->get_shape().type()}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), t_a);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), t_a);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1);
auto b_l = mm->add_literal(beta);
auto l2_b =
......@@ -1320,10 +1318,7 @@ TEST_CASE(gemm_ex_test)
auto beta = 0.8f;
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", l1->get_shape().type()}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), t_a);
auto b_l = mm->add_literal(beta);
auto b_b = mm->add_instruction(
......@@ -1348,10 +1343,7 @@ TEST_CASE(gemm_ex_brcst_test)
auto beta = 0.8f;
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", l1->get_shape().type()}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), t_a);
auto b_l = mm->add_literal(beta);
auto l2_b =
......@@ -1367,6 +1359,37 @@ TEST_CASE(gemm_ex_brcst_test)
EXPECT(p == prog);
}
TEST_CASE(gemm_half_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::half_type, {1, 1, 8, 6}});
auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::half_type, {1, 1, 8, 7}});
auto l2 = mm->add_parameter("3", migraphx::shape{migraphx::shape::half_type, {1, 1, 6, 1}});
auto alpha = 0.5f;
auto beta = 0.8f;
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), t_a);
std::vector<std::size_t> lens = {1, 1, 6, 7};
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), l2);
l2 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l2);
auto b_l = mm->add_literal(beta);
auto b_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), b_l);
auto l2_b = mm->add_instruction(migraphx::make_op("mul"), l2, b_b);
l2_b = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), l2_b);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), t_a, l1, l2_b);
auto prog = optimize_onnx("gemm_half_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(globalavgpool_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment