Unverified Commit 7e61114a authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

refactoring quantization passes (#2544)

parent b742b528
...@@ -231,13 +231,13 @@ void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names) ...@@ -231,13 +231,13 @@ void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names)
struct quantize_int8_options struct quantize_int8_options
{ {
std::vector<parameter_map> calibration = {}; std::vector<parameter_map> calibration = {};
std::vector<std::string> op_names = {}; std::unordered_set<std::string> op_names = {};
}; };
void add_op_name(quantize_int8_options& options, const char* name) void add_op_name(quantize_int8_options& options, const char* name)
{ {
options.op_names.push_back(name); options.op_names.insert(name);
} }
void add_calibration_data(quantize_int8_options& options, parameter_map& data) void add_calibration_data(quantize_int8_options& options, parameter_map& data)
......
...@@ -44,8 +44,8 @@ MIGRAPHX_EXPORT void quantize_fp16(program& prog, ...@@ -44,8 +44,8 @@ MIGRAPHX_EXPORT void quantize_fp16(program& prog,
MIGRAPHX_EXPORT void quantize_int8(program& prog, MIGRAPHX_EXPORT void quantize_int8(program& prog,
const target& t, const target& t,
const std::vector<parameter_map>& calibration, const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names = {"dot", const std::unordered_set<std::string>& ins_names = {
"convolution"}); "dot", "convolution"});
MIGRAPHX_EXPORT void MIGRAPHX_EXPORT void
quantize_fp8(program& prog, const target& t, const std::vector<parameter_map>& calibration); quantize_fp8(program& prog, const target& t, const std::vector<parameter_map>& calibration);
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_8BITS_HPP #define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_8BITS_HPP
#include <string> #include <string>
#include <unordered_set>
#include <vector> #include <vector>
#include <functional> #include <functional>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
...@@ -41,7 +42,7 @@ struct module; ...@@ -41,7 +42,7 @@ struct module;
*/ */
struct MIGRAPHX_EXPORT capture_arguments_pass struct MIGRAPHX_EXPORT capture_arguments_pass
{ {
std::vector<std::string> ins_names = {"dot", "convolution"}; std::unordered_set<std::string> ins_names = {"dot", "convolution"};
std::function<void(std::size_t, std::vector<argument>)> f{}; std::function<void(std::size_t, std::vector<argument>)> f{};
std::size_t* param_index = nullptr; std::size_t* param_index = nullptr;
std::string name() const { return "capture_arguments"; } std::string name() const { return "capture_arguments"; }
...@@ -53,8 +54,7 @@ struct MIGRAPHX_EXPORT capture_arguments_pass ...@@ -53,8 +54,7 @@ struct MIGRAPHX_EXPORT capture_arguments_pass
*/ */
struct MIGRAPHX_EXPORT quantize_8bits_pass struct MIGRAPHX_EXPORT quantize_8bits_pass
{ {
shape::type_t precision = shape::int8_type; shape::type_t precision = shape::int8_type;
std::vector<std::string> ins_names = {"dot", "convolution"};
std::vector<std::pair<float, float>> quant_params; std::vector<std::pair<float, float>> quant_params;
std::string name() const { return "quantize_8bits"; } std::string name() const { return "quantize_8bits"; }
void apply(module& m) const; void apply(module& m) const;
......
...@@ -580,7 +580,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -580,7 +580,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::arg("prog"), py::arg("prog"),
py::arg("t"), py::arg("t"),
py::arg("calibration") = std::vector<migraphx::parameter_map>{}, py::arg("calibration") = std::vector<migraphx::parameter_map>{},
py::arg("ins_names") = std::vector<std::string>{"dot", "convolution"}); py::arg("ins_names") = std::unordered_set<std::string>{"dot", "convolution"});
#ifdef HAVE_GPU #ifdef HAVE_GPU
m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false); m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
......
...@@ -61,7 +61,7 @@ void quantize_8bits(program& prog, ...@@ -61,7 +61,7 @@ void quantize_8bits(program& prog,
const target& t, const target& t,
shape::type_t precision, shape::type_t precision,
const std::vector<parameter_map>& calibration, const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names) const std::unordered_set<std::string>& ins_names)
{ {
// Run optimize_module() before converting to int8/fp8 to const eval and fold in FP32 to // Run optimize_module() before converting to int8/fp8 to const eval and fold in FP32 to
// avoid loss of precision. // avoid loss of precision.
...@@ -138,7 +138,7 @@ void quantize_8bits(program& prog, ...@@ -138,7 +138,7 @@ void quantize_8bits(program& prog,
} }
run_passes(prog, run_passes(prog,
{quantize_8bits_pass{precision, ins_names, *quant_8bit_params}, {quantize_8bits_pass{precision, *quant_8bit_params},
simplify_qdq{}, simplify_qdq{},
optimize_module{}, optimize_module{},
dead_code_elimination{}}); dead_code_elimination{}});
...@@ -147,12 +147,10 @@ void quantize_8bits(program& prog, ...@@ -147,12 +147,10 @@ void quantize_8bits(program& prog,
void quantize_int8(program& prog, void quantize_int8(program& prog,
const target& t, const target& t,
const std::vector<parameter_map>& calibration, const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names) const std::unordered_set<std::string>& ins_names)
{ {
std::set<std::string> op_names = {"convolution", "dot"}; std::unordered_set<std::string> op_names = {"convolution", "dot"};
std::set<std::string> input_ins_names(ins_names.begin(), ins_names.end()); if(op_names != ins_names)
if(not std::includes(
op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end()))
{ {
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation"); MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
} }
...@@ -164,7 +162,7 @@ void quantize_fp8(program& prog, const target& t, const std::vector<parameter_ma ...@@ -164,7 +162,7 @@ void quantize_fp8(program& prog, const target& t, const std::vector<parameter_ma
std::cout << "[Warning] : MIGraphX has BETA support for FP8. Using FP8 may result in " std::cout << "[Warning] : MIGraphX has BETA support for FP8. Using FP8 may result in "
"incorrect final outputs\n"; "incorrect final outputs\n";
std::vector<std::string> supported_ins_names; std::unordered_set<std::string> supported_ins_names;
auto* mm = prog.get_main_module(); auto* mm = prog.get_main_module();
for(auto ins : iterator_for(*mm)) for(auto ins : iterator_for(*mm))
{ {
...@@ -172,9 +170,9 @@ void quantize_fp8(program& prog, const target& t, const std::vector<parameter_ma ...@@ -172,9 +170,9 @@ void quantize_fp8(program& prog, const target& t, const std::vector<parameter_ma
{ {
continue; continue;
} }
else if(not starts_with(ins->name(), "@")) if(not starts_with(ins->name(), "@"))
{ {
supported_ins_names.push_back(ins->name()); supported_ins_names.insert(ins->name());
} }
} }
quantize_8bits(prog, t, shape::fp8e4m3fnuz_type, calibration, supported_ins_names); quantize_8bits(prog, t, shape::fp8e4m3fnuz_type, calibration, supported_ins_names);
......
...@@ -90,11 +90,7 @@ void capture_arguments_pass::apply(module& m) const // NOLINT ...@@ -90,11 +90,7 @@ void capture_arguments_pass::apply(module& m) const // NOLINT
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
if(not contains(ins_names, ins->name())) if((not contains(ins_names, ins->name())) or (ins->name() == "convert"))
{
continue;
}
if(ins->name() == "convert")
{ {
continue; continue;
} }
......
...@@ -654,7 +654,7 @@ TEST_CASE(dot_float) ...@@ -654,7 +654,7 @@ TEST_CASE(dot_float)
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}}); migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}});
migraphx::run_passes( migraphx::run_passes(
p, p,
{migraphx::quantize_8bits_pass{migraphx::shape::type_t::int8_type, {"dot"}, quant_params}, {migraphx::quantize_8bits_pass{migraphx::shape::type_t::int8_type, quant_params},
migraphx::dead_code_elimination{}}); migraphx::dead_code_elimination{}});
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
...@@ -749,7 +749,7 @@ TEST_CASE(dot_double_2args) ...@@ -749,7 +749,7 @@ TEST_CASE(dot_double_2args)
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}}); migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}});
migraphx::run_passes( migraphx::run_passes(
p, p,
{migraphx::quantize_8bits_pass{migraphx::shape::type_t::int8_type, {"dot"}, quant_params}, {migraphx::quantize_8bits_pass{migraphx::shape::type_t::int8_type, quant_params},
migraphx::dead_code_elimination{}}); migraphx::dead_code_elimination{}});
EXPECT(p == create_int8_quantized_prog()); EXPECT(p == create_int8_quantized_prog());
...@@ -823,7 +823,7 @@ TEST_CASE(dot_half_1arg) ...@@ -823,7 +823,7 @@ TEST_CASE(dot_half_1arg)
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}}); migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}});
migraphx::run_passes( migraphx::run_passes(
p, p,
{migraphx::quantize_8bits_pass{migraphx::shape::int8_type, {"dot"}, quant_params}, {migraphx::quantize_8bits_pass{migraphx::shape::int8_type, quant_params},
migraphx::dead_code_elimination{}}); migraphx::dead_code_elimination{}});
EXPECT(p == create_int8_quantized_prog()); EXPECT(p == create_int8_quantized_prog());
...@@ -881,7 +881,7 @@ TEST_CASE(conv_float) ...@@ -881,7 +881,7 @@ TEST_CASE(conv_float)
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"convolution"}, {}, &param_index}}); migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"convolution"}, {}, &param_index}});
migraphx::run_passes(p, migraphx::run_passes(p,
{migraphx::quantize_8bits_pass{ {migraphx::quantize_8bits_pass{
migraphx::shape::type_t::int8_type, {"convolution"}, quant_params}}); migraphx::shape::type_t::int8_type, quant_params}});
optimize_prog_int8(p); optimize_prog_int8(p);
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
...@@ -908,7 +908,7 @@ TEST_CASE(conv_float_throw) ...@@ -908,7 +908,7 @@ TEST_CASE(conv_float_throw)
test::throws([&] { test::throws([&] {
migraphx::run_passes(p, migraphx::run_passes(p,
{migraphx::quantize_8bits_pass{ {migraphx::quantize_8bits_pass{
migraphx::shape::type_t::int8_type, {"add"}, quant_params}}); migraphx::shape::type_t::int8_type, quant_params}});
}); });
} }
...@@ -961,7 +961,7 @@ TEST_CASE(conv_half) ...@@ -961,7 +961,7 @@ TEST_CASE(conv_half)
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"convolution"}, {}, &param_index}}); migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"convolution"}, {}, &param_index}});
migraphx::run_passes(p, migraphx::run_passes(p,
{migraphx::quantize_8bits_pass{ {migraphx::quantize_8bits_pass{
migraphx::shape::type_t::int8_type, {"convolution"}, quant_params}}); migraphx::shape::type_t::int8_type, quant_params}});
optimize_prog_int8(p); optimize_prog_int8(p);
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
...@@ -1242,7 +1242,6 @@ TEST_CASE(int8_subgraph) ...@@ -1242,7 +1242,6 @@ TEST_CASE(int8_subgraph)
p1, {migraphx::capture_arguments_pass{{"convolution", "dot"}, {}, &param_index}}); p1, {migraphx::capture_arguments_pass{{"convolution", "dot"}, {}, &param_index}});
migraphx::run_passes(p1, migraphx::run_passes(p1,
{migraphx::quantize_8bits_pass{migraphx::shape::type_t::int8_type, {migraphx::quantize_8bits_pass{migraphx::shape::type_t::int8_type,
{"convolution", "dot"},
quant_params}}); quant_params}});
optimize_prog_int8(p1); optimize_prog_int8(p1);
......
...@@ -232,12 +232,12 @@ void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names) ...@@ -232,12 +232,12 @@ void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names)
struct quantize_int8_options struct quantize_int8_options
{ {
std::vector<parameter_map> calibration = {}; std::vector<parameter_map> calibration = {};
std::vector<std::string> op_names = {}; std::unordered_set<std::string> op_names = {};
}; };
void add_op_name(quantize_int8_options& options, const char* name) void add_op_name(quantize_int8_options& options, const char* name)
{ {
options.op_names.push_back(name); options.op_names.insert(name);
} }
void add_calibration_data(quantize_int8_options& options, parameter_map& data) void add_calibration_data(quantize_int8_options& options, parameter_map& data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment