"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "f1de9bc154afa408f29763dbebfe66b6d29d32f8"
Commit 794d0e76 authored by Gyula Zakor's avatar Gyula Zakor
Browse files

Add uint8 to quant_dot

parent 99ebfe11
......@@ -41,12 +41,13 @@ struct quant_dot
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{{inputs.at(0), inputs.at(1)}, *this}.same_type().has(2);
const shape& a = inputs.at(0);
const shape& b = inputs.at(1);
auto t = a.type();
if(t != shape::int8_type)
const shape& a = inputs.at(0);
const shape& b = inputs.at(1);
auto t = a.type();
std::set<migraphx::shape::type_t> suppported_types = {shape::int8_type, shape::uint8_type};
if(not contains(suppported_types, t))
{
MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t");
MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t and uint8_t");
}
if(not std::all_of(
......
......@@ -113,32 +113,20 @@ struct parse_matmul : op_parser<parse_matmul>
}
}
// MatMulInteger can accept uint8 as input type or have zero point values
// In these case fall back to dot with half float inputs
auto ba0_type = ba0->get_shape().type();
auto ba1_type = ba1->get_shape().type();
auto has_a0_zero_point = args.size() > 2;
auto has_a1_zero_point = args.size() > 3;
if(is_quant_dot and (ba0_type == migraphx::shape::uint8_type or
ba1_type == migraphx::shape::uint8_type or has_a0_zero_point))
// parse a_zero_point and b_zero_point values
if(args.size() > 2)
{
// gpu implementation (gemm) only accepts floating point types for dot
ba0 = info.add_instruction(
make_op("convert", {{"target_type", migraphx::shape::half_type}}), ba0);
ba1 = info.add_instruction(
make_op("convert", {{"target_type", migraphx::shape::half_type}}), ba1);
make_op("convert", {{"target_type", migraphx::shape::float_type}}), ba0);
if(has_a0_zero_point)
{
ba0 = info.add_common_op("sub", ba0, args[2]);
}
if(has_a1_zero_point)
ba0 = info.add_common_op("sub", ba0, args[2]);
if(args.size() > 3)
{
ba1 = info.add_instruction(
make_op("convert", {{"target_type", migraphx::shape::float_type}}), ba1);
ba1 = info.add_common_op("sub", ba1, args[3]);
}
dot_res = info.add_instruction(make_op("dot"), ba0, ba1);
dot_res = info.add_instruction(
make_op("convert", {{"target_type", migraphx::shape::int32_type}}), dot_res);
}
else
{
......
......@@ -196,7 +196,7 @@ struct gemm_impl
arg_type = get_type(input_shapes[0].type());
output_type = arg_type;
if(output_type == rocblas_datatype_i8_r)
if(output_type == rocblas_datatype_i8_r or output_type == rocblas_datatype_u8_r)
{
output_type = rocblas_datatype_i32_r;
}
......
......@@ -140,6 +140,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
simplify_qdq{},
enable_pass(not mlir_enabled(), rewrite_quantization{}),
dead_code_elimination{},
// workaround for rocBLAS unsupported error when using uint8 in quant_dot
eliminate_data_type{{migraphx::shape::uint8_type}, shape::float_type, {"quant_dot"}},
eliminate_data_type{unsupported_types, shape::type_t::float_type},
simplify_reshapes{},
eliminate_identity{},
......
......@@ -1218,9 +1218,7 @@ TEST_CASE(lpnormalization_2norm)
TEST_CASE(matmulinteger_unsigned_test)
{
migraphx::program p = migraphx::parse_onnx("matmulinteger_unsigned_test.onnx");
migraphx::compile_options gpu_opt;
gpu_opt.offload_copy = true;
p.compile(migraphx::make_target("ref"), gpu_opt);
p.compile(migraphx::make_target("ref"));
migraphx::shape s0{migraphx::shape::uint8_type, {4, 3}};
std::vector<uint8_t> data0 = {11, 7, 3, 10, 6, 2, 9, 5, 1, 8, 4, 0};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment