Unverified Commit 0039b11a authored by shivadbhavsar's avatar shivadbhavsar Committed by GitHub
Browse files

Support per-axis quantization (#2390)

Reworked the simplify_qdq pass to support:

Per-axis quantization (ie. allow 1D scales and zero points)
Allow broadcast and transpose ops between dq and quant_op
parent b2a40ea6
...@@ -591,6 +591,19 @@ MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins) ...@@ -591,6 +591,19 @@ MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins)
ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; }); ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; });
} }
MIGRAPHX_PRED_MATCHER(has_same_value, instruction_ref ins)
{
if(ins->name() != "@literal")
return false;
bool all_same = false;
ins->get_literal().visit([&](auto s) {
all_same = std::all_of(s.begin() + 1, s.end(), [&](const auto& scale) {
return float_equal(scale, s.front());
});
});
return all_same;
}
MIGRAPHX_BASIC_MATCHER(output, const matcher_context&, instruction_ref ins) MIGRAPHX_BASIC_MATCHER(output, const matcher_context&, instruction_ref ins)
{ {
if(ins->outputs().size() == 1) if(ins->outputs().size() == 1)
...@@ -844,6 +857,12 @@ auto skip_broadcasts_converts(Ms... ms) ...@@ -844,6 +857,12 @@ auto skip_broadcasts_converts(Ms... ms)
return skip(name("broadcast", "multibroadcast", "contiguous", "convert"))(ms...); return skip(name("broadcast", "multibroadcast", "contiguous", "convert"))(ms...);
} }
template <class... Ms>
auto skip_broadcasts_transposes_contiguous(Ms... ms)
{
return skip(name("broadcast", "multibroadcast", "contiguous", "transpose"))(ms...);
}
template <class T> template <class T>
inline auto has_value(T x, float tolerance = 1e-6) inline auto has_value(T x, float tolerance = 1e-6)
{ {
......
...@@ -45,77 +45,145 @@ std::unordered_set<std::string> get_quantizable_op_names() ...@@ -45,77 +45,145 @@ std::unordered_set<std::string> get_quantizable_op_names()
return s; return s;
} }
MIGRAPHX_PRED_MATCHER(has_same_value, instruction_ref ins) struct match_find_quantizable_ops
{ {
if(ins->name() != "@literal") static bool
return false; is_valid_scale(instruction_ref scale, std::vector<std::size_t> lens, std::size_t axis)
bool all_same = false; {
ins->get_literal().visit([&](auto s) { return scale->get_shape().scalar() or scale->get_shape().elements() == lens.at(axis);
all_same = std::all_of(s.begin() + 1, s.end(), [&](const auto& scale) { }
return float_equal(scale, s.front());
static bool is_valid_zero_point(instruction_ref zp)
{
if(not zp->can_eval())
return false;
bool all_zeros = false;
zp->eval().visit([&](auto z) {
all_zeros =
std::all_of(z.begin(), z.end(), [&](auto val) { return float_equal(val, 0); });
}); });
}); return all_zeros;
return all_same; }
}
struct match_find_quantizable_ops static auto
{ scale_broadcast_op(instruction_ref scale, std::vector<std::size_t> lens, std::size_t axis)
{
if(scale->get_shape().scalar())
{
return migraphx::make_op("multibroadcast", {{"out_lens", lens}});
}
else
{
return migraphx::make_op("broadcast", {{"out_lens", lens}, {"axis", axis}});
}
}
static auto dequantizelinear_op(const std::string& name, const std::string& scale) // Helper function to insert quantized versions of any broadcasts and transpose ops that
// occur between dequantizelinear and the quantized op
static auto
propagate_quantized_ins(module& m, const instruction_ref dqins, const instruction_ref qop)
{
auto qinp = dqins->inputs().front();
auto next_ins = dqins;
while(next_ins != qop)
{
if(next_ins->name() != "dequantizelinear")
{
qinp = m.insert_instruction(qop, next_ins->get_operator(), qinp);
}
next_ins = next_ins->outputs().front();
}
return qinp;
}
static auto dequantizelinear_op(const std::string& scale, const std::string& zp)
{ {
return match::name("dequantizelinear")( return match::name("dequantizelinear")(
match::arg(0)(match::skip(match::name("quantizelinear"))(match::any().bind(name))), match::arg(0)(match::skip(match::name("quantizelinear"))(match::any())),
match::arg(1)(match::skip_broadcasts(has_same_value().bind(scale))), match::arg(1)(match::skip_broadcasts(match::is_constant().bind(scale))),
match::arg(2)(match::skip_broadcasts(match::all_of(match::has_value(0))))); match::arg(2)(match::skip_broadcasts(match::is_constant().bind(zp))));
} }
auto matcher() const auto matcher() const
{ {
return match::name(get_quantizable_op_names())( return match::name(get_quantizable_op_names())(
match::arg(0)(dequantizelinear_op("x1", "scale1")), match::arg(0)(match::skip_broadcasts_transposes_contiguous(
match::arg(1)(dequantizelinear_op("x2", "scale2"))); dequantizelinear_op("scale1", "zp1").bind("dq1"))),
match::arg(1)(match::skip_broadcasts_transposes_contiguous(
dequantizelinear_op("scale2", "zp2").bind("dq2"))));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto qop = r.result; auto qop = r.result;
auto q1 = r.instructions["x1"]; auto dq1 = r.instructions["dq1"];
auto q2 = r.instructions["x2"]; auto dq2 = r.instructions["dq2"];
auto scale1 = r.instructions["scale1"]; auto scale1 = r.instructions["scale1"];
auto scale2 = r.instructions["scale2"]; auto scale2 = r.instructions["scale2"];
auto zp1 = r.instructions["zp1"];
auto zp2 = r.instructions["zp2"];
// Only INT8 type currently supported // Only INT8 type currently supported
if(q1->get_shape().type() != migraphx::shape::int8_type or if(dq1->inputs().front()->get_shape().type() != migraphx::shape::int8_type or
q2->get_shape().type() != migraphx::shape::int8_type) dq2->inputs().front()->get_shape().type() != migraphx::shape::int8_type)
return; return;
double scale; // Only symmetric quantization supported (ie. non-zero zero_points not allowed)
visit_all(scale1->get_literal(), scale2->get_literal())( if(not(is_valid_zero_point(zp1) and is_valid_zero_point(zp2)))
[&](const auto s1, const auto s2) { scale = s1.front() * s2.front(); }); return;
// Only support scalar and 1D scales
if(scale1->get_shape().lens().size() != 1 or scale2->get_shape().lens().size() != 1)
return;
// Propagate q1 and q2 through any broadcasts and transposes before qop
auto qop_args = qop->inputs(); auto qop_args = qop->inputs();
qop_args.at(0) = q1; qop_args.at(0) = propagate_quantized_ins(m, dq1, qop);
qop_args.at(1) = q2; qop_args.at(1) = propagate_quantized_ins(m, dq2, qop);
instruction_ref dq; instruction_ref dq;
instruction_ref dq_scale; instruction_ref out_scale;
instruction_ref zero_point; instruction_ref zero_point;
if(qop->name() == "convolution") if(qop->name() == "convolution")
{ {
auto conv_val = qop->get_operator().to_value(); auto conv_val = qop->get_operator().to_value();
dq = m.insert_instruction( dq = m.insert_instruction(
qop, migraphx::make_op("quant_convolution", conv_val), qop_args); qop, migraphx::make_op("quant_convolution", conv_val), qop_args);
auto out_lens = dq->get_shape().lens();
// Input scale should always be scalar and weight scale can be scalar or 1D of the
// same lens as the output channel dim (dim 1 in the output)
if(not(is_valid_scale(scale1, out_lens, 1) and is_valid_scale(scale2, out_lens, 1)))
return;
auto s1_bcast =
m.insert_instruction(qop, scale_broadcast_op(scale1, out_lens, 1), scale1);
auto s2_bcast =
m.insert_instruction(qop, scale_broadcast_op(scale2, out_lens, 1), scale2);
out_scale = m.insert_instruction(qop, migraphx::make_op("mul"), s1_bcast, s2_bcast);
} }
else if(qop->name() == "dot") else if(qop->name() == "dot")
{ {
dq = m.insert_instruction(qop, migraphx::make_op("quant_dot"), qop_args); dq = m.insert_instruction(qop, migraphx::make_op("quant_dot"), qop_args);
auto out_lens = dq->get_shape().lens();
// For (..., M, N) x (..., N, K) dot, only support cases where quantization axis is M
// for input1 and K for input 2
if(not(is_valid_scale(scale1, out_lens, out_lens.size() - 2) and
is_valid_scale(scale2, out_lens, out_lens.size() - 1)))
return;
auto s1_bcast = m.insert_instruction(
qop, scale_broadcast_op(scale1, out_lens, out_lens.size() - 2), scale1);
auto s2_bcast = m.insert_instruction(
qop, scale_broadcast_op(scale2, out_lens, out_lens.size() - 1), scale2);
out_scale = m.insert_instruction(qop, migraphx::make_op("mul"), s1_bcast, s2_bcast);
} }
auto ins_type = qop->get_shape().type();
dq_scale = m.add_literal(literal({ins_type}, {scale}));
auto lens = dq->get_shape().lens(); dq = m.insert_instruction(qop, make_op("dequantizelinear"), dq, out_scale);
auto scale_mb =
m.insert_instruction(qop, make_op("multibroadcast", {{"out_lens", lens}}), dq_scale);
dq = m.insert_instruction(qop, make_op("dequantizelinear"), dq, scale_mb);
m.replace_instruction(qop, dq); m.replace_instruction(qop, dq);
} }
}; };
......
...@@ -636,13 +636,12 @@ TEST_CASE(dot_float) ...@@ -636,13 +636,12 @@ TEST_CASE(dot_float)
migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale); migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale);
auto zp_b = auto zp_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp);
auto quant_b = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b); auto quant_b = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto quant = mm->add_instruction(migraphx::make_op("quant_dot"), quant_a, quant_b); auto quant = mm->add_instruction(migraphx::make_op("quant_dot"), quant_a, quant_b);
std::vector<float> vec(sc.elements(), 100.0f); auto scale_mb = mm->add_instruction(
auto dc = mm->add_literal(100.0f); migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), scale);
auto mdc = auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_mb, scale_mb);
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sc.lens()}}), dc); auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, mdc);
mm->add_return({r}); mm->add_return({r});
return p; return p;
...@@ -717,24 +716,28 @@ TEST_CASE(dot_double_2args) ...@@ -717,24 +716,28 @@ TEST_CASE(dot_double_2args)
auto pa = mm->add_parameter("a", sa); auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb); auto pb = mm->add_parameter("b", sb);
auto scale_a = mm->add_literal(10.0); auto scale_a_lit = mm->add_literal(10.0);
auto zp = mm->add_literal(static_cast<int8_t>(0)); auto zp = mm->add_literal(static_cast<int8_t>(0));
scale_a = mm->add_instruction( auto scale_a = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_a); migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_a_lit);
auto zp_a = auto zp_a =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp);
auto qa = mm->add_instruction(migraphx::make_op("quantizelinear"), pa, scale_a, zp_a); auto qa = mm->add_instruction(migraphx::make_op("quantizelinear"), pa, scale_a, zp_a);
auto scale_b = mm->add_literal(5.0); auto scale_b_lit = mm->add_literal(5.0);
scale_b = mm->add_instruction( auto scale_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale_b); migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale_b_lit);
auto zp_b = auto zp_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp);
auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b); auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b);
auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qa, qb); auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qa, qb);
auto scale = mm->add_literal(50.0); auto scale_a_mb = mm->add_instruction(
scale = mm->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}),
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), scale); scale_a_lit);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, scale); auto scale_b_mb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}),
scale_b_lit);
auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_a_mb, scale_b_mb);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, out_scale);
mm->add_return({r}); mm->add_return({r});
return p; return p;
}; };
...@@ -798,19 +801,16 @@ TEST_CASE(dot_half_1arg) ...@@ -798,19 +801,16 @@ TEST_CASE(dot_half_1arg)
migraphx::shape sa{migraphx::shape::half_type, {9, 9}}; migraphx::shape sa{migraphx::shape::half_type, {9, 9}};
auto x = mm->add_parameter("x", sa); auto x = mm->add_parameter("x", sa);
auto zp = mm->add_literal(static_cast<int8_t>(0)); auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(migraphx::literal({sa.type()}, {10.0})); auto scale_lit = mm->add_literal(migraphx::literal({sa.type()}, {10.0}));
scale = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), auto scale = mm->add_instruction(
scale); migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_lit);
zp = zp =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp);
auto qx = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale, zp); auto qx = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale, zp);
auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qx, qx); auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qx, qx);
auto dq_scale = mm->add_literal(migraphx::literal({sa.type()}, {100.0})); auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale, scale);
dq_scale = mm->add_instruction( auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, out_scale);
migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}),
dq_scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, dq_scale);
mm->add_return({r}); mm->add_return({r});
return p; return p;
}; };
...@@ -851,10 +851,10 @@ TEST_CASE(conv_float) ...@@ -851,10 +851,10 @@ TEST_CASE(conv_float)
auto px = mm->add_parameter("x", sx); auto px = mm->add_parameter("x", sx);
auto pw = mm->add_parameter("w", sw); auto pw = mm->add_parameter("w", sw);
auto zp = mm->add_literal(static_cast<int8_t>(0)); auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(10.0f); auto scale_lit = mm->add_literal(10.0f);
scale = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), auto scale = mm->add_instruction(
scale); migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), scale_lit);
zp = zp =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), zp); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), zp);
auto quant_x = mm->add_instruction(migraphx::make_op("quantizelinear"), px, scale, zp); auto quant_x = mm->add_instruction(migraphx::make_op("quantizelinear"), px, scale, zp);
...@@ -862,13 +862,11 @@ TEST_CASE(conv_float) ...@@ -862,13 +862,11 @@ TEST_CASE(conv_float)
auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), quant_x, quant_w); auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), quant_x, quant_w);
migraphx::shape sc{migraphx::shape::float_type, {4, 4, 1, 1}}; auto scale_mb = mm->add_instruction(
std::vector<float> vec(sc.elements(), 100.0f); migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}),
migraphx::shape s_scale{migraphx::shape::float_type, sc.lens()}; scale_lit);
auto d_scale = mm->add_literal(100.0f); auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_mb, scale_mb);
d_scale = mm->add_instruction( auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale);
migraphx::make_op("multibroadcast", {{"out_lens", {4, 4, 1, 1}}}), d_scale);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, d_scale);
mm->add_return({r}); mm->add_return({r});
return p; return p;
...@@ -930,20 +928,21 @@ TEST_CASE(conv_half) ...@@ -930,20 +928,21 @@ TEST_CASE(conv_half)
auto px = mm->add_parameter("x", sx); auto px = mm->add_parameter("x", sx);
auto pw = mm->add_parameter("w", sw); auto pw = mm->add_parameter("w", sw);
auto zp = mm->add_literal(static_cast<int8_t>(0)); auto zp = mm->add_literal(static_cast<int8_t>(0));
auto scale = mm->add_literal(migraphx::literal({sx.type()}, {10.0})); auto scale_lit = mm->add_literal(migraphx::literal({sx.type()}, {10.0}));
scale = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), auto scale = mm->add_instruction(
scale); migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), scale_lit);
zp = zp =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), zp); mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), zp);
auto quant_x = mm->add_instruction(migraphx::make_op("quantizelinear"), px, scale, zp); auto quant_x = mm->add_instruction(migraphx::make_op("quantizelinear"), px, scale, zp);
auto quant_w = mm->add_instruction(migraphx::make_op("quantizelinear"), pw, scale, zp); auto quant_w = mm->add_instruction(migraphx::make_op("quantizelinear"), pw, scale, zp);
auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), quant_x, quant_w); auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), quant_x, quant_w);
auto d_scale = mm->add_literal(migraphx::literal({sx.type()}, {100.0})); auto scale_mb = mm->add_instruction(
d_scale = mm->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}),
migraphx::make_op("multibroadcast", {{"out_lens", {4, 4, 1, 1}}}), d_scale); scale_lit);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, d_scale); auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_mb, scale_mb);
auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale);
mm->add_return({r}); mm->add_return({r});
return p; return p;
...@@ -1185,12 +1184,12 @@ TEST_CASE(int8_subgraph) ...@@ -1185,12 +1184,12 @@ TEST_CASE(int8_subgraph)
migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), s1); migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), s1);
auto zpb = then_mod->add_instruction( auto zpb = then_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), zp1); migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), zp1);
auto qb = then_mod->add_instruction(migraphx::make_op("quantizelinear"), b, sb, zpb); auto qb = then_mod->add_instruction(migraphx::make_op("quantizelinear"), b, sb, zpb);
auto qdot = then_mod->add_instruction(migraphx::make_op("quant_dot"), qa, qb); auto qdot = then_mod->add_instruction(migraphx::make_op("quant_dot"), qa, qb);
auto so = then_mod->add_literal(100.0f); auto s1_mb = then_mod->add_instruction(
so = then_mod->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), s1);
migraphx::make_op("multibroadcast", {{"out_lens", sout.lens()}}), so); auto so = then_mod->add_instruction(migraphx::make_op("mul"), s1_mb, s1_mb);
auto r = then_mod->add_instruction(migraphx::make_op("dequantizelinear"), qdot, so); auto r = then_mod->add_instruction(migraphx::make_op("dequantizelinear"), qdot, so);
then_mod->add_return({r}); then_mod->add_return({r});
migraphx::shape sd{migraphx::shape::float_type, {2, 2, 4, 6}}; migraphx::shape sd{migraphx::shape::float_type, {2, 2, 4, 6}};
...@@ -1199,24 +1198,25 @@ TEST_CASE(int8_subgraph) ...@@ -1199,24 +1198,25 @@ TEST_CASE(int8_subgraph)
auto w = mm->add_parameter("w", sw); auto w = mm->add_parameter("w", sw);
// else submod // else submod
auto* else_mod = p.create_module("If_6_else"); auto* else_mod = p.create_module("If_6_else");
auto sax = else_mod->add_literal(2.0f); auto sax_lit = else_mod->add_literal(2.0f);
auto zp = else_mod->add_literal(static_cast<int8_t>(0)); auto zp = else_mod->add_literal(static_cast<int8_t>(0));
sax = else_mod->add_instruction( auto sax = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sd.lens()}}), sax); migraphx::make_op("multibroadcast", {{"out_lens", sd.lens()}}), sax_lit);
auto zpx = else_mod->add_instruction( auto zpx = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sd.lens()}}), zp); migraphx::make_op("multibroadcast", {{"out_lens", sd.lens()}}), zp);
auto qx = else_mod->add_instruction(migraphx::make_op("quantizelinear"), x, sax, zpx); auto qx = else_mod->add_instruction(migraphx::make_op("quantizelinear"), x, sax, zpx);
auto ssw = else_mod->add_literal(1.66667f); auto ssw_lit = else_mod->add_literal(1.66667f);
ssw = else_mod->add_instruction( auto ssw = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sw.lens()}}), ssw); migraphx::make_op("multibroadcast", {{"out_lens", sw.lens()}}), ssw_lit);
auto zpw = else_mod->add_instruction( auto zpw = else_mod->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", sw.lens()}}), zp); migraphx::make_op("multibroadcast", {{"out_lens", sw.lens()}}), zp);
auto qw = else_mod->add_instruction(migraphx::make_op("quantizelinear"), w, ssw, zpw); auto qw = else_mod->add_instruction(migraphx::make_op("quantizelinear"), w, ssw, zpw);
auto qconv = else_mod->add_instruction(migraphx::make_op("quant_convolution"), qx, qw); auto qconv = else_mod->add_instruction(migraphx::make_op("quant_convolution"), qx, qw);
auto so1 = else_mod->add_literal(3.33333f); auto ssw_mb = else_mod->add_instruction(
so1 = else_mod->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", qconv->get_shape().lens()}}),
migraphx::make_op("multibroadcast", {{"out_lens", sout.lens()}}), so1); ssw_lit);
auto r1 = else_mod->add_instruction(migraphx::make_op("dequantizelinear"), qconv, so1); auto so1 = else_mod->add_instruction(migraphx::make_op("mul"), sax, ssw_mb);
auto r1 = else_mod->add_instruction(migraphx::make_op("dequantizelinear"), qconv, so1);
else_mod->add_return({r1}); else_mod->add_return({r1});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment