Unverified Commit 22f60c24 authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Enable simplify qdq to work with FP8 types (#2528)

parent dfc18d6c
...@@ -72,8 +72,8 @@ struct dequantizelinear ...@@ -72,8 +72,8 @@ struct dequantizelinear
visit_all(x, x_zero_point)([&](auto input, auto zero_pts) { visit_all(x, x_zero_point)([&](auto input, auto zero_pts) {
visit_all(result, x_scale)([&](auto output, auto scales) { visit_all(result, x_scale)([&](auto output, auto scales) {
par_for(output_shape.elements(), [&](auto i) { par_for(output_shape.elements(), [&](auto i) {
output[i] = static_cast<double>(static_cast<int64_t>(input[i]) - output[i] = static_cast<double>(static_cast<double>(input[i]) -
static_cast<int64_t>(zero_pts[i])) * static_cast<double>(zero_pts[i])) *
scales[i]; scales[i];
}); });
}); });
......
...@@ -80,10 +80,10 @@ struct quantizelinear ...@@ -80,10 +80,10 @@ struct quantizelinear
auto min_value = std::numeric_limits<quant_type>::min(); auto min_value = std::numeric_limits<quant_type>::min();
auto max_value = std::numeric_limits<quant_type>::max(); auto max_value = std::numeric_limits<quant_type>::max();
par_for(output_shape.elements(), [&](auto i) { par_for(output_shape.elements(), [&](auto i) {
int64_t quantized = static_cast<int64_t>(std::nearbyint(input[i] / scales[i])) + double quantized = static_cast<double>(std::nearbyint(input[i] / scales[i])) +
static_cast<int64_t>(zero_pts[i]); static_cast<double>(zero_pts[i]);
output[i] = std::max(static_cast<int64_t>(min_value), output[i] = std::max(static_cast<double>(min_value),
std::min(static_cast<int64_t>(max_value), quantized)); std::min(static_cast<double>(max_value), quantized));
}); });
}); });
}); });
......
...@@ -58,8 +58,8 @@ void apply_quantizelinear(module& m, instruction_ref ins) ...@@ -58,8 +58,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point); add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point);
} }
int64_t max_quant = 0; double max_quant = 0;
int64_t min_quant = 0; double min_quant = 0;
ins->get_shape().visit_type([&](auto qt) { ins->get_shape().visit_type([&](auto qt) {
max_quant = qt.max(); max_quant = qt.max();
min_quant = qt.min(); min_quant = qt.min();
...@@ -70,8 +70,8 @@ void apply_quantizelinear(module& m, instruction_ref ins) ...@@ -70,8 +70,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
if(enabled(MIGRAPHX_ENABLE_CK_WORKAROUNDS{})) if(enabled(MIGRAPHX_ENABLE_CK_WORKAROUNDS{}))
{ {
std::vector<int> min_data(s.elements(), min_quant); std::vector<double> min_data(s.elements(), min_quant);
std::vector<int> max_data(s.elements(), max_quant); std::vector<double> max_data(s.elements(), max_quant);
min_arg = m.add_literal(literal(s, min_data)); min_arg = m.add_literal(literal(s, min_data));
max_arg = m.add_literal(literal(s, max_data)); max_arg = m.add_literal(literal(s, max_data));
} }
......
...@@ -82,18 +82,21 @@ struct match_find_quantizable_ops ...@@ -82,18 +82,21 @@ struct match_find_quantizable_ops
// Helper function to insert quantized versions of any broadcasts and transpose ops that // Helper function to insert quantized versions of any broadcasts and transpose ops that
// occur between dequantizelinear and the quantized op // occur between dequantizelinear and the quantized op
static auto static auto
propagate_quantized_ins(module& m, const instruction_ref dqins, const instruction_ref qop) propagate_quantized_ins(module& m, const instruction_ref dqins, const instruction_ref qop_arg)
{ {
auto qinp = dqins->inputs().front(); auto prev_ins = qop_arg;
auto next_ins = dqins; std::vector<instruction_ref> ins_inbetween;
// matcher skips continguous, multi/broadcasts and transposes, collect all those
while(next_ins != qop) // instructions
while(prev_ins != dqins)
{ {
if(next_ins->name() != "dequantizelinear") ins_inbetween.push_back(prev_ins);
{ prev_ins = prev_ins->inputs().front();
qinp = m.insert_instruction(qop, next_ins->get_operator(), qinp); }
} auto qinp = dqins->inputs().front();
next_ins = next_ins->outputs().front(); for(auto ins : reverse_iterator_for(ins_inbetween))
{
qinp = m.insert_instruction(dqins, (*ins)->get_operator(), {qinp});
} }
return qinp; return qinp;
} }
...@@ -124,10 +127,11 @@ struct match_find_quantizable_ops ...@@ -124,10 +127,11 @@ struct match_find_quantizable_ops
auto scale2 = r.instructions["scale2"]; auto scale2 = r.instructions["scale2"];
auto zp1 = r.instructions["zp1"]; auto zp1 = r.instructions["zp1"];
auto zp2 = r.instructions["zp2"]; auto zp2 = r.instructions["zp2"];
// Only INT8 or FP8 type currently supported
// Only INT8 type currently supported std::set<migraphx::shape::type_t> supported_types = {migraphx::shape::fp8e4m3fnuz_type,
if(dq1->inputs().front()->get_shape().type() != migraphx::shape::int8_type or migraphx::shape::int8_type};
dq2->inputs().front()->get_shape().type() != migraphx::shape::int8_type) if(not contains(supported_types, dq1->inputs().front()->get_shape().type()) or
not contains(supported_types, dq2->inputs().front()->get_shape().type()))
return; return;
// Only symmetric quantization supported (ie. non-zero zero_points not allowed) // Only symmetric quantization supported (ie. non-zero zero_points not allowed)
...@@ -140,8 +144,8 @@ struct match_find_quantizable_ops ...@@ -140,8 +144,8 @@ struct match_find_quantizable_ops
// Propagate q1 and q2 through any broadcasts and transposes before qop // Propagate q1 and q2 through any broadcasts and transposes before qop
auto qop_args = qop->inputs(); auto qop_args = qop->inputs();
qop_args.at(0) = propagate_quantized_ins(m, dq1, qop); qop_args.at(0) = propagate_quantized_ins(m, dq1, qop_args[0]);
qop_args.at(1) = propagate_quantized_ins(m, dq2, qop); qop_args.at(1) = propagate_quantized_ins(m, dq2, qop_args[1]);
instruction_ref dq; instruction_ref dq;
instruction_ref out_scale; instruction_ref out_scale;
instruction_ref zero_point; instruction_ref zero_point;
......
...@@ -527,6 +527,62 @@ TEST_CASE(dot_add) ...@@ -527,6 +527,62 @@ TEST_CASE(dot_add)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(dot_add_multiple_dq_use)
{
migraphx::shape sh1{migraphx::shape::float_type, {32, 1}};
migraphx::shape sh2{migraphx::shape::float_type, {32, 32}};
migraphx::module m1;
{
auto t1 = m1.add_parameter("t1", sh1);
auto t2 = m1.add_parameter("t2", sh2);
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto d1_t =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), d1);
auto d1_tmb =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {32, 32}}}), d1_t);
auto d1_tmbc = m1.add_instruction(migraphx::make_op("contiguous"), d1_tmb);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot_1 = m1.add_instruction(migraphx::make_op("dot"), d1_tmbc, d2);
auto q3 = add_quantize_op(m1, "quantizelinear", dot_1, scale, zero);
auto d3 = add_quantize_op(m1, "dequantizelinear", q3, scale, zero);
auto dot_2 = m1.add_instruction(migraphx::make_op("dot"), d3, d1);
auto add = m1.add_instruction(migraphx::make_op("add"), {dot_2, d1});
m1.add_return({add});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q1_t =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), q1);
auto q1_tmb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {32, 32}}}), q1_t);
auto q1_tmbc = m2.add_instruction(migraphx::make_op("contiguous"), q1_tmb);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto dot_1 = m2.add_instruction(migraphx::make_op("quant_dot"), q1_tmbc, q2);
auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot_1->get_shape().lens());
auto d3 = add_quantize_op(m2, "dequantizelinear", dot_1, out_scale);
auto d3_q = add_quantize_op(m2, "quantizelinear", d3, scale, zero);
auto dot_2 = m2.add_instruction(migraphx::make_op("quant_dot"), d3_q, q1);
auto out_scale_2 = add_scale_mul(m2, scale, scale, 1, 1, dot_2->get_shape().lens());
auto d4 = add_quantize_op(m2, "dequantizelinear", dot_2, out_scale_2);
auto add = m2.add_instruction(migraphx::make_op("add"), d4, t1);
m2.add_return({add});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(conv) TEST_CASE(conv)
{ {
migraphx::shape s4{migraphx::shape::int8_type, {1280, 320, 1, 1}}; migraphx::shape s4{migraphx::shape::int8_type, {1280, 320, 1, 1}};
...@@ -919,7 +975,6 @@ TEST_CASE(mobilenet_snippet) ...@@ -919,7 +975,6 @@ TEST_CASE(mobilenet_snippet)
auto mod1 = create_module(); auto mod1 = create_module();
auto mod2 = create_module(); auto mod2 = create_module();
run_pass(mod2); run_pass(mod2);
auto match_qdq = migraphx::match::name("dequantizelinear")( auto match_qdq = migraphx::match::name("dequantizelinear")(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment