Unverified Commit 0b5f33b6 authored by turneram's avatar turneram Committed by GitHub
Browse files

Optimize Q/DQ Format Pass (#889)

* Add operators, refactor parsers, add rewrite passes, add tests

* Add ref implementations

* Move broadcasting of scales and zero points to onnx parser

* Allow for x and zero_point to have different types in quantizelinear; fix zero_point default type

* Switch certain variables to int64_t

* Fix overflow in implicit constant conversion

* Remove operators.hpp from includes in tf_test.cpp

* Add conversion for int32 input to quantizelinear and add test case; remove operators.hpp from onnx_test.cpp includes

* Switch dequantizelinear math from int32 to float

* Remove changes to operators.hpp

* Simplify apply_quantizelinear

* Add verify test for int32 data

* Add rewrite_quantization back to CMakeLists

* Add passes to insert qdq after add_bias is applied, replace quant_ops, and remove remaining qdq pairs

* Renaming, refactoring, cleaning up code, adding formal test, and adding passes to targets

* Renaming, review comments, begin adding more specific tests

* Add more specific unit tests

* Fix failing test on CI

* Correct matcher and update qop rewriting, update tests and add more tests

* Update matcher, clean up simplify_qdq, tweak tests

* Add tests, remove pass from CPU target, update dot parameters, clean up simplify_qdq

* Fix correctness bug in ref q/dq implementations; edit gemm parser to make beta always 0.0

* Remove unused variables in onnx gemm tests
parent 4e3b2e3c
......@@ -51,6 +51,7 @@ add_library(migraphx
register_op.cpp
register_target.cpp
remap.cpp
simplify_qdq.cpp
rewrite_batchnorm.cpp
rewrite_pooling.cpp
rewrite_quantization.cpp
......
......@@ -263,6 +263,20 @@ matcher_result match_instruction(module& mod, instruction_ref ins, M&& m)
return result;
}
/// Find first instance of a matching instruction in a module
template <class M>
match::matcher_result find_match(module& modl, M&& m)
{
match::matcher_result result;
for(auto ins : iterator_for(modl))
{
result = match::match_instruction(modl, ins, m);
if(result.result != modl.end())
return result;
}
return result;
}
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES)
/// Find matches for an instruction in the module
......
......@@ -25,6 +25,7 @@ struct dequantizelinear
std::string name() const { return "dequantizelinear"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.same_dims();
return {shape::float_type, inputs[0].lens(), inputs[0].strides()};
}
......@@ -32,7 +33,7 @@ struct dequantizelinear
{
auto x = args.at(0);
auto x_scale = args.at(1);
std::vector<int8_t> zeros(output_shape.elements(), 0);
std::vector<int8_t> zeros(output_shape.bytes(), 0);
argument x_zero_point{{x.get_shape().type(), output_shape.lens()}, zeros.data()};
if(args.size() == 3)
{
......
......@@ -25,6 +25,7 @@ struct quantizelinear
std::string name() const { return "quantizelinear"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.same_dims();
if(inputs.size() == 3)
{
return {inputs[2].type(), inputs[0].lens(), inputs[0].strides()};
......@@ -36,7 +37,7 @@ struct quantizelinear
{
auto x = args.at(0);
auto y_scale = args.at(1);
std::vector<int8_t> zeros(output_shape.elements(), 0);
std::vector<int8_t> zeros(output_shape.bytes(), 0);
argument y_zero_point{output_shape, zeros.data()};
if(args.size() == 3)
{
......
#ifndef MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_QDQ_HPP
#define MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_QDQ_HPP
#include <string>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
/**
* Inserts quantized operators in place of dq->quantizable_op->q
* then removes remaining fake quantization (q->dq pairs)
*/
struct simplify_qdq
{
std::string name() const { return "simplify_qdq"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -59,9 +59,11 @@ struct parse_gemm : op_parser<parse_gemm>
auto l2 = (transb) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), args[1])
: args[1];
auto ret = info.add_instruction(make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), l1, l2);
if(args.size() == 3)
{
if(beta != 0.0f && args[2]->get_shape().elements() > 0)
if(not float_equal(beta, 0.0f) && args[2]->get_shape().elements() > 0)
{
auto out_lens = l1->get_shape().lens();
out_lens.back() = l2->get_shape().lens().back();
......@@ -80,12 +82,11 @@ struct parse_gemm : op_parser<parse_gemm>
beta_l3);
}
return info.add_instruction(
make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), l1, l2, beta_l3);
return info.add_instruction(make_op("add"), ret, beta_l3);
}
}
return info.add_instruction(make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), l1, l2);
return ret;
}
};
......
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/register_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
std::unordered_set<std::string> get_quantizable_op_names()
{
static std::unordered_set<std::string> s = {"convolution", "dot"};
return s;
}
MIGRAPHX_PRED_MATCHER(has_same_value, instruction_ref ins)
{
if(ins->name() != "@literal")
return false;
bool all_same = false;
ins->get_literal().visit([&](auto s) {
all_same = std::all_of(s.begin() + 1, s.end(), [&](const auto& scale) {
return float_equal(scale, s.front());
});
});
return all_same;
}
struct match_find_quantizable_ops
{
static auto dequantizelinear_op(const std::string& name, const std::string& scale)
{
return match::name("dequantizelinear")(
match::arg(0)(match::skip(match::name("quantizelinear"))(match::any().bind(name))),
match::arg(1)(match::skip_broadcasts(has_same_value().bind(scale))),
match::arg(2)(match::skip_broadcasts(match::all_of(match::has_value(0)))));
}
auto matcher() const
{
return match::name(get_quantizable_op_names())(
match::arg(0)(dequantizelinear_op("x1", "scale1")),
match::arg(1)(dequantizelinear_op("x2", "scale2")));
}
void apply(module& m, match::matcher_result r) const
{
auto qop = r.result;
auto q1 = r.instructions["x1"];
auto q2 = r.instructions["x2"];
auto scale1 = r.instructions["scale1"];
auto scale2 = r.instructions["scale2"];
// Only INT8 type currently supported
if(q1->get_shape().type() != migraphx::shape::int8_type or
q2->get_shape().type() != migraphx::shape::int8_type)
return;
double scale;
visit_all(scale1->get_literal(), scale2->get_literal())(
[&](const auto s1, const auto s2) { scale = s1.front() * s2.front(); });
auto qop_args = qop->inputs();
qop_args.at(0) = q1;
qop_args.at(1) = q2;
instruction_ref dq;
instruction_ref dq_scale;
instruction_ref zero_point;
if(qop->name() == "convolution")
{
auto conv_val = qop->get_operator().to_value();
dq = m.insert_instruction(
qop, migraphx::make_op("quant_convolution", conv_val), qop_args);
}
else if(qop->name() == "dot")
{
auto dot_op = any_cast<op::dot>(qop->get_operator());
if(!(float_equal(dot_op.alpha, 1.0f) and float_equal(dot_op.beta, 0.0f)))
return;
if(qop_args.size() == 3)
qop_args.pop_back();
dq = m.insert_instruction(
qop, migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qop_args);
}
dq_scale = m.add_literal(static_cast<float>(scale));
auto lens = dq->get_shape().lens();
auto scale_mb =
m.insert_instruction(qop, make_op("multibroadcast", {{"output_lens", lens}}), dq_scale);
dq = m.insert_instruction(qop, make_op("dequantizelinear"), dq, scale_mb);
m.replace_instruction(qop, dq);
}
};
bool compare_literals(instruction_ref ins1, instruction_ref ins2)
{
if(ins1->name() == "broadcast" or ins1->name() == "multibroadcast")
ins1 = ins1->inputs().front();
auto x = ins1->eval();
if(x.empty())
return false;
auto literal1 = ins1->get_literal();
if(ins2->name() == "broadcast" or ins2->name() == "multibroadcast")
ins2 = ins2->inputs().front();
auto y = ins2->eval();
if(y.empty())
return false;
auto literal2 = ins2->get_literal();
bool diff_shapes_equal_vals = false;
visit_all(ins1->get_literal(), ins2->get_literal())([&](const auto l1, const auto l2) {
diff_shapes_equal_vals =
std::all_of(
l1.begin() + 1, l1.end(), [&](auto v) { return float_equal(v, l1.front()); }) and
std::all_of(l2.begin(), l2.end(), [&](auto v) { return float_equal(v, l1.front()); });
});
return (x == y) or diff_shapes_equal_vals;
}
void remove_qdq_pairs(module& m)
{
for(auto ins : iterator_for(m))
{
auto args = ins->inputs();
for(auto&& arg : args)
{
if(arg->name() == "dequantizelinear")
{
auto q = arg->inputs().front();
if((q->name() == "quantizelinear") and
compare_literals(arg->inputs().at(1), q->inputs().at(1)) and
compare_literals(arg->inputs().at(2), q->inputs().at(2)))
{
instruction::replace_argument(ins, arg, q->inputs().front());
}
}
}
}
}
void simplify_qdq::apply(module& m) const
{
match::find_matches(m, match_find_quantizable_ops{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
remove_qdq_pairs(m);
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -22,6 +22,7 @@
#include <migraphx/schedule.hpp>
#include <migraphx/memory_coloring.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/preallocate_param.hpp>
#include <migraphx/cpu/fuse_ops.hpp>
......
......@@ -24,6 +24,7 @@
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/schedule.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/gpu/allocation_model.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp>
......@@ -60,6 +61,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
normalize_ops{},
decompose{},
dead_code_elimination{},
simplify_qdq{},
rewrite_quantization{},
dead_code_elimination{},
eliminate_data_type{unsupported_types, shape::type_t::float_type},
......
......@@ -7,19 +7,6 @@ namespace match = migraphx::match;
MIGRAPHX_PRED_MATCHER(throws, migraphx::instruction_ref) { MIGRAPHX_THROW("Matcher throws"); }
template <class M>
migraphx::match::matcher_result find_match(migraphx::module& modl, M&& m)
{
migraphx::match::matcher_result result;
for(auto ins : migraphx::iterator_for(modl))
{
result = migraphx::match::match_instruction(modl, ins, m);
if(result.result != modl.end())
return result;
}
return result;
}
void match1()
{
migraphx::module mm;
......
......@@ -1295,14 +1295,16 @@ TEST_CASE(gemm_test)
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), t_a);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {1, 0}}}), l1);
auto dot =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), t_a, t1);
auto b_l = mm->add_literal(beta);
auto l2_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", {7, 11}}}), l2);
auto b_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", l2_b->get_shape().lens()}}), b_l);
auto l2_bb = mm->add_instruction(migraphx::make_op("mul"), l2_b, b_b);
mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), t_a, t1, l2_bb);
mm->add_instruction(migraphx::make_op("add"), dot, l2_bb);
auto prog = optimize_onnx("gemm_test.onnx");
EXPECT(p == prog);
}
......@@ -1320,14 +1322,15 @@ TEST_CASE(gemm_ex_test)
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), t_a);
auto dot =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), t_a, l1);
auto b_l = mm->add_literal(beta);
auto b_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", l2->get_shape().lens()}}), b_l);
auto l2_b = mm->add_instruction(migraphx::make_op("mul"), l2, b_b);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), t_a, l1, l2_b);
mm->add_instruction(migraphx::make_op("add"), dot, l2_b);
auto prog = optimize_onnx("gemm_ex_test.onnx");
EXPECT(p == prog);
}
......@@ -1345,17 +1348,17 @@ TEST_CASE(gemm_ex_brcst_test)
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), t_a);
auto dot =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), t_a, l1);
auto b_l = mm->add_literal(beta);
auto l2_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", out_lens}}), l2);
auto b_b = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", l2_b->get_shape().lens()}}), b_l);
auto l2_bb = mm->add_instruction(migraphx::make_op("mul"), l2_b, b_b);
mm->add_instruction(
migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), t_a, l1, l2_bb);
mm->add_instruction(migraphx::make_op("add"), dot, l2_bb);
auto prog = optimize_onnx("gemm_ex_brcst_test.onnx");
EXPECT(p == prog);
}
......@@ -1374,6 +1377,8 @@ TEST_CASE(gemm_half_test)
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {0, 1, 3, 2}}}), t_a);
std::vector<std::size_t> lens = {1, 1, 6, 7};
auto dot =
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), t_a, l1);
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"output_lens", lens}}), l2);
l2 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l2);
......@@ -1383,10 +1388,9 @@ TEST_CASE(gemm_half_test)
auto l2_b = mm->add_instruction(migraphx::make_op("mul"), l2, b_b);
l2_b = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), l2_b);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), t_a, l1, l2_b);
mm->add_instruction(migraphx::make_op("add"), dot, l2_b);
auto prog = optimize_onnx("gemm_half_test.onnx");
EXPECT(p == prog);
}
......@@ -1806,7 +1810,7 @@ TEST_CASE(initializer_not_an_input)
std::vector<float> w = {1, 2, 3, 4, 5, 6, 7, 8};
auto l1 = mm->add_literal(migraphx::literal({migraphx::shape::float_type, {2, 4}}, w));
auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 2}});
mm->add_instruction(migraphx::make_op("dot"), l0, l1);
mm->add_instruction(migraphx::make_op("dot", {{"alpha", 1.0f}, {"beta", 0.0f}}), l0, l1);
auto prog = optimize_onnx("initializer_not_an_input.onnx");
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment