Unverified Commit 0039b11a authored by shivadbhavsar's avatar shivadbhavsar Committed by GitHub
Browse files

Support per-axis quantization (#2390)

Reworked the simplify_qdq pass to support:

Per-axis quantization (ie. allow 1D scales and zero points)
Allow broadcast and transpose ops between dq and quant_op
parent b2a40ea6
......@@ -591,6 +591,19 @@ MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins)
ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; });
}
MIGRAPHX_PRED_MATCHER(has_same_value, instruction_ref ins)
{
if(ins->name() != "@literal")
return false;
bool all_same = false;
ins->get_literal().visit([&](auto s) {
all_same = std::all_of(s.begin() + 1, s.end(), [&](const auto& scale) {
return float_equal(scale, s.front());
});
});
return all_same;
}
MIGRAPHX_BASIC_MATCHER(output, const matcher_context&, instruction_ref ins)
{
if(ins->outputs().size() == 1)
......@@ -844,6 +857,12 @@ auto skip_broadcasts_converts(Ms... ms)
return skip(name("broadcast", "multibroadcast", "contiguous", "convert"))(ms...);
}
template <class... Ms>
auto skip_broadcasts_transposes_contiguous(Ms... ms)
{
return skip(name("broadcast", "multibroadcast", "contiguous", "transpose"))(ms...);
}
template <class T>
inline auto has_value(T x, float tolerance = 1e-6)
{
......
......@@ -45,77 +45,145 @@ std::unordered_set<std::string> get_quantizable_op_names()
return s;
}
MIGRAPHX_PRED_MATCHER(has_same_value, instruction_ref ins)
struct match_find_quantizable_ops
{
if(ins->name() != "@literal")
return false;
bool all_same = false;
ins->get_literal().visit([&](auto s) {
all_same = std::all_of(s.begin() + 1, s.end(), [&](const auto& scale) {
return float_equal(scale, s.front());
static bool
is_valid_scale(instruction_ref scale, std::vector<std::size_t> lens, std::size_t axis)
{
return scale->get_shape().scalar() or scale->get_shape().elements() == lens.at(axis);
}
static bool is_valid_zero_point(instruction_ref zp)
{
if(not zp->can_eval())
return false;
bool all_zeros = false;
zp->eval().visit([&](auto z) {
all_zeros =
std::all_of(z.begin(), z.end(), [&](auto val) { return float_equal(val, 0); });
});
});
return all_same;
}
return all_zeros;
}
struct match_find_quantizable_ops
{
static auto
scale_broadcast_op(instruction_ref scale, std::vector<std::size_t> lens, std::size_t axis)
{
if(scale->get_shape().scalar())
{
return migraphx::make_op("multibroadcast", {{"out_lens", lens}});
}
else
{
return migraphx::make_op("broadcast", {{"out_lens", lens}, {"axis", axis}});
}
}
static auto dequantizelinear_op(const std::string& name, const std::string& scale)
// Helper function to insert quantized versions of any broadcasts and transpose ops that
// occur between dequantizelinear and the quantized op
static auto
propagate_quantized_ins(module& m, const instruction_ref dqins, const instruction_ref qop)
{
auto qinp = dqins->inputs().front();
auto next_ins = dqins;
while(next_ins != qop)
{
if(next_ins->name() != "dequantizelinear")
{
qinp = m.insert_instruction(qop, next_ins->get_operator(), qinp);
}
next_ins = next_ins->outputs().front();
}
return qinp;
}
static auto dequantizelinear_op(const std::string& scale, const std::string& zp)
{
return match::name("dequantizelinear")(
match::arg(0)(match::skip(match::name("quantizelinear"))(match::any().bind(name))),
match::arg(1)(match::skip_broadcasts(has_same_value().bind(scale))),
match::arg(2)(match::skip_broadcasts(match::all_of(match::has_value(0)))));
match::arg(0)(match::skip(match::name("quantizelinear"))(match::any())),
match::arg(1)(match::skip_broadcasts(match::is_constant().bind(scale))),
match::arg(2)(match::skip_broadcasts(match::is_constant().bind(zp))));
}
auto matcher() const
{
return match::name(get_quantizable_op_names())(
match::arg(0)(dequantizelinear_op("x1", "scale1")),
match::arg(1)(dequantizelinear_op("x2", "scale2")));
match::arg(0)(match::skip_broadcasts_transposes_contiguous(
dequantizelinear_op("scale1", "zp1").bind("dq1"))),
match::arg(1)(match::skip_broadcasts_transposes_contiguous(
dequantizelinear_op("scale2", "zp2").bind("dq2"))));
}
void apply(module& m, const match::matcher_result& r) const
{
auto qop = r.result;
auto q1 = r.instructions["x1"];
auto q2 = r.instructions["x2"];
auto dq1 = r.instructions["dq1"];
auto dq2 = r.instructions["dq2"];
auto scale1 = r.instructions["scale1"];
auto scale2 = r.instructions["scale2"];
auto zp1 = r.instructions["zp1"];
auto zp2 = r.instructions["zp2"];
// Only INT8 type currently supported
if(q1->get_shape().type() != migraphx::shape::int8_type or
q2->get_shape().type() != migraphx::shape::int8_type)
if(dq1->inputs().front()->get_shape().type() != migraphx::shape::int8_type or
dq2->inputs().front()->get_shape().type() != migraphx::shape::int8_type)
return;
double scale;
visit_all(scale1->get_literal(), scale2->get_literal())(
[&](const auto s1, const auto s2) { scale = s1.front() * s2.front(); });
// Only symmetric quantization supported (ie. non-zero zero_points not allowed)
if(not(is_valid_zero_point(zp1) and is_valid_zero_point(zp2)))
return;
// Only support scalar and 1D scales
if(scale1->get_shape().lens().size() != 1 or scale2->get_shape().lens().size() != 1)
return;
// Propagate q1 and q2 through any broadcasts and transposes before qop
auto qop_args = qop->inputs();
qop_args.at(0) = q1;
qop_args.at(1) = q2;
qop_args.at(0) = propagate_quantized_ins(m, dq1, qop);
qop_args.at(1) = propagate_quantized_ins(m, dq2, qop);
instruction_ref dq;
instruction_ref dq_scale;
instruction_ref out_scale;
instruction_ref zero_point;
if(qop->name() == "convolution")
{
auto conv_val = qop->get_operator().to_value();
dq = m.insert_instruction(
qop, migraphx::make_op("quant_convolution", conv_val), qop_args);
auto out_lens = dq->get_shape().lens();
// Input scale should always be scalar and weight scale can be scalar or 1D of the
// same lens as the output channel dim (dim 1 in the output)
if(not(is_valid_scale(scale1, out_lens, 1) and is_valid_scale(scale2, out_lens, 1)))
return;
auto s1_bcast =
m.insert_instruction(qop, scale_broadcast_op(scale1, out_lens, 1), scale1);
auto s2_bcast =
m.insert_instruction(qop, scale_broadcast_op(scale2, out_lens, 1), scale2);
out_scale = m.insert_instruction(qop, migraphx::make_op("mul"), s1_bcast, s2_bcast);
}
else if(qop->name() == "dot")
{
dq = m.insert_instruction(qop, migraphx::make_op("quant_dot"), qop_args);
dq = m.insert_instruction(qop, migraphx::make_op("quant_dot"), qop_args);
auto out_lens = dq->get_shape().lens();
// For (..., M, N) x (..., N, K) dot, only support cases where quantization axis is M
// for input1 and K for input 2
if(not(is_valid_scale(scale1, out_lens, out_lens.size() - 2) and
is_valid_scale(scale2, out_lens, out_lens.size() - 1)))
return;
auto s1_bcast = m.insert_instruction(
qop, scale_broadcast_op(scale1, out_lens, out_lens.size() - 2), scale1);
auto s2_bcast = m.insert_instruction(
qop, scale_broadcast_op(scale2, out_lens, out_lens.size() - 1), scale2);
out_scale = m.insert_instruction(qop, migraphx::make_op("mul"), s1_bcast, s2_bcast);
}
auto ins_type = qop->get_shape().type();
dq_scale = m.add_literal(literal({ins_type}, {scale}));
auto lens = dq->get_shape().lens();
auto scale_mb =
m.insert_instruction(qop, make_op("multibroadcast", {{"out_lens", lens}}), dq_scale);
dq = m.insert_instruction(qop, make_op("dequantizelinear"), dq, scale_mb);
dq = m.insert_instruction(qop, make_op("dequantizelinear"), dq, out_scale);
m.replace_instruction(qop, dq);
}
};
......
......@@ -636,13 +636,12 @@ TEST_CASE(dot_float)
migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale);
auto zp_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp);
auto quant_b = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto quant = mm->add_instruction(migraphx::make_op("quant_dot"), quant_a, quant_b);
std::vector<float> vec(sc.elements(), 100.0f);
auto dc = mm->add_literal(100.0f);
auto mdc =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sc.lens()}}), dc);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, mdc);
auto quant_b = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto quant = mm->add_instruction(migraphx::make_op("quant_dot"), quant_a, quant_b);
auto scale_mb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), scale);
auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_mb, scale_mb);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale);
mm->add_return({r});
return p;
......@@ -717,24 +716,28 @@ TEST_CASE(dot_double_2args)
auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb);
auto scale_a = mm->add_literal(10.0);
auto zp = mm->add_literal(static_cast<int8_t>(0));
scale_a = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_a);
auto scale_a_lit = mm->add_literal(10.0);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale_a = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_a_lit);
auto zp_a =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp);
auto qa = mm->add_instruction(migraphx::make_op("quantizelinear"), pa, scale_a, zp_a);
auto scale_b = mm->add_literal(5.0);
scale_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale_b);
auto qa = mm->add_instruction(migraphx::make_op("quantizelinear"), pa, scale_a, zp_a);
auto scale_b_lit = mm->add_literal(5.0);
auto scale_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale_b_lit);
auto zp_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qa, qb);
auto scale = mm->add_literal(50.0);
scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, scale);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qa, qb);
auto scale_a_mb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}),
scale_a_lit);
auto scale_b_mb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}),
scale_b_lit);
auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_a_mb, scale_b_mb);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, out_scale);
mm->add_return({r});
return p;
};
......@@ -798,19 +801,16 @@ TEST_CASE(dot_half_1arg)
migraphx::shape sa{migraphx::shape::half_type, {9, 9}};
auto x = mm->add_parameter("x", sa);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(migraphx::literal({sa.type()}, {10.0}));
scale = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}),
scale);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale_lit = mm->add_literal(migraphx::literal({sa.type()}, {10.0}));
auto scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_lit);
zp =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp);
auto qx = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale, zp);
auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qx, qx);
auto dq_scale = mm->add_literal(migraphx::literal({sa.type()}, {100.0}));
dq_scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}),
dq_scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, dq_scale);
auto qx = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale, zp);
auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qx, qx);
auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale, scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, out_scale);
mm->add_return({r});
return p;
};
......@@ -851,10 +851,10 @@ TEST_CASE(conv_float)
auto px = mm->add_parameter("x", sx);
auto pw = mm->add_parameter("w", sw);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(10.0f);
scale = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}),
scale);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale_lit = mm->add_literal(10.0f);
auto scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), scale_lit);
zp =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), zp);
auto quant_x = mm->add_instruction(migraphx::make_op("quantizelinear"), px, scale, zp);
......@@ -862,13 +862,11 @@ TEST_CASE(conv_float)
auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), quant_x, quant_w);
migraphx::shape sc{migraphx::shape::float_type, {4, 4, 1, 1}};
std::vector<float> vec(sc.elements(), 100.0f);
migraphx::shape s_scale{migraphx::shape::float_type, sc.lens()};
auto d_scale = mm->add_literal(100.0f);
d_scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {4, 4, 1, 1}}}), d_scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, d_scale);
auto scale_mb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}),
scale_lit);
auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_mb, scale_mb);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale);
mm->add_return({r});
return p;
......@@ -930,20 +928,21 @@ TEST_CASE(conv_half)
auto px = mm->add_parameter("x", sx);
auto pw = mm->add_parameter("w", sw);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(migraphx::literal({sx.type()}, {10.0}));
scale = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}),
scale);
auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale_lit = mm->add_literal(migraphx::literal({sx.type()}, {10.0}));
auto scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), scale_lit);
zp =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), zp);
auto quant_x = mm->add_instruction(migraphx::make_op("quantizelinear"), px, scale, zp);
auto quant_w = mm->add_instruction(migraphx::make_op("quantizelinear"), pw, scale, zp);
auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), quant_x, quant_w);
auto d_scale = mm->add_literal(migraphx::literal({sx.type()}, {100.0}));
d_scale = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {4, 4, 1, 1}}}), d_scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, d_scale);
auto scale_mb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}),
scale_lit);
auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_mb, scale_mb);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale);
mm->add_return({r});
return p;
......@@ -1185,12 +1184,12 @@ TEST_CASE(int8_subgraph)
migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), s1);
auto zpb = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), zp1);
auto qb = then_mod->add_instruction(migraphx::make_op("quantizelinear"), b, sb, zpb);
auto qdot = then_mod->add_instruction(migraphx::make_op("quant_dot"), qa, qb);
auto so = then_mod->add_literal(100.0f);
so = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sout.lens()}}), so);
auto r = then_mod->add_instruction(migraphx::make_op("dequantizelinear"), qdot, so);
auto qb = then_mod->add_instruction(migraphx::make_op("quantizelinear"), b, sb, zpb);
auto qdot = then_mod->add_instruction(migraphx::make_op("quant_dot"), qa, qb);
auto s1_mb = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), s1);
auto so = then_mod->add_instruction(migraphx::make_op("mul"), s1_mb, s1_mb);
auto r = then_mod->add_instruction(migraphx::make_op("dequantizelinear"), qdot, so);
then_mod->add_return({r});
migraphx::shape sd{migraphx::shape::float_type, {2, 2, 4, 6}};
......@@ -1199,24 +1198,25 @@ TEST_CASE(int8_subgraph)
auto w = mm->add_parameter("w", sw);
// else submod
auto* else_mod = p.create_module("If_6_else");
auto sax = else_mod->add_literal(2.0f);
auto sax_lit = else_mod->add_literal(2.0f);
auto zp = else_mod->add_literal(static_cast<int8_t>(0));
sax = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sd.lens()}}), sax);
auto sax = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sd.lens()}}), sax_lit);
auto zpx = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sd.lens()}}), zp);
auto qx = else_mod->add_instruction(migraphx::make_op("quantizelinear"), x, sax, zpx);
auto ssw = else_mod->add_literal(1.66667f);
ssw = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sw.lens()}}), ssw);
auto qx = else_mod->add_instruction(migraphx::make_op("quantizelinear"), x, sax, zpx);
auto ssw_lit = else_mod->add_literal(1.66667f);
auto ssw = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sw.lens()}}), ssw_lit);
auto zpw = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sw.lens()}}), zp);
auto qw = else_mod->add_instruction(migraphx::make_op("quantizelinear"), w, ssw, zpw);
auto qconv = else_mod->add_instruction(migraphx::make_op("quant_convolution"), qx, qw);
auto so1 = else_mod->add_literal(3.33333f);
so1 = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sout.lens()}}), so1);
auto r1 = else_mod->add_instruction(migraphx::make_op("dequantizelinear"), qconv, so1);
auto qw = else_mod->add_instruction(migraphx::make_op("quantizelinear"), w, ssw, zpw);
auto qconv = else_mod->add_instruction(migraphx::make_op("quant_convolution"), qx, qw);
auto ssw_mb = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qconv->get_shape().lens()}}),
ssw_lit);
auto so1 = else_mod->add_instruction(migraphx::make_op("mul"), sax, ssw_mb);
auto r1 = else_mod->add_instruction(migraphx::make_op("dequantizelinear"), qconv, so1);
else_mod->add_return({r1});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
......
......@@ -44,20 +44,34 @@ void run_pass(migraphx::module& m)
sqdq.apply(m);
}
migraphx::instruction_ref add_quantize_op(migraphx::module& m,
const std::string& name,
migraphx::instruction_ref x,
migraphx::instruction_ref broadcast_scale(migraphx::module& m,
migraphx::instruction_ref scale,
migraphx::instruction_ref shift)
const std::vector<std::size_t>& out_lens,
std::size_t axis)
{
auto lens = x->get_shape().lens();
if(scale->get_shape().lens() == out_lens)
return scale;
migraphx::instruction_ref scale_mb;
if(scale->get_shape().lens().front() == 1)
auto scale_lens = scale->get_shape().lens();
if(scale_lens.front() == 1 and scale_lens.size() == 1)
scale_mb =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), scale);
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), scale);
else
scale_mb = m.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", lens}}), scale);
migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", out_lens}}), scale);
return scale_mb;
}
migraphx::instruction_ref add_quantize_op(migraphx::module& m,
const std::string& name,
migraphx::instruction_ref x,
migraphx::instruction_ref scale,
migraphx::instruction_ref shift,
std::size_t q_axis = 1)
{
auto lens = x->get_shape().lens();
auto scale_mb = broadcast_scale(m, scale, lens, q_axis);
auto shift_mb =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), shift);
return m.add_instruction(migraphx::make_op(name), x, scale_mb, shift_mb);
......@@ -66,19 +80,26 @@ migraphx::instruction_ref add_quantize_op(migraphx::module& m,
migraphx::instruction_ref add_quantize_op(migraphx::module& m,
const std::string& name,
migraphx::instruction_ref x,
migraphx::instruction_ref scale)
migraphx::instruction_ref scale,
std::size_t q_axis = 1)
{
auto lens = x->get_shape().lens();
migraphx::instruction_ref scale_mb;
if(scale->get_shape().lens().front() == 1)
scale_mb =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), scale);
else
scale_mb = m.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", lens}}), scale);
auto lens = x->get_shape().lens();
auto scale_mb = broadcast_scale(m, scale, lens, q_axis);
return m.add_instruction(migraphx::make_op(name), x, scale_mb);
}
migraphx::instruction_ref add_scale_mul(migraphx::module& m,
migraphx::instruction_ref scale1,
migraphx::instruction_ref scale2,
std::size_t axis1,
std::size_t axis2,
const std::vector<std::size_t>& out_lens)
{
auto scale1_mb = broadcast_scale(m, scale1, out_lens, axis1);
auto scale2_mb = broadcast_scale(m, scale2, out_lens, axis2);
return m.add_instruction(migraphx::make_op("mul"), scale1_mb, scale2_mb);
}
TEST_CASE(remove_qdq)
{
migraphx::shape sh1{migraphx::shape::float_type, {100, 100}};
......@@ -159,18 +180,62 @@ TEST_CASE(dot)
m1.add_return({dot});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2);
auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot->get_shape().lens());
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale);
m2.add_return({d3});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(dot_multi_scale)
{
migraphx::shape sh1{migraphx::shape::float_type, {1280, 1000}};
migraphx::shape sh2{migraphx::shape::float_type, {1000, 1024}};
migraphx::shape sh3{migraphx::shape::float_type, {1280}};
migraphx::module m1;
{
auto t1 = m1.add_parameter("t1", sh1);
auto t2 = m1.add_parameter("t2", sh2);
auto scale1 = m1.add_literal(migraphx::generate_literal(sh3, 0));
auto scale2 = m1.add_literal(0.4f);
auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale1, zero, 0);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale1, zero, 0);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale2, zero, 1);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale2, zero, 1);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2);
m1.add_return({dot});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto scale = m2.add_literal(0.5f);
auto scale1 = m2.add_literal(migraphx::generate_literal(sh3, 0));
auto scale2 = m2.add_literal(0.4f);
auto zero = m2.add_literal(std::int8_t{0});
auto scale1 = m2.add_literal(0.25f);
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2);
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, scale1);
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale1, zero, 0);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale2, zero, 1);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2);
auto out_scale = add_scale_mul(m2, scale1, scale2, 0, 1, dot->get_shape().lens());
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale);
m2.add_return({d3});
}
......@@ -178,6 +243,180 @@ TEST_CASE(dot)
EXPECT(m1 == m2);
}
TEST_CASE(dot_broadcasted)
{
migraphx::shape sh1{migraphx::shape::float_type, {2, 1280, 1000}};
migraphx::shape sh2{migraphx::shape::float_type, {1000, 1024}};
migraphx::module m1;
{
auto t1 = m1.add_parameter("t1", sh1);
auto t2 = m1.add_parameter("t2", sh2);
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto d2_mb = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 1000, 1024}}}), d2);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2_mb);
m1.add_return({dot});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto q2_mb = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 1000, 1024}}}), q2);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2_mb);
auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot->get_shape().lens());
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale);
m2.add_return({d3});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(dot_transposed)
{
migraphx::shape sh1{migraphx::shape::float_type, {1280, 1000}};
migraphx::shape sh2{migraphx::shape::float_type, {1024, 1000}};
migraphx::module m1;
{
auto t1 = m1.add_parameter("t1", sh1);
auto t2 = m1.add_parameter("t2", sh2);
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto d2_t =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), d2);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2_t);
m1.add_return({dot});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto q2_t =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), q2);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2_t);
auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot->get_shape().lens());
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale);
m2.add_return({d3});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(dot_multi_scale_transposed_broadcasted)
{
migraphx::shape sh1{migraphx::shape::float_type, {2, 3, 1280, 1000}};
migraphx::shape sh2{migraphx::shape::float_type, {1024, 1000}};
migraphx::shape sh3{migraphx::shape::float_type, {1280}};
migraphx::shape sh4{migraphx::shape::float_type, {1024}};
migraphx::module m1;
{
auto t1 = m1.add_parameter("t1", sh1);
auto t2 = m1.add_parameter("t2", sh2);
auto scale1 = m1.add_literal(migraphx::generate_literal(sh3, 0));
auto scale2 = m1.add_literal(migraphx::generate_literal(sh4, 0));
auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale1, zero, 2);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale1, zero, 2);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale2, zero, 0);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale2, zero, 0);
auto d2_t =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), d2);
auto d2_mb = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 1000, 1024}}}), d2_t);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2_mb);
m1.add_return({dot});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto scale1 = m2.add_literal(migraphx::generate_literal(sh3, 0));
auto scale2 = m2.add_literal(migraphx::generate_literal(sh4, 0));
auto zero = m2.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale1, zero, 2);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale2, zero, 0);
auto q2_t =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), q2);
auto q2_mb = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 1000, 1024}}}), q2_t);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2_mb);
auto out_scale = add_scale_mul(m2, scale1, scale2, 2, 3, dot->get_shape().lens());
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale);
m2.add_return({d3});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(dot_multi_scale_unsupported_axis)
{
migraphx::shape sh1{migraphx::shape::float_type, {1280, 1000}};
migraphx::shape sh2{migraphx::shape::float_type, {1000, 1024}};
migraphx::shape sh3{migraphx::shape::float_type, {1000}};
migraphx::module m1;
{
auto t1 = m1.add_parameter("t1", sh1);
auto t2 = m1.add_parameter("t2", sh2);
auto scale1 = m1.add_literal(migraphx::generate_literal(sh3, 0));
auto scale2 = m1.add_literal(0.4f);
auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale1, zero, 1);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale1, zero, 1);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale2, zero, 1);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale2, zero, 1);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2);
m1.add_return({dot});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto dot = m2.add_instruction(migraphx::make_op("dot"), t1, t2);
m2.add_return({dot});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(dot_non_zero_point)
{
migraphx::shape sh1{migraphx::shape::float_type, {1280, 1000}};
......@@ -269,18 +508,18 @@ TEST_CASE(dot_add)
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto ab = m2.add_parameter("ab", sh3);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto scale1 = m2.add_literal(0.25f);
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2);
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, scale1);
auto add = m2.add_instruction(migraphx::make_op("add"), d3, ab);
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto ab = m2.add_parameter("ab", sh3);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2);
auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot->get_shape().lens());
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale);
auto add = m2.add_instruction(migraphx::make_op("add"), d3, ab);
m2.add_return({add});
}
......@@ -320,26 +559,80 @@ TEST_CASE(conv)
auto weights = m2.add_parameter("weights", s4);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto scale1 = m2.add_literal(0.25f);
auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero);
auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution",
auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero);
auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution",
{{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"dilation", {1, 1}},
{"group", 1},
{"padding_mode", 0}}),
q1,
weights);
auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, c1->get_shape().lens());
auto d6 = add_quantize_op(m2, "dequantizelinear", c1, out_scale);
m2.add_return({d6});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(conv_multi_scale)
{
migraphx::shape s4{migraphx::shape::int8_type, {1280, 320, 1, 1}};
migraphx::shape s7{migraphx::shape::float_type, {1, 320, 7, 7}};
migraphx::shape s8{migraphx::shape::float_type, {1280}};
migraphx::module m1;
{
auto input = m1.add_parameter("input", s7);
auto weights = m1.add_parameter("weights", s4);
auto w_scale = m1.add_literal(migraphx::generate_literal(s8, 0));
auto inp_scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});
auto d1 = add_quantize_op(m1, "dequantizelinear", weights, w_scale, zero, 0);
auto q1 = add_quantize_op(m1, "quantizelinear", input, inp_scale, zero);
auto d5 = add_quantize_op(m1, "dequantizelinear", q1, inp_scale, zero);
auto c1 = m1.add_instruction(migraphx::make_op("convolution",
{{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"dilation", {1, 1}},
{"group", 1},
{"padding_mode", 0}}),
q1,
d5,
d1);
m1.add_return({c1});
}
migraphx::module m2;
{
auto input = m2.add_parameter("input", s7);
auto weights = m2.add_parameter("weights", s4);
auto w_scale = m2.add_literal(migraphx::generate_literal(s8, 0));
auto inp_scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto q_inp = add_quantize_op(m2, "quantizelinear", input, inp_scale, zero);
auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution",
{{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"dilation", {1, 1}},
{"group", 1},
{"padding_mode", 0}}),
q_inp,
weights);
auto d6 = add_quantize_op(m2, "dequantizelinear", c1, scale1);
m2.add_return({d6});
auto out_scale = add_scale_mul(m2, inp_scale, w_scale, 1, 1, c1->get_shape().lens());
auto d1 = add_quantize_op(m2, "dequantizelinear", c1, out_scale);
m2.add_return({d1});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(conv_multi_scale)
TEST_CASE(conv_multi_scale_unsupported_axis)
{
migraphx::shape s4{migraphx::shape::int8_type, {1280, 320, 1, 1}};
migraphx::shape s7{migraphx::shape::float_type, {1, 320, 7, 7}};
......@@ -430,20 +723,20 @@ TEST_CASE(conv_bias_add)
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto zero32 = m2.add_literal(std::int32_t{0});
auto scale1 = m2.add_literal(0.25f);
auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero32);
auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero);
auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution",
{{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"dilation", {1, 1}},
{"group", 1},
{"padding_mode", 0}}),
auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero32);
auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero);
auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution",
{{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"dilation", {1, 1}},
{"group", 1},
{"padding_mode", 0}}),
q1,
weights);
auto d6 = add_quantize_op(m2, "dequantizelinear", c1, scale1);
auto b1 = m2.add_instruction(
auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, c1->get_shape().lens());
auto d6 = add_quantize_op(m2, "dequantizelinear", c1, out_scale);
auto b1 = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m2.add_instruction(migraphx::make_op("add"), d6, b1);
m2.add_return({a1});
......@@ -519,22 +812,21 @@ TEST_CASE(conv_pooling_dot)
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto zero32 = m2.add_literal(std::int32_t{0});
auto scale1 = m2.add_literal(0.25f);
auto scale2 = m2.add_literal(0.25f);
auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero32);
auto d3 = add_quantize_op(m2, "dequantizelinear", ab, scale, zero);
auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero);
auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution",
{{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"dilation", {1, 1}},
{"group", 1},
{"padding_mode", 0}}),
auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero32);
auto d3 = add_quantize_op(m2, "dequantizelinear", ab, scale, zero);
auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero);
auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution",
{{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"dilation", {1, 1}},
{"group", 1},
{"padding_mode", 0}}),
q1,
weights);
auto d5 = add_quantize_op(m2, "dequantizelinear", c1, scale1);
auto bc1 = m2.add_instruction(
auto out_scale1 = add_scale_mul(m2, scale, scale, 1, 1, c1->get_shape().lens());
auto d5 = add_quantize_op(m2, "dequantizelinear", c1, out_scale1);
auto bc1 = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m2.add_instruction(migraphx::make_op("add"), d5, bc1);
auto ap =
......@@ -545,10 +837,11 @@ TEST_CASE(conv_pooling_dot)
{"lengths", {7, 7}},
{"ceil_mode", 0}}),
a1);
auto fl = m2.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap);
auto q4 = add_quantize_op(m2, "quantizelinear", fl, scale, zero);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q4, db);
auto d9 = add_quantize_op(m2, "dequantizelinear", dot, scale2);
auto fl = m2.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap);
auto q4 = add_quantize_op(m2, "quantizelinear", fl, scale, zero);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q4, db);
auto out_scale2 = add_scale_mul(m2, scale, scale, 1, 0, dot->get_shape().lens());
auto d9 = add_quantize_op(m2, "dequantizelinear", dot, out_scale2);
auto mb1 =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3);
auto a2 = m2.add_instruction(migraphx::make_op("add"), d9, mb1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment