Commit 684d5a4e authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix review comments

parent 0fbdaf58
...@@ -15,10 +15,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -15,10 +15,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS) void quantize_fp16(program& prog, const std::vector<std::string>& ins_names = {"all"});
void quantize_fp16(program& prog, const std::vector<std::string>& ins_names);
void quantize_fp16(program& prog);
// insert the capture operator for the inputs of each operator to be quantized // insert the capture operator for the inputs of each operator to be quantized
// to int8 // to int8
...@@ -31,7 +28,7 @@ capture_arguments_impl(program& prog, const target& t, const std::vector<std::st ...@@ -31,7 +28,7 @@ capture_arguments_impl(program& prog, const target& t, const std::vector<std::st
template <class T> template <class T>
std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments( std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments(
program& prog, T&& t, const std::vector<std::string>& ins_names = {"dot", "convolution"}) program& prog, T&& t, const std::vector<std::string>& ins_names)
{ {
static_assert(std::is_same<std::remove_cv_t<std::remove_reference_t<T>>, target>{} && static_assert(std::is_same<std::remove_cv_t<std::remove_reference_t<T>>, target>{} &&
std::is_lvalue_reference<T>{}, std::is_lvalue_reference<T>{},
...@@ -41,9 +38,9 @@ std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments( ...@@ -41,9 +38,9 @@ std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments(
void quantize_int8(program& prog, void quantize_int8(program& prog,
const target& t, const target& t,
std::vector<program::parameter_map>& calibration_args, std::vector<program::parameter_map>& calibration,
const std::vector<std::string>& ins_names = {"dot", "convolution"}); const std::vector<std::string>& ins_names = {"dot", "convolution"});
void quantize_int8(program& prog, void quantize_int8_impl(program& prog,
const std::vector<std::pair<float, float>>& quant_params, const std::vector<std::pair<float, float>>& quant_params,
const std::vector<std::string>& ins_names); const std::vector<std::string>& ins_names);
......
...@@ -183,29 +183,9 @@ PYBIND11_MODULE(migraphx, m) ...@@ -183,29 +183,9 @@ PYBIND11_MODULE(migraphx, m)
}); });
m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0); m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0);
m.def("quantize_fp16", [](migraphx::program& p, std::vector<std::string>& ins_names) { m.def("quantize_fp16", &migraphx::quantize_fp16, py::arg("prog"), py::arg("ins_names") = std::vector<std::string>{"all"});
migraphx::quantize_fp16(p, ins_names); m.def("quantize_int8", &migraphx::quantize_int8, py::arg("prog"), py::arg("t"), py::arg("calibration") = std::vector<migraphx::program::parameter_map>{},
}); py::arg("ins_names") = std::vector<std::string>{"dot", "convolution"});
m.def("quantize_fp16", [](migraphx::program& p) { migraphx::quantize_fp16(p, {"all"}); });
m.def("quantize_int8",
[](migraphx::program& p,
std::vector<std::string>& ins_names,
std::vector<std::pair<float, float>>& quant_params) {
migraphx::quantize_int8(p, quant_params, ins_names);
});
m.def("quantize_int8",
[](migraphx::program& p,
const migraphx::target& t,
std::vector<std::string>& ins_names,
std::vector<migraphx::program::parameter_map>& cali_args) {
migraphx::quantize_int8(p, t, cali_args, ins_names);
});
m.def("quantize_int8",
[](migraphx::program& p,
const migraphx::target& t,
std::vector<migraphx::program::parameter_map>& cali_args) {
migraphx::quantize_int8(p, t, cali_args);
});
#ifdef HAVE_GPU #ifdef HAVE_GPU
m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false); m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
......
...@@ -17,12 +17,16 @@ ...@@ -17,12 +17,16 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/target.hpp> #include <migraphx/target.hpp>
#include <utility> #include <utility>
#include <set>
#include <iomanip> #include <iomanip>
#include <fstream> #include <fstream>
#include <algorithm>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS)
instruction_ref insert_quant_ins(program& prog, instruction_ref insert_quant_ins(program& prog,
instruction_ref& ins, instruction_ref& ins,
shape::type_t type, shape::type_t type,
...@@ -161,8 +165,6 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names) ...@@ -161,8 +165,6 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
} }
} }
void quantize_fp16(program& prog) { quantize_fp16(prog, {"all"}); }
static void ins_quantize_int8(program& prog, static void ins_quantize_int8(program& prog,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref>& converted_inputs, std::vector<instruction_ref>& converted_inputs,
...@@ -305,7 +307,7 @@ static void ins_quantize_int8(program& prog, ...@@ -305,7 +307,7 @@ static void ins_quantize_int8(program& prog,
// -128 ~ 127. To convert the float or double to int8, we need a scale and // -128 ~ 127. To convert the float or double to int8, we need a scale and
// a shift, then the convert can be done as v_int8 = fp * scale + shift. // a shift, then the convert can be done as v_int8 = fp * scale + shift.
// To simplify the changes, we consider shift as 0.0f for now. // To simplify the changes, we consider shift as 0.0f for now.
void quantize_int8(program& prog, void quantize_int8_impl(program& prog,
const std::vector<std::pair<float, float>>& quant_params, const std::vector<std::pair<float, float>>& quant_params,
const std::vector<std::string>& ins_names) const std::vector<std::string>& ins_names)
{ {
...@@ -321,10 +323,9 @@ void quantize_int8(program& prog, ...@@ -321,10 +323,9 @@ void quantize_int8(program& prog,
} }
// For now, we only support the int8 quantization of gemm and convolution // For now, we only support the int8 quantization of gemm and convolution
std::vector<std::string> op_names = {"dot", "convolution"}; std::set<std::string> op_names = {"convolution", "dot"};
if(!std::all_of(ins_names.begin(), ins_names.end(), [&](auto name) { std::set<std::string> input_ins_names(ins_names.begin(), ins_names.end());
return (std::find(op_names.begin(), op_names.end(), name) != op_names.end()); if (!std::includes(op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end()))
}))
{ {
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation"); MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
} }
...@@ -381,7 +382,7 @@ void quantize_int8(program& prog, ...@@ -381,7 +382,7 @@ void quantize_int8(program& prog,
quant_input = input->inputs().front(); quant_input = input->inputs().front();
// the scale in this case is not used, so tune the scale // the scale in this case is not used, so tune the scale
// to 1.0f for this parameter // to 1.0f for this parameter
ins_quant_params.back() = std::make_pair<float, float>(1.0f, 0.0f); ins_quant_params.back() = std::pair<float, float>(1.0f, 0.0f);
} }
else else
{ {
...@@ -413,7 +414,7 @@ void quantize_int8(program& prog, ...@@ -413,7 +414,7 @@ void quantize_int8(program& prog,
void quantize_int8(program& prog, void quantize_int8(program& prog,
const target& t, const target& t,
std::vector<program::parameter_map>& calibration_args, std::vector<program::parameter_map>& calibration,
const std::vector<std::string>& ins_names) const std::vector<std::string>& ins_names)
{ {
// insert capture operator // insert capture operator
...@@ -425,7 +426,7 @@ void quantize_int8(program& prog, ...@@ -425,7 +426,7 @@ void quantize_int8(program& prog,
// use all calibration data to run the program to calculate the // use all calibration data to run the program to calculate the
// quantization scale and shift // quantization scale and shift
for(auto&& arg : calibration_args) for(auto&& arg : calibration)
{ {
program::parameter_map m; program::parameter_map m;
for(auto&& x : cap_prog.get_parameter_shapes()) for(auto&& x : cap_prog.get_parameter_shapes())
...@@ -443,7 +444,7 @@ void quantize_int8(program& prog, ...@@ -443,7 +444,7 @@ void quantize_int8(program& prog,
cap_prog.eval(m); cap_prog.eval(m);
} }
quantize_int8(prog, *int8_quant_params, ins_names); quantize_int8_impl(prog, *int8_quant_params, ins_names);
} }
// For the input of each input argument, we need to insert a // For the input of each input argument, we need to insert a
...@@ -456,9 +457,8 @@ std::size_t capture_arguments(program& prog, ...@@ -456,9 +457,8 @@ std::size_t capture_arguments(program& prog,
size_t num_quant_params = 0; size_t num_quant_params = 0;
// the int8 quantization only support dot and convolution // the int8 quantization only support dot and convolution
std::vector<std::string> op_names = {"dot", "convolution"}; std::vector<std::string> op_names = {"dot", "convolution"};
if(!std::all_of(ins_names.begin(), ins_names.end(), [&](auto name) { std::set<std::string> input_ins_names(ins_names.begin(), ins_names.end());
return std::find(op_names.begin(), op_names.end(), name) != op_names.end(); if (!std::includes(op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end()))
}))
{ {
MIGRAPHX_THROW("CAPTURE_ARGUMENTS: input operator is not supported"); MIGRAPHX_THROW("CAPTURE_ARGUMENTS: input operator is not supported");
} }
......
...@@ -2068,7 +2068,7 @@ TEST_CASE(op_capture) ...@@ -2068,7 +2068,7 @@ TEST_CASE(op_capture)
migraphx::program capture_p = p; migraphx::program capture_p = p;
migraphx::target t = migraphx::cpu::target{}; migraphx::target t = migraphx::cpu::target{};
migraphx::capture_arguments(capture_p, t); migraphx::capture_arguments(capture_p, t, {"dot"});
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
capture_p.compile(migraphx::cpu::target{}); capture_p.compile(migraphx::cpu::target{});
......
...@@ -250,7 +250,7 @@ TEST_CASE(op_capture) ...@@ -250,7 +250,7 @@ TEST_CASE(op_capture)
auto p = create_program_float(); auto p = create_program_float();
auto op_capture_p = create_program_op(); auto op_capture_p = create_program_op();
migraphx::target t = migraphx::cpu::target{}; migraphx::target t = migraphx::cpu::target{};
migraphx::capture_arguments(p, t); migraphx::capture_arguments(p, t, {"dot", "convolution"});
EXPECT(p == op_capture_p); EXPECT(p == op_capture_p);
} }
} }
...@@ -313,7 +313,7 @@ TEST_CASE(dot_float) ...@@ -313,7 +313,7 @@ TEST_CASE(dot_float)
auto p = create_program(); auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{ const std::vector<std::pair<float, float>>& quant_params{
{0.1f, 0.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}}; {0.1f, 0.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
migraphx::quantize_int8(p, quant_params, {"dot"}); migraphx::quantize_int8_impl(p, quant_params, {"dot"});
migraphx::run_passes(p, {migraphx::dead_code_elimination{}}); migraphx::run_passes(p, {migraphx::dead_code_elimination{}});
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
...@@ -375,7 +375,7 @@ TEST_CASE(dot_double_2args) ...@@ -375,7 +375,7 @@ TEST_CASE(dot_double_2args)
auto p = create_program(); auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}}; const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
migraphx::quantize_int8(p, quant_params, {"dot"}); migraphx::quantize_int8_impl(p, quant_params, {"dot"});
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
EXPECT(p == qp); EXPECT(p == qp);
...@@ -440,7 +440,7 @@ TEST_CASE(dot_large_alpha_beta_float) ...@@ -440,7 +440,7 @@ TEST_CASE(dot_large_alpha_beta_float)
auto p = create_program(); auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{ const std::vector<std::pair<float, float>>& quant_params{
{0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}}; {0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
migraphx::quantize_int8(p, quant_params, {"dot"}); migraphx::quantize_int8_impl(p, quant_params, {"dot"});
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
EXPECT(p == qp); EXPECT(p == qp);
...@@ -504,7 +504,7 @@ TEST_CASE(dot_large_alpha_beta_int32) ...@@ -504,7 +504,7 @@ TEST_CASE(dot_large_alpha_beta_int32)
auto p = create_program(); auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{ const std::vector<std::pair<float, float>>& quant_params{
{0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}}; {0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
migraphx::quantize_int8(p, quant_params, {"dot"}); migraphx::quantize_int8_impl(p, quant_params, {"dot"});
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
EXPECT(p == qp); EXPECT(p == qp);
...@@ -548,7 +548,7 @@ TEST_CASE(dot_int32_one_arg) ...@@ -548,7 +548,7 @@ TEST_CASE(dot_int32_one_arg)
auto p = create_program(); auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{1.0f, 1.0f}}; const std::vector<std::pair<float, float>>& quant_params{{1.0f, 1.0f}};
migraphx::quantize_int8(p, quant_params, {"dot"}); migraphx::quantize_int8_impl(p, quant_params, {"dot"});
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
EXPECT(p == qp); EXPECT(p == qp);
...@@ -622,7 +622,7 @@ TEST_CASE(dot_int32) ...@@ -622,7 +622,7 @@ TEST_CASE(dot_int32)
auto p = create_program(); auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{ const std::vector<std::pair<float, float>>& quant_params{
{0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}}; {0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
migraphx::quantize_int8(p, quant_params, {"dot"}); migraphx::quantize_int8_impl(p, quant_params, {"dot"});
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
EXPECT(p == qp); EXPECT(p == qp);
...@@ -671,7 +671,7 @@ TEST_CASE(dot_float_convert) ...@@ -671,7 +671,7 @@ TEST_CASE(dot_float_convert)
auto p = create_program(); auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 1.0f}, {0.1f, 0.0f}}; const std::vector<std::pair<float, float>>& quant_params{{0.1f, 1.0f}, {0.1f, 0.0f}};
migraphx::quantize_int8(p, quant_params, {"dot"}); migraphx::quantize_int8_impl(p, quant_params, {"dot"});
migraphx::run_passes(p, {migraphx::dead_code_elimination{}}); migraphx::run_passes(p, {migraphx::dead_code_elimination{}});
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
...@@ -726,7 +726,7 @@ TEST_CASE(conv_float) ...@@ -726,7 +726,7 @@ TEST_CASE(conv_float)
auto p = create_program(); auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}}; const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
migraphx::quantize_int8(p, quant_params, {"convolution"}); migraphx::quantize_int8_impl(p, quant_params, {"convolution"});
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
EXPECT(p == qp); EXPECT(p == qp);
...@@ -782,7 +782,7 @@ TEST_CASE(conv_int32) ...@@ -782,7 +782,7 @@ TEST_CASE(conv_int32)
auto p = create_program(); auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}}; const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
migraphx::quantize_int8(p, quant_params, {"convolution"}); migraphx::quantize_int8_impl(p, quant_params, {"convolution"});
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
EXPECT(p == qp); EXPECT(p == qp);
...@@ -840,7 +840,7 @@ TEST_CASE(conv_half) ...@@ -840,7 +840,7 @@ TEST_CASE(conv_half)
auto p = create_program(); auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}}; const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
migraphx::quantize_int8(p, quant_params, {"convolution"}); migraphx::quantize_int8_impl(p, quant_params, {"convolution"});
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
EXPECT(p == qp); EXPECT(p == qp);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment