Unverified Commit 0039b11a authored by shivadbhavsar's avatar shivadbhavsar Committed by GitHub
Browse files

Support per-axis quantization (#2390)

Reworked the simplify_qdq pass to support:

Per-axis quantization (ie. allow 1D scales and zero points)
Allow broadcast and transpose ops between dq and quant_op
parent b2a40ea6
...@@ -591,6 +591,19 @@ MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins) ...@@ -591,6 +591,19 @@ MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins)
ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; }); ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; });
} }
MIGRAPHX_PRED_MATCHER(has_same_value, instruction_ref ins)
{
if(ins->name() != "@literal")
return false;
bool all_same = false;
ins->get_literal().visit([&](auto s) {
all_same = std::all_of(s.begin() + 1, s.end(), [&](const auto& scale) {
return float_equal(scale, s.front());
});
});
return all_same;
}
MIGRAPHX_BASIC_MATCHER(output, const matcher_context&, instruction_ref ins) MIGRAPHX_BASIC_MATCHER(output, const matcher_context&, instruction_ref ins)
{ {
if(ins->outputs().size() == 1) if(ins->outputs().size() == 1)
...@@ -844,6 +857,12 @@ auto skip_broadcasts_converts(Ms... ms) ...@@ -844,6 +857,12 @@ auto skip_broadcasts_converts(Ms... ms)
return skip(name("broadcast", "multibroadcast", "contiguous", "convert"))(ms...); return skip(name("broadcast", "multibroadcast", "contiguous", "convert"))(ms...);
} }
template <class... Ms>
auto skip_broadcasts_transposes_contiguous(Ms... ms)
{
return skip(name("broadcast", "multibroadcast", "contiguous", "transpose"))(ms...);
}
template <class T> template <class T>
inline auto has_value(T x, float tolerance = 1e-6) inline auto has_value(T x, float tolerance = 1e-6)
{ {
......
...@@ -45,77 +45,145 @@ std::unordered_set<std::string> get_quantizable_op_names() ...@@ -45,77 +45,145 @@ std::unordered_set<std::string> get_quantizable_op_names()
return s; return s;
} }
MIGRAPHX_PRED_MATCHER(has_same_value, instruction_ref ins) struct match_find_quantizable_ops
{ {
if(ins->name() != "@literal") static bool
is_valid_scale(instruction_ref scale, std::vector<std::size_t> lens, std::size_t axis)
{
return scale->get_shape().scalar() or scale->get_shape().elements() == lens.at(axis);
}
static bool is_valid_zero_point(instruction_ref zp)
{
if(not zp->can_eval())
return false; return false;
bool all_same = false;
ins->get_literal().visit([&](auto s) { bool all_zeros = false;
all_same = std::all_of(s.begin() + 1, s.end(), [&](const auto& scale) { zp->eval().visit([&](auto z) {
return float_equal(scale, s.front()); all_zeros =
}); std::all_of(z.begin(), z.end(), [&](auto val) { return float_equal(val, 0); });
}); });
return all_same; return all_zeros;
} }
struct match_find_quantizable_ops static auto
{ scale_broadcast_op(instruction_ref scale, std::vector<std::size_t> lens, std::size_t axis)
{
if(scale->get_shape().scalar())
{
return migraphx::make_op("multibroadcast", {{"out_lens", lens}});
}
else
{
return migraphx::make_op("broadcast", {{"out_lens", lens}, {"axis", axis}});
}
}
// Helper function to insert quantized versions of any broadcasts and transpose ops that
// occur between dequantizelinear and the quantized op
static auto
propagate_quantized_ins(module& m, const instruction_ref dqins, const instruction_ref qop)
{
auto qinp = dqins->inputs().front();
auto next_ins = dqins;
static auto dequantizelinear_op(const std::string& name, const std::string& scale) while(next_ins != qop)
{
if(next_ins->name() != "dequantizelinear")
{
qinp = m.insert_instruction(qop, next_ins->get_operator(), qinp);
}
next_ins = next_ins->outputs().front();
}
return qinp;
}
static auto dequantizelinear_op(const std::string& scale, const std::string& zp)
{ {
return match::name("dequantizelinear")( return match::name("dequantizelinear")(
match::arg(0)(match::skip(match::name("quantizelinear"))(match::any().bind(name))), match::arg(0)(match::skip(match::name("quantizelinear"))(match::any())),
match::arg(1)(match::skip_broadcasts(has_same_value().bind(scale))), match::arg(1)(match::skip_broadcasts(match::is_constant().bind(scale))),
match::arg(2)(match::skip_broadcasts(match::all_of(match::has_value(0))))); match::arg(2)(match::skip_broadcasts(match::is_constant().bind(zp))));
} }
auto matcher() const auto matcher() const
{ {
return match::name(get_quantizable_op_names())( return match::name(get_quantizable_op_names())(
match::arg(0)(dequantizelinear_op("x1", "scale1")), match::arg(0)(match::skip_broadcasts_transposes_contiguous(
match::arg(1)(dequantizelinear_op("x2", "scale2"))); dequantizelinear_op("scale1", "zp1").bind("dq1"))),
match::arg(1)(match::skip_broadcasts_transposes_contiguous(
dequantizelinear_op("scale2", "zp2").bind("dq2"))));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto qop = r.result; auto qop = r.result;
auto q1 = r.instructions["x1"]; auto dq1 = r.instructions["dq1"];
auto q2 = r.instructions["x2"]; auto dq2 = r.instructions["dq2"];
auto scale1 = r.instructions["scale1"]; auto scale1 = r.instructions["scale1"];
auto scale2 = r.instructions["scale2"]; auto scale2 = r.instructions["scale2"];
auto zp1 = r.instructions["zp1"];
auto zp2 = r.instructions["zp2"];
// Only INT8 type currently supported // Only INT8 type currently supported
if(q1->get_shape().type() != migraphx::shape::int8_type or if(dq1->inputs().front()->get_shape().type() != migraphx::shape::int8_type or
q2->get_shape().type() != migraphx::shape::int8_type) dq2->inputs().front()->get_shape().type() != migraphx::shape::int8_type)
return;
// Only symmetric quantization supported (ie. non-zero zero_points not allowed)
if(not(is_valid_zero_point(zp1) and is_valid_zero_point(zp2)))
return; return;
double scale; // Only support scalar and 1D scales
visit_all(scale1->get_literal(), scale2->get_literal())( if(scale1->get_shape().lens().size() != 1 or scale2->get_shape().lens().size() != 1)
[&](const auto s1, const auto s2) { scale = s1.front() * s2.front(); }); return;
// Propagate q1 and q2 through any broadcasts and transposes before qop
auto qop_args = qop->inputs(); auto qop_args = qop->inputs();
qop_args.at(0) = q1; qop_args.at(0) = propagate_quantized_ins(m, dq1, qop);
qop_args.at(1) = q2; qop_args.at(1) = propagate_quantized_ins(m, dq2, qop);
instruction_ref dq; instruction_ref dq;
instruction_ref dq_scale; instruction_ref out_scale;
instruction_ref zero_point; instruction_ref zero_point;
if(qop->name() == "convolution") if(qop->name() == "convolution")
{ {
auto conv_val = qop->get_operator().to_value(); auto conv_val = qop->get_operator().to_value();
dq = m.insert_instruction( dq = m.insert_instruction(
qop, migraphx::make_op("quant_convolution", conv_val), qop_args); qop, migraphx::make_op("quant_convolution", conv_val), qop_args);
auto out_lens = dq->get_shape().lens();
// Input scale should always be scalar and weight scale can be scalar or 1D of the
// same lens as the output channel dim (dim 1 in the output)
if(not(is_valid_scale(scale1, out_lens, 1) and is_valid_scale(scale2, out_lens, 1)))
return;
auto s1_bcast =
m.insert_instruction(qop, scale_broadcast_op(scale1, out_lens, 1), scale1);
auto s2_bcast =
m.insert_instruction(qop, scale_broadcast_op(scale2, out_lens, 1), scale2);
out_scale = m.insert_instruction(qop, migraphx::make_op("mul"), s1_bcast, s2_bcast);
} }
else if(qop->name() == "dot") else if(qop->name() == "dot")
{ {
dq = m.insert_instruction(qop, migraphx::make_op("quant_dot"), qop_args); dq = m.insert_instruction(qop, migraphx::make_op("quant_dot"), qop_args);
auto out_lens = dq->get_shape().lens();
// For (..., M, N) x (..., N, K) dot, only support cases where quantization axis is M
// for input1 and K for input 2
if(not(is_valid_scale(scale1, out_lens, out_lens.size() - 2) and
is_valid_scale(scale2, out_lens, out_lens.size() - 1)))
return;
auto s1_bcast = m.insert_instruction(
qop, scale_broadcast_op(scale1, out_lens, out_lens.size() - 2), scale1);
auto s2_bcast = m.insert_instruction(
qop, scale_broadcast_op(scale2, out_lens, out_lens.size() - 1), scale2);
out_scale = m.insert_instruction(qop, migraphx::make_op("mul"), s1_bcast, s2_bcast);
} }
auto ins_type = qop->get_shape().type();
dq_scale = m.add_literal(literal({ins_type}, {scale}));
auto lens = dq->get_shape().lens(); dq = m.insert_instruction(qop, make_op("dequantizelinear"), dq, out_scale);
auto scale_mb =
m.insert_instruction(qop, make_op("multibroadcast", {{"out_lens", lens}}), dq_scale);
dq = m.insert_instruction(qop, make_op("dequantizelinear"), dq, scale_mb);
m.replace_instruction(qop, dq); m.replace_instruction(qop, dq);
} }
}; };
......
...@@ -638,11 +638,10 @@ TEST_CASE(dot_float) ...@@ -638,11 +638,10 @@ TEST_CASE(dot_float)
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp);
auto quant_b = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b); auto quant_b = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto quant = mm->add_instruction(migraphx::make_op("quant_dot"), quant_a, quant_b); auto quant = mm->add_instruction(migraphx::make_op("quant_dot"), quant_a, quant_b);
std::vector<float> vec(sc.elements(), 100.0f); auto scale_mb = mm->add_instruction(
auto dc = mm->add_literal(100.0f); migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), scale);
auto mdc = auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_mb, scale_mb);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sc.lens()}}), dc); auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, mdc);
mm->add_return({r}); mm->add_return({r});
return p; return p;
...@@ -717,24 +716,28 @@ TEST_CASE(dot_double_2args) ...@@ -717,24 +716,28 @@ TEST_CASE(dot_double_2args)
auto pa = mm->add_parameter("a", sa); auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb); auto pb = mm->add_parameter("b", sb);
auto scale_a = mm->add_literal(10.0); auto scale_a_lit = mm->add_literal(10.0);
auto zp = mm->add_literal(static_cast<int8_t>(0)); auto zp = mm->add_literal(static_cast<int8_t>(0));
scale_a = mm->add_instruction( auto scale_a = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_a); migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_a_lit);
auto zp_a = auto zp_a =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp);
auto qa = mm->add_instruction(migraphx::make_op("quantizelinear"), pa, scale_a, zp_a); auto qa = mm->add_instruction(migraphx::make_op("quantizelinear"), pa, scale_a, zp_a);
auto scale_b = mm->add_literal(5.0); auto scale_b_lit = mm->add_literal(5.0);
scale_b = mm->add_instruction( auto scale_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale_b); migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale_b_lit);
auto zp_b = auto zp_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b); auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qa, qb); auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qa, qb);
auto scale = mm->add_literal(50.0); auto scale_a_mb = mm->add_instruction(
scale = mm->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}),
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), scale); scale_a_lit);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, scale); auto scale_b_mb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}),
scale_b_lit);
auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_a_mb, scale_b_mb);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, out_scale);
mm->add_return({r}); mm->add_return({r});
return p; return p;
}; };
...@@ -799,18 +802,15 @@ TEST_CASE(dot_half_1arg) ...@@ -799,18 +802,15 @@ TEST_CASE(dot_half_1arg)
auto x = mm->add_parameter("x", sa); auto x = mm->add_parameter("x", sa);
auto zp = mm->add_literal(static_cast<int8_t>(0)); auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(migraphx::literal({sa.type()}, {10.0})); auto scale_lit = mm->add_literal(migraphx::literal({sa.type()}, {10.0}));
scale = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), auto scale = mm->add_instruction(
scale); migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_lit);
zp = zp =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp);
auto qx = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale, zp); auto qx = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale, zp);
auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qx, qx); auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qx, qx);
auto dq_scale = mm->add_literal(migraphx::literal({sa.type()}, {100.0})); auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale, scale);
dq_scale = mm->add_instruction( auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, out_scale);
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}),
dq_scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, dq_scale);
mm->add_return({r}); mm->add_return({r});
return p; return p;
}; };
...@@ -852,9 +852,9 @@ TEST_CASE(conv_float) ...@@ -852,9 +852,9 @@ TEST_CASE(conv_float)
auto pw = mm->add_parameter("w", sw); auto pw = mm->add_parameter("w", sw);
auto zp = mm->add_literal(static_cast<int8_t>(0)); auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(10.0f); auto scale_lit = mm->add_literal(10.0f);
scale = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), auto scale = mm->add_instruction(
scale); migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), scale_lit);
zp = zp =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), zp); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), zp);
auto quant_x = mm->add_instruction(migraphx::make_op("quantizelinear"), px, scale, zp); auto quant_x = mm->add_instruction(migraphx::make_op("quantizelinear"), px, scale, zp);
...@@ -862,13 +862,11 @@ TEST_CASE(conv_float) ...@@ -862,13 +862,11 @@ TEST_CASE(conv_float)
auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), quant_x, quant_w); auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), quant_x, quant_w);
migraphx::shape sc{migraphx::shape::float_type, {4, 4, 1, 1}}; auto scale_mb = mm->add_instruction(
std::vector<float> vec(sc.elements(), 100.0f); migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}),
migraphx::shape s_scale{migraphx::shape::float_type, sc.lens()}; scale_lit);
auto d_scale = mm->add_literal(100.0f); auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_mb, scale_mb);
d_scale = mm->add_instruction( auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale);
migraphx::make_op("multibroadcast", {{"out_lens", {4, 4, 1, 1}}}), d_scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, d_scale);
mm->add_return({r}); mm->add_return({r});
return p; return p;
...@@ -931,19 +929,20 @@ TEST_CASE(conv_half) ...@@ -931,19 +929,20 @@ TEST_CASE(conv_half)
auto pw = mm->add_parameter("w", sw); auto pw = mm->add_parameter("w", sw);
auto zp = mm->add_literal(static_cast<int8_t>(0)); auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(migraphx::literal({sx.type()}, {10.0})); auto scale_lit = mm->add_literal(migraphx::literal({sx.type()}, {10.0}));
scale = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), auto scale = mm->add_instruction(
scale); migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), scale_lit);
zp = zp =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), zp); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), zp);
auto quant_x = mm->add_instruction(migraphx::make_op("quantizelinear"), px, scale, zp); auto quant_x = mm->add_instruction(migraphx::make_op("quantizelinear"), px, scale, zp);
auto quant_w = mm->add_instruction(migraphx::make_op("quantizelinear"), pw, scale, zp); auto quant_w = mm->add_instruction(migraphx::make_op("quantizelinear"), pw, scale, zp);
auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), quant_x, quant_w); auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), quant_x, quant_w);
auto d_scale = mm->add_literal(migraphx::literal({sx.type()}, {100.0})); auto scale_mb = mm->add_instruction(
d_scale = mm->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}),
migraphx::make_op("multibroadcast", {{"out_lens", {4, 4, 1, 1}}}), d_scale); scale_lit);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, d_scale); auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_mb, scale_mb);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale);
mm->add_return({r}); mm->add_return({r});
return p; return p;
...@@ -1187,9 +1186,9 @@ TEST_CASE(int8_subgraph) ...@@ -1187,9 +1186,9 @@ TEST_CASE(int8_subgraph)
migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), zp1); migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), zp1);
auto qb = then_mod->add_instruction(migraphx::make_op("quantizelinear"), b, sb, zpb); auto qb = then_mod->add_instruction(migraphx::make_op("quantizelinear"), b, sb, zpb);
auto qdot = then_mod->add_instruction(migraphx::make_op("quant_dot"), qa, qb); auto qdot = then_mod->add_instruction(migraphx::make_op("quant_dot"), qa, qb);
auto so = then_mod->add_literal(100.0f); auto s1_mb = then_mod->add_instruction(
so = then_mod->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), s1);
migraphx::make_op("multibroadcast", {{"out_lens", sout.lens()}}), so); auto so = then_mod->add_instruction(migraphx::make_op("mul"), s1_mb, s1_mb);
auto r = then_mod->add_instruction(migraphx::make_op("dequantizelinear"), qdot, so); auto r = then_mod->add_instruction(migraphx::make_op("dequantizelinear"), qdot, so);
then_mod->add_return({r}); then_mod->add_return({r});
...@@ -1199,23 +1198,24 @@ TEST_CASE(int8_subgraph) ...@@ -1199,23 +1198,24 @@ TEST_CASE(int8_subgraph)
auto w = mm->add_parameter("w", sw); auto w = mm->add_parameter("w", sw);
// else submod // else submod
auto* else_mod = p.create_module("If_6_else"); auto* else_mod = p.create_module("If_6_else");
auto sax = else_mod->add_literal(2.0f); auto sax_lit = else_mod->add_literal(2.0f);
auto zp = else_mod->add_literal(static_cast<int8_t>(0)); auto zp = else_mod->add_literal(static_cast<int8_t>(0));
sax = else_mod->add_instruction( auto sax = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sd.lens()}}), sax); migraphx::make_op("multibroadcast", {{"out_lens", sd.lens()}}), sax_lit);
auto zpx = else_mod->add_instruction( auto zpx = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sd.lens()}}), zp); migraphx::make_op("multibroadcast", {{"out_lens", sd.lens()}}), zp);
auto qx = else_mod->add_instruction(migraphx::make_op("quantizelinear"), x, sax, zpx); auto qx = else_mod->add_instruction(migraphx::make_op("quantizelinear"), x, sax, zpx);
auto ssw = else_mod->add_literal(1.66667f); auto ssw_lit = else_mod->add_literal(1.66667f);
ssw = else_mod->add_instruction( auto ssw = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sw.lens()}}), ssw); migraphx::make_op("multibroadcast", {{"out_lens", sw.lens()}}), ssw_lit);
auto zpw = else_mod->add_instruction( auto zpw = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sw.lens()}}), zp); migraphx::make_op("multibroadcast", {{"out_lens", sw.lens()}}), zp);
auto qw = else_mod->add_instruction(migraphx::make_op("quantizelinear"), w, ssw, zpw); auto qw = else_mod->add_instruction(migraphx::make_op("quantizelinear"), w, ssw, zpw);
auto qconv = else_mod->add_instruction(migraphx::make_op("quant_convolution"), qx, qw); auto qconv = else_mod->add_instruction(migraphx::make_op("quant_convolution"), qx, qw);
auto so1 = else_mod->add_literal(3.33333f); auto ssw_mb = else_mod->add_instruction(
so1 = else_mod->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", qconv->get_shape().lens()}}),
migraphx::make_op("multibroadcast", {{"out_lens", sout.lens()}}), so1); ssw_lit);
auto so1 = else_mod->add_instruction(migraphx::make_op("mul"), sax, ssw_mb);
auto r1 = else_mod->add_instruction(migraphx::make_op("dequantizelinear"), qconv, so1); auto r1 = else_mod->add_instruction(migraphx::make_op("dequantizelinear"), qconv, so1);
else_mod->add_return({r1}); else_mod->add_return({r1});
......
...@@ -44,20 +44,34 @@ void run_pass(migraphx::module& m) ...@@ -44,20 +44,34 @@ void run_pass(migraphx::module& m)
sqdq.apply(m); sqdq.apply(m);
} }
migraphx::instruction_ref add_quantize_op(migraphx::module& m, migraphx::instruction_ref broadcast_scale(migraphx::module& m,
const std::string& name,
migraphx::instruction_ref x,
migraphx::instruction_ref scale, migraphx::instruction_ref scale,
migraphx::instruction_ref shift) const std::vector<std::size_t>& out_lens,
std::size_t axis)
{ {
auto lens = x->get_shape().lens(); if(scale->get_shape().lens() == out_lens)
return scale;
migraphx::instruction_ref scale_mb; migraphx::instruction_ref scale_mb;
if(scale->get_shape().lens().front() == 1) auto scale_lens = scale->get_shape().lens();
if(scale_lens.front() == 1 and scale_lens.size() == 1)
scale_mb = scale_mb =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), scale); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), scale);
else else
scale_mb = m.add_instruction( scale_mb = m.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", lens}}), scale); migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", out_lens}}), scale);
return scale_mb;
}
migraphx::instruction_ref add_quantize_op(migraphx::module& m,
const std::string& name,
migraphx::instruction_ref x,
migraphx::instruction_ref scale,
migraphx::instruction_ref shift,
std::size_t q_axis = 1)
{
auto lens = x->get_shape().lens();
auto scale_mb = broadcast_scale(m, scale, lens, q_axis);
auto shift_mb = auto shift_mb =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), shift); m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), shift);
return m.add_instruction(migraphx::make_op(name), x, scale_mb, shift_mb); return m.add_instruction(migraphx::make_op(name), x, scale_mb, shift_mb);
...@@ -66,19 +80,26 @@ migraphx::instruction_ref add_quantize_op(migraphx::module& m, ...@@ -66,19 +80,26 @@ migraphx::instruction_ref add_quantize_op(migraphx::module& m,
migraphx::instruction_ref add_quantize_op(migraphx::module& m, migraphx::instruction_ref add_quantize_op(migraphx::module& m,
const std::string& name, const std::string& name,
migraphx::instruction_ref x, migraphx::instruction_ref x,
migraphx::instruction_ref scale) migraphx::instruction_ref scale,
std::size_t q_axis = 1)
{ {
auto lens = x->get_shape().lens(); auto lens = x->get_shape().lens();
migraphx::instruction_ref scale_mb; auto scale_mb = broadcast_scale(m, scale, lens, q_axis);
if(scale->get_shape().lens().front() == 1)
scale_mb =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), scale);
else
scale_mb = m.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", lens}}), scale);
return m.add_instruction(migraphx::make_op(name), x, scale_mb); return m.add_instruction(migraphx::make_op(name), x, scale_mb);
} }
migraphx::instruction_ref add_scale_mul(migraphx::module& m,
migraphx::instruction_ref scale1,
migraphx::instruction_ref scale2,
std::size_t axis1,
std::size_t axis2,
const std::vector<std::size_t>& out_lens)
{
auto scale1_mb = broadcast_scale(m, scale1, out_lens, axis1);
auto scale2_mb = broadcast_scale(m, scale2, out_lens, axis2);
return m.add_instruction(migraphx::make_op("mul"), scale1_mb, scale2_mb);
}
TEST_CASE(remove_qdq) TEST_CASE(remove_qdq)
{ {
migraphx::shape sh1{migraphx::shape::float_type, {100, 100}}; migraphx::shape sh1{migraphx::shape::float_type, {100, 100}};
...@@ -165,12 +186,144 @@ TEST_CASE(dot) ...@@ -165,12 +186,144 @@ TEST_CASE(dot)
auto t2 = m2.add_parameter("t2", sh2); auto t2 = m2.add_parameter("t2", sh2);
auto scale = m2.add_literal(0.5f); auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0}); auto zero = m2.add_literal(std::int8_t{0});
auto scale1 = m2.add_literal(0.25f);
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero); auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero); auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2); auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2);
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, scale1); auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot->get_shape().lens());
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale);
m2.add_return({d3});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(dot_multi_scale)
{
migraphx::shape sh1{migraphx::shape::float_type, {1280, 1000}};
migraphx::shape sh2{migraphx::shape::float_type, {1000, 1024}};
migraphx::shape sh3{migraphx::shape::float_type, {1280}};
migraphx::module m1;
{
auto t1 = m1.add_parameter("t1", sh1);
auto t2 = m1.add_parameter("t2", sh2);
auto scale1 = m1.add_literal(migraphx::generate_literal(sh3, 0));
auto scale2 = m1.add_literal(0.4f);
auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale1, zero, 0);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale1, zero, 0);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale2, zero, 1);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale2, zero, 1);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2);
m1.add_return({dot});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto scale1 = m2.add_literal(migraphx::generate_literal(sh3, 0));
auto scale2 = m2.add_literal(0.4f);
auto zero = m2.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale1, zero, 0);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale2, zero, 1);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2);
auto out_scale = add_scale_mul(m2, scale1, scale2, 0, 1, dot->get_shape().lens());
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale);
m2.add_return({d3});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(dot_broadcasted)
{
migraphx::shape sh1{migraphx::shape::float_type, {2, 1280, 1000}};
migraphx::shape sh2{migraphx::shape::float_type, {1000, 1024}};
migraphx::module m1;
{
auto t1 = m1.add_parameter("t1", sh1);
auto t2 = m1.add_parameter("t2", sh2);
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto d2_mb = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 1000, 1024}}}), d2);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2_mb);
m1.add_return({dot});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto q2_mb = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 1000, 1024}}}), q2);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2_mb);
auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot->get_shape().lens());
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale);
m2.add_return({d3});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(dot_transposed)
{
migraphx::shape sh1{migraphx::shape::float_type, {1280, 1000}};
migraphx::shape sh2{migraphx::shape::float_type, {1024, 1000}};
migraphx::module m1;
{
auto t1 = m1.add_parameter("t1", sh1);
auto t2 = m1.add_parameter("t2", sh2);
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto d2_t =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), d2);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2_t);
m1.add_return({dot});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto q2_t =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), q2);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2_t);
auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot->get_shape().lens());
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale);
m2.add_return({d3}); m2.add_return({d3});
} }
...@@ -178,6 +331,92 @@ TEST_CASE(dot) ...@@ -178,6 +331,92 @@ TEST_CASE(dot)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(dot_multi_scale_transposed_broadcasted)
{
migraphx::shape sh1{migraphx::shape::float_type, {2, 3, 1280, 1000}};
migraphx::shape sh2{migraphx::shape::float_type, {1024, 1000}};
migraphx::shape sh3{migraphx::shape::float_type, {1280}};
migraphx::shape sh4{migraphx::shape::float_type, {1024}};
migraphx::module m1;
{
auto t1 = m1.add_parameter("t1", sh1);
auto t2 = m1.add_parameter("t2", sh2);
auto scale1 = m1.add_literal(migraphx::generate_literal(sh3, 0));
auto scale2 = m1.add_literal(migraphx::generate_literal(sh4, 0));
auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale1, zero, 2);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale1, zero, 2);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale2, zero, 0);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale2, zero, 0);
auto d2_t =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), d2);
auto d2_mb = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 1000, 1024}}}), d2_t);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2_mb);
m1.add_return({dot});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto scale1 = m2.add_literal(migraphx::generate_literal(sh3, 0));
auto scale2 = m2.add_literal(migraphx::generate_literal(sh4, 0));
auto zero = m2.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale1, zero, 2);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale2, zero, 0);
auto q2_t =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), q2);
auto q2_mb = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 1000, 1024}}}), q2_t);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2_mb);
auto out_scale = add_scale_mul(m2, scale1, scale2, 2, 3, dot->get_shape().lens());
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale);
m2.add_return({d3});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(dot_multi_scale_unsupported_axis)
{
migraphx::shape sh1{migraphx::shape::float_type, {1280, 1000}};
migraphx::shape sh2{migraphx::shape::float_type, {1000, 1024}};
migraphx::shape sh3{migraphx::shape::float_type, {1000}};
migraphx::module m1;
{
auto t1 = m1.add_parameter("t1", sh1);
auto t2 = m1.add_parameter("t2", sh2);
auto scale1 = m1.add_literal(migraphx::generate_literal(sh3, 0));
auto scale2 = m1.add_literal(0.4f);
auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale1, zero, 1);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale1, zero, 1);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale2, zero, 1);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale2, zero, 1);
auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2);
m1.add_return({dot});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto dot = m2.add_instruction(migraphx::make_op("dot"), t1, t2);
m2.add_return({dot});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(dot_non_zero_point) TEST_CASE(dot_non_zero_point)
{ {
migraphx::shape sh1{migraphx::shape::float_type, {1280, 1000}}; migraphx::shape sh1{migraphx::shape::float_type, {1280, 1000}};
...@@ -274,12 +513,12 @@ TEST_CASE(dot_add) ...@@ -274,12 +513,12 @@ TEST_CASE(dot_add)
auto ab = m2.add_parameter("ab", sh3); auto ab = m2.add_parameter("ab", sh3);
auto scale = m2.add_literal(0.5f); auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0}); auto zero = m2.add_literal(std::int8_t{0});
auto scale1 = m2.add_literal(0.25f);
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero); auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero); auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2); auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2);
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, scale1); auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot->get_shape().lens());
auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale);
auto add = m2.add_instruction(migraphx::make_op("add"), d3, ab); auto add = m2.add_instruction(migraphx::make_op("add"), d3, ab);
m2.add_return({add}); m2.add_return({add});
} }
...@@ -320,7 +559,6 @@ TEST_CASE(conv) ...@@ -320,7 +559,6 @@ TEST_CASE(conv)
auto weights = m2.add_parameter("weights", s4); auto weights = m2.add_parameter("weights", s4);
auto scale = m2.add_literal(0.5f); auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0}); auto zero = m2.add_literal(std::int8_t{0});
auto scale1 = m2.add_literal(0.25f);
auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero); auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero);
auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution", auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution",
...@@ -331,7 +569,8 @@ TEST_CASE(conv) ...@@ -331,7 +569,8 @@ TEST_CASE(conv)
{"padding_mode", 0}}), {"padding_mode", 0}}),
q1, q1,
weights); weights);
auto d6 = add_quantize_op(m2, "dequantizelinear", c1, scale1); auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, c1->get_shape().lens());
auto d6 = add_quantize_op(m2, "dequantizelinear", c1, out_scale);
m2.add_return({d6}); m2.add_return({d6});
} }
...@@ -340,6 +579,60 @@ TEST_CASE(conv) ...@@ -340,6 +579,60 @@ TEST_CASE(conv)
} }
TEST_CASE(conv_multi_scale) TEST_CASE(conv_multi_scale)
{
migraphx::shape s4{migraphx::shape::int8_type, {1280, 320, 1, 1}};
migraphx::shape s7{migraphx::shape::float_type, {1, 320, 7, 7}};
migraphx::shape s8{migraphx::shape::float_type, {1280}};
migraphx::module m1;
{
auto input = m1.add_parameter("input", s7);
auto weights = m1.add_parameter("weights", s4);
auto w_scale = m1.add_literal(migraphx::generate_literal(s8, 0));
auto inp_scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});
auto d1 = add_quantize_op(m1, "dequantizelinear", weights, w_scale, zero, 0);
auto q1 = add_quantize_op(m1, "quantizelinear", input, inp_scale, zero);
auto d5 = add_quantize_op(m1, "dequantizelinear", q1, inp_scale, zero);
auto c1 = m1.add_instruction(migraphx::make_op("convolution",
{{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"dilation", {1, 1}},
{"group", 1},
{"padding_mode", 0}}),
d5,
d1);
m1.add_return({c1});
}
migraphx::module m2;
{
auto input = m2.add_parameter("input", s7);
auto weights = m2.add_parameter("weights", s4);
auto w_scale = m2.add_literal(migraphx::generate_literal(s8, 0));
auto inp_scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto q_inp = add_quantize_op(m2, "quantizelinear", input, inp_scale, zero);
auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution",
{{"padding", {0, 0, 0, 0}},
{"stride", {1, 1}},
{"dilation", {1, 1}},
{"group", 1},
{"padding_mode", 0}}),
q_inp,
weights);
auto out_scale = add_scale_mul(m2, inp_scale, w_scale, 1, 1, c1->get_shape().lens());
auto d1 = add_quantize_op(m2, "dequantizelinear", c1, out_scale);
m2.add_return({d1});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(conv_multi_scale_unsupported_axis)
{ {
migraphx::shape s4{migraphx::shape::int8_type, {1280, 320, 1, 1}}; migraphx::shape s4{migraphx::shape::int8_type, {1280, 320, 1, 1}};
migraphx::shape s7{migraphx::shape::float_type, {1, 320, 7, 7}}; migraphx::shape s7{migraphx::shape::float_type, {1, 320, 7, 7}};
...@@ -430,7 +723,6 @@ TEST_CASE(conv_bias_add) ...@@ -430,7 +723,6 @@ TEST_CASE(conv_bias_add)
auto scale = m2.add_literal(0.5f); auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0}); auto zero = m2.add_literal(std::int8_t{0});
auto zero32 = m2.add_literal(std::int32_t{0}); auto zero32 = m2.add_literal(std::int32_t{0});
auto scale1 = m2.add_literal(0.25f);
auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero32); auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero32);
auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero); auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero);
...@@ -442,7 +734,8 @@ TEST_CASE(conv_bias_add) ...@@ -442,7 +734,8 @@ TEST_CASE(conv_bias_add)
{"padding_mode", 0}}), {"padding_mode", 0}}),
q1, q1,
weights); weights);
auto d6 = add_quantize_op(m2, "dequantizelinear", c1, scale1); auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, c1->get_shape().lens());
auto d6 = add_quantize_op(m2, "dequantizelinear", c1, out_scale);
auto b1 = m2.add_instruction( auto b1 = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m2.add_instruction(migraphx::make_op("add"), d6, b1); auto a1 = m2.add_instruction(migraphx::make_op("add"), d6, b1);
...@@ -519,8 +812,6 @@ TEST_CASE(conv_pooling_dot) ...@@ -519,8 +812,6 @@ TEST_CASE(conv_pooling_dot)
auto scale = m2.add_literal(0.5f); auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0}); auto zero = m2.add_literal(std::int8_t{0});
auto zero32 = m2.add_literal(std::int32_t{0}); auto zero32 = m2.add_literal(std::int32_t{0});
auto scale1 = m2.add_literal(0.25f);
auto scale2 = m2.add_literal(0.25f);
auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero32); auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero32);
auto d3 = add_quantize_op(m2, "dequantizelinear", ab, scale, zero); auto d3 = add_quantize_op(m2, "dequantizelinear", ab, scale, zero);
...@@ -533,7 +824,8 @@ TEST_CASE(conv_pooling_dot) ...@@ -533,7 +824,8 @@ TEST_CASE(conv_pooling_dot)
{"padding_mode", 0}}), {"padding_mode", 0}}),
q1, q1,
weights); weights);
auto d5 = add_quantize_op(m2, "dequantizelinear", c1, scale1); auto out_scale1 = add_scale_mul(m2, scale, scale, 1, 1, c1->get_shape().lens());
auto d5 = add_quantize_op(m2, "dequantizelinear", c1, out_scale1);
auto bc1 = m2.add_instruction( auto bc1 = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2); migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2);
auto a1 = m2.add_instruction(migraphx::make_op("add"), d5, bc1); auto a1 = m2.add_instruction(migraphx::make_op("add"), d5, bc1);
...@@ -548,7 +840,8 @@ TEST_CASE(conv_pooling_dot) ...@@ -548,7 +840,8 @@ TEST_CASE(conv_pooling_dot)
auto fl = m2.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap); auto fl = m2.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap);
auto q4 = add_quantize_op(m2, "quantizelinear", fl, scale, zero); auto q4 = add_quantize_op(m2, "quantizelinear", fl, scale, zero);
auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q4, db); auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q4, db);
auto d9 = add_quantize_op(m2, "dequantizelinear", dot, scale2); auto out_scale2 = add_scale_mul(m2, scale, scale, 1, 0, dot->get_shape().lens());
auto d9 = add_quantize_op(m2, "dequantizelinear", dot, out_scale2);
auto mb1 = auto mb1 =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3); m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3);
auto a2 = m2.add_instruction(migraphx::make_op("add"), d9, mb1); auto a2 = m2.add_instruction(migraphx::make_op("add"), d9, mb1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment