Unverified Commit 7e61114a authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

refactoring quantization passes (#2544)

parent b742b528
......@@ -231,13 +231,13 @@ void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names)
struct quantize_int8_options
{
std::vector<parameter_map> calibration = {};
std::vector<std::string> op_names = {};
std::vector<parameter_map> calibration = {};
std::unordered_set<std::string> op_names = {};
};
void add_op_name(quantize_int8_options& options, const char* name)
{
options.op_names.push_back(name);
options.op_names.insert(name);
}
void add_calibration_data(quantize_int8_options& options, parameter_map& data)
......
......@@ -44,8 +44,8 @@ MIGRAPHX_EXPORT void quantize_fp16(program& prog,
MIGRAPHX_EXPORT void quantize_int8(program& prog,
const target& t,
const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names = {"dot",
"convolution"});
const std::unordered_set<std::string>& ins_names = {
"dot", "convolution"});
MIGRAPHX_EXPORT void
quantize_fp8(program& prog, const target& t, const std::vector<parameter_map>& calibration);
......
......@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZE_8BITS_HPP
#include <string>
#include <unordered_set>
#include <vector>
#include <functional>
#include <migraphx/argument.hpp>
......@@ -41,7 +42,7 @@ struct module;
*/
struct MIGRAPHX_EXPORT capture_arguments_pass
{
std::vector<std::string> ins_names = {"dot", "convolution"};
std::unordered_set<std::string> ins_names = {"dot", "convolution"};
std::function<void(std::size_t, std::vector<argument>)> f{};
std::size_t* param_index = nullptr;
std::string name() const { return "capture_arguments"; }
......@@ -53,8 +54,7 @@ struct MIGRAPHX_EXPORT capture_arguments_pass
*/
struct MIGRAPHX_EXPORT quantize_8bits_pass
{
shape::type_t precision = shape::int8_type;
std::vector<std::string> ins_names = {"dot", "convolution"};
shape::type_t precision = shape::int8_type;
std::vector<std::pair<float, float>> quant_params;
std::string name() const { return "quantize_8bits"; }
void apply(module& m) const;
......
......@@ -580,7 +580,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::arg("prog"),
py::arg("t"),
py::arg("calibration") = std::vector<migraphx::parameter_map>{},
py::arg("ins_names") = std::vector<std::string>{"dot", "convolution"});
py::arg("ins_names") = std::unordered_set<std::string>{"dot", "convolution"});
#ifdef HAVE_GPU
m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
......
......@@ -61,7 +61,7 @@ void quantize_8bits(program& prog,
const target& t,
shape::type_t precision,
const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names)
const std::unordered_set<std::string>& ins_names)
{
// Run optimize_module() before converting to int8/fp8 to const eval and fold in FP32 to
// avoid loss of precision.
......@@ -138,7 +138,7 @@ void quantize_8bits(program& prog,
}
run_passes(prog,
{quantize_8bits_pass{precision, ins_names, *quant_8bit_params},
{quantize_8bits_pass{precision, *quant_8bit_params},
simplify_qdq{},
optimize_module{},
dead_code_elimination{}});
......@@ -147,12 +147,10 @@ void quantize_8bits(program& prog,
void quantize_int8(program& prog,
const target& t,
const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names)
const std::unordered_set<std::string>& ins_names)
{
std::set<std::string> op_names = {"convolution", "dot"};
std::set<std::string> input_ins_names(ins_names.begin(), ins_names.end());
if(not std::includes(
op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end()))
std::unordered_set<std::string> op_names = {"convolution", "dot"};
if(op_names != ins_names)
{
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
}
......@@ -164,7 +162,7 @@ void quantize_fp8(program& prog, const target& t, const std::vector<parameter_ma
std::cout << "[Warning] : MIGraphX has BETA support for FP8. Using FP8 may result in "
"incorrect final outputs\n";
std::vector<std::string> supported_ins_names;
std::unordered_set<std::string> supported_ins_names;
auto* mm = prog.get_main_module();
for(auto ins : iterator_for(*mm))
{
......@@ -172,9 +170,9 @@ void quantize_fp8(program& prog, const target& t, const std::vector<parameter_ma
{
continue;
}
else if(not starts_with(ins->name(), "@"))
if(not starts_with(ins->name(), "@"))
{
supported_ins_names.push_back(ins->name());
supported_ins_names.insert(ins->name());
}
}
quantize_8bits(prog, t, shape::fp8e4m3fnuz_type, calibration, supported_ins_names);
......
......@@ -90,11 +90,7 @@ void capture_arguments_pass::apply(module& m) const // NOLINT
for(auto ins : iterator_for(m))
{
if(not contains(ins_names, ins->name()))
{
continue;
}
if(ins->name() == "convert")
if((not contains(ins_names, ins->name())) or (ins->name() == "convert"))
{
continue;
}
......
......@@ -654,7 +654,7 @@ TEST_CASE(dot_float)
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}});
migraphx::run_passes(
p,
{migraphx::quantize_8bits_pass{migraphx::shape::type_t::int8_type, {"dot"}, quant_params},
{migraphx::quantize_8bits_pass{migraphx::shape::type_t::int8_type, quant_params},
migraphx::dead_code_elimination{}});
auto qp = create_int8_quantized_prog();
......@@ -749,7 +749,7 @@ TEST_CASE(dot_double_2args)
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}});
migraphx::run_passes(
p,
{migraphx::quantize_8bits_pass{migraphx::shape::type_t::int8_type, {"dot"}, quant_params},
{migraphx::quantize_8bits_pass{migraphx::shape::type_t::int8_type, quant_params},
migraphx::dead_code_elimination{}});
EXPECT(p == create_int8_quantized_prog());
......@@ -823,7 +823,7 @@ TEST_CASE(dot_half_1arg)
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}});
migraphx::run_passes(
p,
{migraphx::quantize_8bits_pass{migraphx::shape::int8_type, {"dot"}, quant_params},
{migraphx::quantize_8bits_pass{migraphx::shape::int8_type, quant_params},
migraphx::dead_code_elimination{}});
EXPECT(p == create_int8_quantized_prog());
......@@ -881,7 +881,7 @@ TEST_CASE(conv_float)
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"convolution"}, {}, &param_index}});
migraphx::run_passes(p,
{migraphx::quantize_8bits_pass{
migraphx::shape::type_t::int8_type, {"convolution"}, quant_params}});
migraphx::shape::type_t::int8_type, quant_params}});
optimize_prog_int8(p);
auto qp = create_int8_quantized_prog();
......@@ -908,7 +908,7 @@ TEST_CASE(conv_float_throw)
test::throws([&] {
migraphx::run_passes(p,
{migraphx::quantize_8bits_pass{
migraphx::shape::type_t::int8_type, {"add"}, quant_params}});
migraphx::shape::type_t::int8_type, quant_params}});
});
}
......@@ -961,7 +961,7 @@ TEST_CASE(conv_half)
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"convolution"}, {}, &param_index}});
migraphx::run_passes(p,
{migraphx::quantize_8bits_pass{
migraphx::shape::type_t::int8_type, {"convolution"}, quant_params}});
migraphx::shape::type_t::int8_type, quant_params}});
optimize_prog_int8(p);
auto qp = create_int8_quantized_prog();
......@@ -1242,7 +1242,6 @@ TEST_CASE(int8_subgraph)
p1, {migraphx::capture_arguments_pass{{"convolution", "dot"}, {}, &param_index}});
migraphx::run_passes(p1,
{migraphx::quantize_8bits_pass{migraphx::shape::type_t::int8_type,
{"convolution", "dot"},
quant_params}});
optimize_prog_int8(p1);
......
......@@ -232,12 +232,12 @@ void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names)
struct quantize_int8_options
{
std::vector<parameter_map> calibration = {};
std::vector<std::string> op_names = {};
std::unordered_set<std::string> op_names = {};
};
void add_op_name(quantize_int8_options& options, const char* name)
{
options.op_names.push_back(name);
options.op_names.insert(name);
}
void add_calibration_data(quantize_int8_options& options, parameter_map& data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment