Unverified Commit b45f7239 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

qdq for quantization and include subgraph (#891)



Add operators, refactor parsers, add rewrite passes, add tests
Add ref implementations
Move broadcasting of scales and zero points to onnx parser
Allow for x and zero_point to have different types in quantizelinear; fix zero_point default type
fp16 and fp8 quantization to include subgraph and parameters
fix unit test to use qdq operators for int8 quantization
Co-authored-by: default avatarturneram <alturner@amd.com>
parent fdaa21ee
...@@ -47,6 +47,8 @@ add_library(migraphx ...@@ -47,6 +47,8 @@ add_library(migraphx
program.cpp program.cpp
propagate_constant.cpp propagate_constant.cpp
quantization.cpp quantization.cpp
quantize_fp16.cpp
quantize_int8.cpp
reduce_dims.cpp reduce_dims.cpp
register_op.cpp register_op.cpp
register_target.cpp register_target.cpp
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/context.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -29,7 +30,9 @@ struct capture ...@@ -29,7 +30,9 @@ struct capture
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); } shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(const shape&, std::vector<argument> args) const // the context argument is added to prevent the op from be eliminated by
// constant propagation
argument compute(context&, const shape&, const std::vector<argument>& args) const
{ {
if(f) if(f)
{ {
...@@ -42,6 +45,8 @@ struct capture ...@@ -42,6 +45,8 @@ struct capture
return args.front(); return args.front();
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
} // namespace op } // namespace op
......
...@@ -26,7 +26,7 @@ struct dequantizelinear ...@@ -26,7 +26,7 @@ struct dequantizelinear
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.same_dims(); check_shapes{inputs, *this}.same_dims();
return {shape::float_type, inputs[0].lens(), inputs[0].strides()}; return {inputs[1].type(), inputs[0].lens(), inputs[0].strides()};
} }
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
......
...@@ -17,32 +17,10 @@ struct program; ...@@ -17,32 +17,10 @@ struct program;
void quantize_fp16(program& prog, const std::vector<std::string>& ins_names = {"all"}); void quantize_fp16(program& prog, const std::vector<std::string>& ins_names = {"all"});
// insert the capture operator for the inputs of each operator to be quantized
// to int8
std::size_t capture_arguments(program& prog,
const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func);
std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments_impl(program& prog, const target& t, const std::vector<std::string>& ins_names);
template <class T>
std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments(program& prog, T&& t, const std::vector<std::string>& ins_names)
{
static_assert(std::is_same<std::remove_cv_t<std::remove_reference_t<T>>, target>{} &&
std::is_lvalue_reference<T>{},
"Dangling reference to target!");
return capture_arguments_impl(prog, t, ins_names);
}
void quantize_int8(program& prog, void quantize_int8(program& prog,
const target& t, const target& t,
const std::vector<parameter_map>& calibration, const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names = {"dot", "convolution"}); const std::vector<std::string>& ins_names = {"dot", "convolution"});
void quantize_int8_impl(program& prog,
const std::vector<std::pair<float, float>>& quant_params,
const std::vector<std::string>& ins_names);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_FP16_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_FP16_HPP
#include <string>
#include <vector>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
struct module;
/**
* quantize a program to fp16
*/
struct quantize_fp16_pass
{
std::vector<std::string> ins_names = {"all"};
std::string name() const { return "quantize_fp16"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZE_INT8_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_INT8_HPP
#include <string>
#include <vector>
#include <functional>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
struct module;
/**
* capture inputs of operators to be quantized to int8
*/
struct capture_arguments_pass
{
std::vector<std::string> ins_names = {"dot", "convolution"};
std::function<void(std::size_t, std::vector<argument>)> f{};
std::size_t* param_index = nullptr;
std::string name() const { return "capture_arguments"; }
void apply(module& m) const;
};
/**
* quantize a program to int8
*/
struct quantize_int8_pass
{
std::vector<std::string> ins_names = {"dot", "convolution"};
std::vector<std::pair<float, float>> quant_params;
std::string name() const { return "quantize_int8"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -763,6 +763,22 @@ void program::remove_module(const std::string& name) ...@@ -763,6 +763,22 @@ void program::remove_module(const std::string& name)
impl->modules.at(name).end(), impl->modules.at(name).end(),
[&](auto&& ins) { return references_instruction(impl->modules, ins, name); }) && [&](auto&& ins) { return references_instruction(impl->modules, ins, name); }) &&
"Instruction referenced in another module"); "Instruction referenced in another module");
// if an instruction has an input out side of the current module, need to remove
// the instruction from its input's outputs
auto& mod = impl->modules.at(name);
for(auto ins : iterator_for(mod))
{
auto inputs = ins->inputs();
for(auto in : inputs)
{
if(not mod.has_instruction(in))
{
in->remove_output(ins);
}
}
}
impl->modules.erase(name); impl->modules.erase(name);
} }
......
This diff is collapsed.
#include <migraphx/float_equal.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantize_fp16.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/target.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
static void quantize_module(module& m, const std::vector<std::string>& ins_names)
{
for(auto ins : iterator_for(m))
{
// instructions are not in the set to be quantized
if(not(contains(ins_names, ins->name()) or contains(ins_names, "all")))
continue;
// skip return and convert instructions
if(contains({"@return", "convert"}, ins->name()))
continue;
if(ins->inputs().empty())
continue;
auto mod_inputs = ins->module_inputs();
auto s = ins->get_shape();
// Convert back to original type before quantizing the inputs
if(mod_inputs.empty())
{
auto r = m.insert_instruction(
std::next(ins), make_op("convert", {{"target_type", s.type()}}), ins);
m.replace_instruction(ins, r);
}
// Convert each of the inputs that are floating point to fp16
auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
auto input_type = input->get_shape().type();
if(input_type != shape::float_type and input_type != shape::double_type)
return input;
return m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::half_type}}), input);
});
// Replace inputs
m.replace_instruction(ins, ins->get_operator(), inputs, mod_inputs);
}
}
void quantize_fp16_pass::apply(module& m) const { quantize_module(m, ins_names); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/operation.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantize_int8.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/target.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <numeric>
#include <set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS)
static std::vector<shape::type_t>& get_quantizable_type()
{
static std::vector<shape::type_t> quantable_types = {
shape::float_type, shape::double_type, shape::half_type};
return quantable_types;
}
void quantize_int8_pass::apply(module& m) const // NOLINT
{
const auto& quantizable_types = get_quantizable_type();
for(auto ins : iterator_for(m))
{
if(ins->name() != "capture")
continue;
auto op_val = ins->get_operator().to_value();
assert(op_val.contains("ins_index"));
auto param_index = op_val.at("ins_index").to<std::size_t>();
auto param = quant_params[param_index];
auto input = ins->inputs().front();
auto s = input->get_shape();
if(contains(quantizable_types, s.type()) and s.type() != shape::int8_type)
{
auto zero_point = m.add_literal(static_cast<int8_t>(param.second));
auto scale = m.add_literal(literal({s.type()}, {1.0f / param.first}));
const auto& lens = s.lens();
scale =
m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), scale);
zero_point = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", lens}}), zero_point);
auto q_in =
m.insert_instruction(ins, make_op("quantizelinear"), input, scale, zero_point);
auto dq_in =
m.insert_instruction(ins, make_op("dequantizelinear"), q_in, scale, zero_point);
m.replace_instruction(ins, dq_in);
}
}
}
void capture_arguments_pass::apply(module& m) const // NOLINT
{
assert(param_index != nullptr);
for(auto ins : iterator_for(m))
{
if(not contains(ins_names, ins->name()))
{
continue;
}
auto inputs = ins->inputs();
std::vector<instruction_ref> new_args;
for(auto input : inputs)
{
auto new_in = m.insert_instruction(ins, op::capture{(*param_index)++, f}, input);
new_args.push_back(new_in);
}
m.replace_instruction(ins, ins->get_operator(), new_args);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -92,7 +92,8 @@ struct match_find_quantizable_ops ...@@ -92,7 +92,8 @@ struct match_find_quantizable_ops
dq = m.insert_instruction( dq = m.insert_instruction(
qop, migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qop_args); qop, migraphx::make_op("quant_dot", {{"alpha", 1}, {"beta", 0}}), qop_args);
} }
dq_scale = m.add_literal(static_cast<float>(scale)); auto ins_type = qop->get_shape().type();
dq_scale = m.add_literal(literal({ins_type}, {scale}));
auto lens = dq->get_shape().lens(); auto lens = dq->get_shape().lens();
auto scale_mb = auto scale_mb =
......
...@@ -68,8 +68,7 @@ TEST_CASE(int8_quantization) ...@@ -68,8 +68,7 @@ TEST_CASE(int8_quantization)
migraphx::shape sc{migraphx::shape::float_type, {5, 8}}; migraphx::shape sc{migraphx::shape::float_type, {5, 8}};
auto pa = mm->add_parameter("a", sa); auto pa = mm->add_parameter("a", sa);
auto pb = mm->add_parameter("b", sb); auto pb = mm->add_parameter("b", sb);
auto pc = mm->add_parameter("c", sc); mm->add_instruction(migraphx::op::dot{}, pa, pb);
mm->add_instruction(migraphx::op::dot{}, pa, pb, pc);
return p; return p;
}; };
...@@ -82,7 +81,6 @@ TEST_CASE(int8_quantization) ...@@ -82,7 +81,6 @@ TEST_CASE(int8_quantization)
migraphx::shape sc{migraphx::shape::float_type, {5, 8}}; migraphx::shape sc{migraphx::shape::float_type, {5, 8}};
m["a"] = migraphx::generate_argument(sa); m["a"] = migraphx::generate_argument(sa);
m["b"] = migraphx::generate_argument(sb); m["b"] = migraphx::generate_argument(sb);
m["c"] = migraphx::generate_argument(sc);
std::vector<float> ref_result; std::vector<float> ref_result;
migraphx::target ref_t = migraphx::ref::target{}; migraphx::target ref_t = migraphx::ref::target{};
run_prog(p, ref_t, m, ref_result); run_prog(p, ref_t, m, ref_result);
......
This diff is collapsed.
...@@ -2741,43 +2741,6 @@ TEST_CASE(not_test) ...@@ -2741,43 +2741,6 @@ TEST_CASE(not_test)
} }
} }
TEST_CASE(op_capture)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s1{migraphx::shape::float_type, {3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {3, 6}};
std::vector<float> d1(s1.elements());
std::vector<float> d2(s2.elements());
std::iota(d1.begin(), d1.end(), 0.0f);
std::iota(d2.begin(), d2.end(), 0.0f);
auto p1 = mm->add_literal(s1, d1);
auto p2 = mm->add_literal(s1, d1);
auto pb = mm->add_literal(s2, d2);
auto pc = mm->add_literal(s2, d2);
auto pa = mm->add_instruction(migraphx::make_op("add"), p1, p2);
auto ps = mm->add_instruction(migraphx::make_op("dot"), pa, pb, pc);
mm->add_instruction(migraphx::make_op("dot"), pa, ps);
migraphx::program capture_p = p;
migraphx::target t = migraphx::ref::target{};
migraphx::capture_arguments(capture_p, t, {"dot"});
p.compile(migraphx::ref::target{});
capture_p.compile(migraphx::ref::target{});
auto cap_res = capture_p.eval({}).back();
auto res = p.eval({}).back();
std::vector<float> vec;
std::vector<float> cap_vec;
cap_res.visit([&](auto output) { cap_vec.assign(output.begin(), output.end()); });
res.visit([&](auto output) { vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(vec, cap_vec));
}
TEST_CASE(pad_test) TEST_CASE(pad_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment