You need to sign in or sign up before continuing.
Commit 95616a07 authored by Umang Yadav's avatar Umang Yadav
Browse files

Enable simplify qdq to work with FP8 types

parent 6a72e8fc
......@@ -72,8 +72,8 @@ struct dequantizelinear
visit_all(x, x_zero_point)([&](auto input, auto zero_pts) {
visit_all(result, x_scale)([&](auto output, auto scales) {
par_for(output_shape.elements(), [&](auto i) {
output[i] = static_cast<double>(static_cast<int64_t>(input[i]) -
static_cast<int64_t>(zero_pts[i])) *
output[i] = static_cast<double>(static_cast<double>(input[i]) -
static_cast<double>(zero_pts[i])) *
scales[i];
});
});
......
......@@ -80,10 +80,10 @@ struct quantizelinear
auto min_value = std::numeric_limits<quant_type>::min();
auto max_value = std::numeric_limits<quant_type>::max();
par_for(output_shape.elements(), [&](auto i) {
int64_t quantized = static_cast<int64_t>(std::nearbyint(input[i] / scales[i])) +
static_cast<int64_t>(zero_pts[i]);
output[i] = std::max(static_cast<int64_t>(min_value),
std::min(static_cast<int64_t>(max_value), quantized));
double quantized = static_cast<double>(std::nearbyint(input[i] / scales[i])) +
static_cast<double>(zero_pts[i]);
output[i] = std::max(static_cast<double>(min_value),
std::min(static_cast<double>(max_value), quantized));
});
});
});
......
......@@ -58,8 +58,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point);
}
int64_t max_quant = 0;
int64_t min_quant = 0;
double max_quant = 0;
double min_quant = 0;
ins->get_shape().visit_type([&](auto qt) {
max_quant = qt.max();
min_quant = qt.min();
......@@ -70,8 +70,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
if(enabled(MIGRAPHX_ENABLE_CK_WORKAROUNDS{}))
{
std::vector<int> min_data(s.elements(), min_quant);
std::vector<int> max_data(s.elements(), max_quant);
std::vector<double> min_data(s.elements(), min_quant);
std::vector<double> max_data(s.elements(), max_quant);
min_arg = m.add_literal(literal(s, min_data));
max_arg = m.add_literal(literal(s, max_data));
}
......
......@@ -82,18 +82,21 @@ struct match_find_quantizable_ops
// Helper function to insert quantized versions of any broadcasts and transpose ops that
// occur between dequantizelinear and the quantized op
static auto
propagate_quantized_ins(module& m, const instruction_ref dqins, const instruction_ref qop)
propagate_quantized_ins(module& m, const instruction_ref dqins, const instruction_ref qop_arg)
{
auto qinp = dqins->inputs().front();
auto next_ins = dqins;
while(next_ins != qop)
{
if(next_ins->name() != "dequantizelinear")
auto prev_ins = qop_arg;
std::vector<instruction_ref> ins_inbetween;
// matcher skips continguous, multi/broadcasts and transposes, collect all those
// instructions
while(prev_ins != dqins)
{
qinp = m.insert_instruction(qop, next_ins->get_operator(), qinp);
ins_inbetween.push_back(prev_ins);
prev_ins = prev_ins->inputs().front();
}
next_ins = next_ins->outputs().front();
auto qinp = dqins->inputs().front();
for(auto ins : reverse_iterator_for(ins_inbetween))
{
qinp = m.insert_instruction(dqins, (*ins)->get_operator(), {qinp});
}
return qinp;
}
......@@ -124,10 +127,11 @@ struct match_find_quantizable_ops
auto scale2 = r.instructions["scale2"];
auto zp1 = r.instructions["zp1"];
auto zp2 = r.instructions["zp2"];
// Only INT8 type currently supported
if(dq1->inputs().front()->get_shape().type() != migraphx::shape::int8_type or
dq2->inputs().front()->get_shape().type() != migraphx::shape::int8_type)
// Only INT8 or FP8 type currently supported
std::set<migraphx::shape::type_t> supported_types = {migraphx::shape::fp8e4m3fnuz_type,
migraphx::shape::int8_type};
if(not contains(supported_types, dq1->inputs().front()->get_shape().type()) or
not contains(supported_types, dq2->inputs().front()->get_shape().type()))
return;
// Only symmetric quantization supported (ie. non-zero zero_points not allowed)
......@@ -140,8 +144,8 @@ struct match_find_quantizable_ops
// Propagate q1 and q2 through any broadcasts and transposes before qop
auto qop_args = qop->inputs();
qop_args.at(0) = propagate_quantized_ins(m, dq1, qop);
qop_args.at(1) = propagate_quantized_ins(m, dq2, qop);
qop_args.at(0) = propagate_quantized_ins(m, dq1, qop_args[0]);
qop_args.at(1) = propagate_quantized_ins(m, dq2, qop_args[1]);
instruction_ref dq;
instruction_ref out_scale;
instruction_ref zero_point;
......
......@@ -527,6 +527,62 @@ TEST_CASE(dot_add)
EXPECT(m1 == m2);
}
TEST_CASE(dot_add_multiple_dq_use)
{
migraphx::shape sh1{migraphx::shape::float_type, {32, 1}};
migraphx::shape sh2{migraphx::shape::float_type, {32, 32}};
migraphx::module m1;
{
auto t1 = m1.add_parameter("t1", sh1);
auto t2 = m1.add_parameter("t2", sh2);
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto d1_t =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), d1);
auto d1_tmb =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {32, 32}}}), d1_t);
auto d1_tmbc = m1.add_instruction(migraphx::make_op("contiguous"), d1_tmb);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot_1 = m1.add_instruction(migraphx::make_op("dot"), d1_tmbc, d2);
auto q3 = add_quantize_op(m1, "quantizelinear", dot_1, scale, zero);
auto d3 = add_quantize_op(m1, "dequantizelinear", q3, scale, zero);
auto dot_2 = m1.add_instruction(migraphx::make_op("dot"), d3, d1);
auto add = m1.add_instruction(migraphx::make_op("add"), {dot_2, d1});
m1.add_return({add});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q1_t =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), q1);
auto q1_tmb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {32, 32}}}), q1_t);
auto q1_tmbc = m2.add_instruction(migraphx::make_op("contiguous"), q1_tmb);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto dot_1 = m2.add_instruction(migraphx::make_op("quant_dot"), q1_tmbc, q2);
auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot_1->get_shape().lens());
auto d3 = add_quantize_op(m2, "dequantizelinear", dot_1, out_scale);
auto d3_q = add_quantize_op(m2, "quantizelinear", d3, scale, zero);
auto dot_2 = m2.add_instruction(migraphx::make_op("quant_dot"), d3_q, q1);
auto out_scale_2 = add_scale_mul(m2, scale, scale, 1, 1, dot_2->get_shape().lens());
auto d4 = add_quantize_op(m2, "dequantizelinear", dot_2, out_scale_2);
auto add = m2.add_instruction(migraphx::make_op("add"), d4, t1);
m2.add_return({add});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(conv)
{
migraphx::shape s4{migraphx::shape::int8_type, {1280, 320, 1, 1}};
......@@ -919,7 +975,6 @@ TEST_CASE(mobilenet_snippet)
auto mod1 = create_module();
auto mod2 = create_module();
run_pass(mod2);
auto match_qdq = migraphx::match::name("dequantizelinear")(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment