Commit 99ebfe11 authored by Gyula Zakor's avatar Gyula Zakor
Browse files

Update MatMulInteger op parsing

parent 7e53592e
...@@ -62,9 +62,10 @@ struct parse_matmul : op_parser<parse_matmul> ...@@ -62,9 +62,10 @@ struct parse_matmul : op_parser<parse_matmul>
a1 = info.add_instruction(make_op("unsqueeze", {{"axes", {1}}}), args[1]); a1 = info.add_instruction(make_op("unsqueeze", {{"axes", {1}}}), args[1]);
} }
auto is_quant_dot = opd.op_name == "quant_dot";
if(s0.dynamic() or s1.dynamic()) if(s0.dynamic() or s1.dynamic())
{ {
if(opd.op_name == "quant_dot") if(is_quant_dot)
{ {
MIGRAPHX_THROW("PARSE_MATMUL: dynamic MatMulInteger not supported"); MIGRAPHX_THROW("PARSE_MATMUL: dynamic MatMulInteger not supported");
} }
...@@ -111,7 +112,38 @@ struct parse_matmul : op_parser<parse_matmul> ...@@ -111,7 +112,38 @@ struct parse_matmul : op_parser<parse_matmul>
make_op("multibroadcast", {{"out_lens", l1_broadcasted_lens}}), a1); make_op("multibroadcast", {{"out_lens", l1_broadcasted_lens}}), a1);
} }
} }
dot_res = info.add_instruction(make_op(opd.op_name), ba0, ba1);
// MatMulInteger can accept uint8 as input type or have zero point values
// In these case fall back to dot with half float inputs
auto ba0_type = ba0->get_shape().type();
auto ba1_type = ba1->get_shape().type();
auto has_a0_zero_point = args.size() > 2;
auto has_a1_zero_point = args.size() > 3;
if(is_quant_dot and (ba0_type == migraphx::shape::uint8_type or
ba1_type == migraphx::shape::uint8_type or has_a0_zero_point))
{
// gpu implementation (gemm) only accepts floating point types for dot
ba0 = info.add_instruction(
make_op("convert", {{"target_type", migraphx::shape::half_type}}), ba0);
ba1 = info.add_instruction(
make_op("convert", {{"target_type", migraphx::shape::half_type}}), ba1);
if(has_a0_zero_point)
{
ba0 = info.add_common_op("sub", ba0, args[2]);
}
if(has_a1_zero_point)
{
ba1 = info.add_common_op("sub", ba1, args[3]);
}
dot_res = info.add_instruction(make_op("dot"), ba0, ba1);
dot_res = info.add_instruction(
make_op("convert", {{"target_type", migraphx::shape::int32_type}}), dot_res);
}
else
{
dot_res = info.add_instruction(make_op(opd.op_name), ba0, ba1);
}
} }
// squeeze the appended or prepended dimensions // squeeze the appended or prepended dimensions
......
...@@ -4866,6 +4866,23 @@ def matmulinteger_dyn_error(): ...@@ -4866,6 +4866,23 @@ def matmulinteger_dyn_error():
return ([node], [m1, m2], [y]) return ([node], [m1, m2], [y])
@onnx_test()
def matmulinteger_unsigned_test():
m1 = helper.make_tensor_value_info('1', TensorProto.UINT8, [4, 3])
m2 = helper.make_tensor_value_info('2', TensorProto.UINT8, [3, 2])
zp1 = helper.make_tensor('3', TensorProto.UINT8, [], [12])
zp2 = helper.make_tensor('4', TensorProto.UINT8, [], [0])
y = helper.make_tensor_value_info('y', TensorProto.INT32, [4, 2])
node = onnx.helper.make_node(
'MatMulInteger',
inputs=['1', '2', '3', '4'],
outputs=['y'],
)
return ([node], [m1, m2], [y], [zp1, zp2])
@onnx_test() @onnx_test()
def max_test(): def max_test():
a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
......
...@@ -1215,6 +1215,29 @@ TEST_CASE(lpnormalization_2norm) ...@@ -1215,6 +1215,29 @@ TEST_CASE(lpnormalization_2norm)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(matmulinteger_unsigned_test)
{
migraphx::program p = migraphx::parse_onnx("matmulinteger_unsigned_test.onnx");
migraphx::compile_options gpu_opt;
gpu_opt.offload_copy = true;
p.compile(migraphx::make_target("ref"), gpu_opt);
migraphx::shape s0{migraphx::shape::uint8_type, {4, 3}};
std::vector<uint8_t> data0 = {11, 7, 3, 10, 6, 2, 9, 5, 1, 8, 4, 0};
migraphx::shape s1{migraphx::shape::uint8_type, {3, 2}};
std::vector<uint8_t> data1 = {1, 4, 2, 5, 3, 6};
migraphx::parameter_map pp;
pp["1"] = migraphx::argument(s0, data0.data());
pp["2"] = migraphx::argument(s1, data1.data());
auto result = p.eval(pp).back();
std::vector<int32_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int32_t> gold = {-38, -83, -44, -98, -50, -113, -56, -128};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(mean_broadcast_test) TEST_CASE(mean_broadcast_test)
{ {
migraphx::program p = migraphx::parse_onnx("mean_broadcast_test.onnx"); migraphx::program p = migraphx::parse_onnx("mean_broadcast_test.onnx");
......
...@@ -134,7 +134,6 @@ def disabled_tests_onnx_1_7_0(backend_test): ...@@ -134,7 +134,6 @@ def disabled_tests_onnx_1_7_0(backend_test):
backend_test.exclude(r'test_hardmax_example_cpu') backend_test.exclude(r'test_hardmax_example_cpu')
backend_test.exclude(r'test_hardmax_negative_axis_cpu') backend_test.exclude(r'test_hardmax_negative_axis_cpu')
backend_test.exclude(r'test_hardmax_one_hot_cpu') backend_test.exclude(r'test_hardmax_one_hot_cpu')
backend_test.exclude(r'test_matmulinteger_cpu')
backend_test.exclude(r'test_maxpool_2d_uint8_cpu') backend_test.exclude(r'test_maxpool_2d_uint8_cpu')
backend_test.exclude(r'test_maxunpool_export_with_output_shape_cpu') backend_test.exclude(r'test_maxunpool_export_with_output_shape_cpu')
backend_test.exclude(r'test_maxunpool_export_without_output_shape_cpu') backend_test.exclude(r'test_maxunpool_export_without_output_shape_cpu')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment