Commit 69da37bb authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix review comments

parent aefffcda
...@@ -14,8 +14,8 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -14,8 +14,8 @@ inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
void quantize(program& prog, const std::vector<std::string>& ins_names); void quantize_fp16(program& prog, const std::vector<std::string>& ins_names);
void quantize(program& prog); void quantize_fp16(program& prog);
// insert the capture operator for the inputs of each operator to be quantized // insert the capture operator for the inputs of each operator to be quantized
// to int8 // to int8
...@@ -24,9 +24,7 @@ std::size_t capture_arguments(program& prog, ...@@ -24,9 +24,7 @@ std::size_t capture_arguments(program& prog,
const std::function<void(std::size_t, std::vector<argument>)>& func); const std::function<void(std::size_t, std::vector<argument>)>& func);
std::shared_ptr<std::vector<std::pair<float, float>>> std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments_impl(program& prog, capture_arguments_impl(program& prog, const target& t, const std::vector<std::string>& ins_names);
const target& t,
const std::vector<std::string>& ins_names = {"dot", "convolution"});
template <class T> template <class T>
std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments( std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments(
...@@ -40,14 +38,11 @@ std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments( ...@@ -40,14 +38,11 @@ std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments(
void quantize_int8(program& prog, void quantize_int8(program& prog,
const target& t, const target& t,
std::vector<program::parameter_map>& calibration_args); std::vector<program::parameter_map>& calibration_args,
const std::vector<std::string>& ins_names = {"dot", "convolution"});
void quantize_int8(program& prog, void quantize_int8(program& prog,
const target& t, const std::vector<std::pair<float, float>>& quant_params,
const std::vector<std::string>& ins_names, const std::vector<std::string>& ins_names);
std::vector<program::parameter_map>& calibration_args);
void quantize_int8(program& prog,
const std::vector<std::string>& ins_names,
const std::vector<std::pair<float, float>>& quant_params);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -183,22 +183,22 @@ PYBIND11_MODULE(migraphx, m) ...@@ -183,22 +183,22 @@ PYBIND11_MODULE(migraphx, m)
}); });
m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0); m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0);
m.def("quantize", [](migraphx::program& p, std::vector<std::string>& ins_names) { m.def("quantize_fp16", [](migraphx::program& p, std::vector<std::string>& ins_names) {
migraphx::quantize(p, ins_names); migraphx::quantize_fp16(p, ins_names);
}); });
m.def("quantize", [](migraphx::program& p) { migraphx::quantize(p, {"all"}); }); m.def("quantize_fp16", [](migraphx::program& p) { migraphx::quantize_fp16(p, {"all"}); });
m.def("quantize_int8", m.def("quantize_int8",
[](migraphx::program& p, [](migraphx::program& p,
std::vector<std::string>& ins_names, std::vector<std::string>& ins_names,
std::vector<std::pair<float, float>>& quant_params) { std::vector<std::pair<float, float>>& quant_params) {
migraphx::quantize_int8(p, ins_names, quant_params); migraphx::quantize_int8(p, quant_params, ins_names);
}); });
m.def("quantize_int8", m.def("quantize_int8",
[](migraphx::program& p, [](migraphx::program& p,
const migraphx::target& t, const migraphx::target& t,
std::vector<std::string>& ins_names, std::vector<std::string>& ins_names,
std::vector<migraphx::program::parameter_map>& cali_args) { std::vector<migraphx::program::parameter_map>& cali_args) {
migraphx::quantize_int8(p, t, ins_names, cali_args); migraphx::quantize_int8(p, t, cali_args, ins_names);
}); });
m.def("quantize_int8", m.def("quantize_int8",
[](migraphx::program& p, [](migraphx::program& p,
......
...@@ -96,7 +96,7 @@ instruction_ref insert_quant_ins(program& prog, ...@@ -96,7 +96,7 @@ instruction_ref insert_quant_ins(program& prog,
// For the conversion, there could be cases of overflowing, but it // For the conversion, there could be cases of overflowing, but it
// is very rare in the area of deeping learning, so we just do a // is very rare in the area of deeping learning, so we just do a
// truncate of the input to get the fp16. // truncate of the input to get the fp16.
void quantize(program& prog, const std::vector<std::string>& ins_names) void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
{ {
std::unordered_map<instruction_ref, instruction_ref> map_fp16; std::unordered_map<instruction_ref, instruction_ref> map_fp16;
for(auto ins : iterator_for(prog)) for(auto ins : iterator_for(prog))
...@@ -161,7 +161,7 @@ void quantize(program& prog, const std::vector<std::string>& ins_names) ...@@ -161,7 +161,7 @@ void quantize(program& prog, const std::vector<std::string>& ins_names)
} }
} }
void quantize(program& prog) { quantize(prog, {"all"}); } void quantize_fp16(program& prog) { quantize_fp16(prog, {"all"}); }
static void ins_quantize_int8(program& prog, static void ins_quantize_int8(program& prog,
instruction_ref ins, instruction_ref ins,
...@@ -306,8 +306,8 @@ static void ins_quantize_int8(program& prog, ...@@ -306,8 +306,8 @@ static void ins_quantize_int8(program& prog,
// a shift, then the convert can be done as v_int8 = fp * scale + shift. // a shift, then the convert can be done as v_int8 = fp * scale + shift.
// To simplify the changes, we consider shift as 0.0f for now. // To simplify the changes, we consider shift as 0.0f for now.
void quantize_int8(program& prog, void quantize_int8(program& prog,
const std::vector<std::string>& ins_names, const std::vector<std::pair<float, float>>& quant_params,
const std::vector<std::pair<float, float>>& quant_params) const std::vector<std::string>& ins_names)
{ {
// For now, we only support the int8 quantization of gemm and convolution // For now, we only support the int8 quantization of gemm and convolution
std::vector<std::string> op_names = {"dot", "convolution"}; std::vector<std::string> op_names = {"dot", "convolution"};
...@@ -402,8 +402,8 @@ void quantize_int8(program& prog, ...@@ -402,8 +402,8 @@ void quantize_int8(program& prog,
void quantize_int8(program& prog, void quantize_int8(program& prog,
const target& t, const target& t,
const std::vector<std::string>& ins_names, std::vector<program::parameter_map>& calibration_args,
std::vector<program::parameter_map>& calibration_args) const std::vector<std::string>& ins_names)
{ {
// insert capture operator // insert capture operator
auto cap_prog = prog; auto cap_prog = prog;
...@@ -432,15 +432,7 @@ void quantize_int8(program& prog, ...@@ -432,15 +432,7 @@ void quantize_int8(program& prog,
cap_prog.eval(m); cap_prog.eval(m);
} }
quantize_int8(prog, ins_names, *int8_quant_params); quantize_int8(prog, *int8_quant_params, ins_names);
}
void quantize_int8(program& prog,
const target& t,
std::vector<program::parameter_map>& calibration_args)
{
std::vector<std::string> ins_names = {"dot", "convolution"};
quantize_int8(prog, t, ins_names, calibration_args);
} }
// For the input of each input argument, we need to insert a // For the input of each input argument, we need to insert a
......
...@@ -1821,7 +1821,7 @@ TEST_CASE(fp32_fp16_test) ...@@ -1821,7 +1821,7 @@ TEST_CASE(fp32_fp16_test)
auto test_case = [&](std::vector<std::string>&& op_names) { auto test_case = [&](std::vector<std::string>&& op_names) {
std::vector<float> gold_res = {2.0, 4.0, 6.0, 8.0, 10.0, 12.0}; std::vector<float> gold_res = {2.0, 4.0, 6.0, 8.0, 10.0, 12.0};
auto p = create_program(); auto p = create_program();
migraphx::quantize(p, op_names); migraphx::quantize_fp16(p, op_names);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<float> res; std::vector<float> res;
......
...@@ -3694,7 +3694,7 @@ struct test_fp32_fp16_lall : verify_program<test_fp32_fp16_lall> ...@@ -3694,7 +3694,7 @@ struct test_fp32_fp16_lall : verify_program<test_fp32_fp16_lall>
auto l1 = p.add_literal(migraphx::literal(s, data)); auto l1 = p.add_literal(migraphx::literal(s, data));
auto l2 = p.add_parameter("p2", s); auto l2 = p.add_parameter("p2", s);
p.add_instruction(migraphx::op::add{}, l1, l2); p.add_instruction(migraphx::op::add{}, l1, l2);
migraphx::quantize(p, {"all"}); migraphx::quantize_fp16(p, {"all"});
return p; return p;
}; };
}; };
...@@ -3710,7 +3710,7 @@ struct test_fp32_fp16_ladd : verify_program<test_fp32_fp16_ladd> ...@@ -3710,7 +3710,7 @@ struct test_fp32_fp16_ladd : verify_program<test_fp32_fp16_ladd>
auto l1 = p.add_literal(migraphx::literal(s, data)); auto l1 = p.add_literal(migraphx::literal(s, data));
auto l2 = p.add_parameter("p2", s); auto l2 = p.add_parameter("p2", s);
p.add_instruction(migraphx::op::add{}, l1, l2); p.add_instruction(migraphx::op::add{}, l1, l2);
migraphx::quantize(p, {"add"}); migraphx::quantize_fp16(p, {"add"});
return p; return p;
}; };
}; };
...@@ -3726,7 +3726,7 @@ struct test_fp32_fp16_add : verify_program<test_fp32_fp16_add> ...@@ -3726,7 +3726,7 @@ struct test_fp32_fp16_add : verify_program<test_fp32_fp16_add>
auto sum = p.add_instruction(migraphx::op::add{}, p1, p2); auto sum = p.add_instruction(migraphx::op::add{}, p1, p2);
auto diff = p.add_instruction(migraphx::op::sub{}, sum, p2); auto diff = p.add_instruction(migraphx::op::sub{}, sum, p2);
p.add_instruction(migraphx::op::add{}, diff, p1); p.add_instruction(migraphx::op::add{}, diff, p1);
migraphx::quantize(p, {"add"}); migraphx::quantize_fp16(p, {"add"});
return p; return p;
}; };
...@@ -3743,7 +3743,7 @@ struct test_fp32_fp16_sub : verify_program<test_fp32_fp16_sub> ...@@ -3743,7 +3743,7 @@ struct test_fp32_fp16_sub : verify_program<test_fp32_fp16_sub>
auto sum = p.add_instruction(migraphx::op::add{}, p1, p2); auto sum = p.add_instruction(migraphx::op::add{}, p1, p2);
auto diff = p.add_instruction(migraphx::op::sub{}, sum, p2); auto diff = p.add_instruction(migraphx::op::sub{}, sum, p2);
p.add_instruction(migraphx::op::add{}, diff, p1); p.add_instruction(migraphx::op::add{}, diff, p1);
migraphx::quantize(p, {"sub"}); migraphx::quantize_fp16(p, {"sub"});
return p; return p;
}; };
......
...@@ -43,7 +43,7 @@ TEST_CASE(param_add) ...@@ -43,7 +43,7 @@ TEST_CASE(param_add)
auto p1 = create_program_float(); auto p1 = create_program_float();
auto p2 = create_program_half(); auto p2 = create_program_half();
migraphx::quantize(p1); migraphx::quantize_fp16(p1);
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
...@@ -51,7 +51,7 @@ TEST_CASE(param_add) ...@@ -51,7 +51,7 @@ TEST_CASE(param_add)
auto p1 = create_program_float(); auto p1 = create_program_float();
auto p2 = create_program_half(); auto p2 = create_program_half();
migraphx::quantize(p1, {"add"}); migraphx::quantize_fp16(p1, {"add"});
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
} }
...@@ -127,7 +127,7 @@ TEST_CASE(param_add_sub) ...@@ -127,7 +127,7 @@ TEST_CASE(param_add_sub)
auto p1 = create_program_float(); auto p1 = create_program_float();
auto p2 = create_program_half_add(); auto p2 = create_program_half_add();
migraphx::quantize(p1, {"add"}); migraphx::quantize_fp16(p1, {"add"});
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
...@@ -135,7 +135,7 @@ TEST_CASE(param_add_sub) ...@@ -135,7 +135,7 @@ TEST_CASE(param_add_sub)
auto p1 = create_program_float(); auto p1 = create_program_float();
auto p2 = create_program_half_sub(); auto p2 = create_program_half_sub();
migraphx::quantize(p1, {"sub"}); migraphx::quantize_fp16(p1, {"sub"});
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
...@@ -143,7 +143,7 @@ TEST_CASE(param_add_sub) ...@@ -143,7 +143,7 @@ TEST_CASE(param_add_sub)
auto p1 = create_program_float(); auto p1 = create_program_float();
auto p2 = create_program_half_all(); auto p2 = create_program_half_all();
migraphx::quantize(p1); migraphx::quantize_fp16(p1);
migraphx::run_passes(p1, {migraphx::dead_code_elimination{}}); migraphx::run_passes(p1, {migraphx::dead_code_elimination{}});
EXPECT(p1 == p2); EXPECT(p1 == p2);
...@@ -181,7 +181,7 @@ TEST_CASE(literal_add) ...@@ -181,7 +181,7 @@ TEST_CASE(literal_add)
auto p1 = create_program_float(); auto p1 = create_program_float();
auto p2 = create_program_half(); auto p2 = create_program_half();
migraphx::quantize(p1, {"all"}); migraphx::quantize_fp16(p1, {"all"});
migraphx::run_passes(p1, migraphx::run_passes(p1,
{migraphx::propagate_constant{}, migraphx::dead_code_elimination{}}); {migraphx::propagate_constant{}, migraphx::dead_code_elimination{}});
migraphx::run_passes(p2, migraphx::run_passes(p2,
...@@ -194,7 +194,7 @@ TEST_CASE(literal_add) ...@@ -194,7 +194,7 @@ TEST_CASE(literal_add)
auto p1 = create_program_float(); auto p1 = create_program_float();
auto p2 = create_program_half(); auto p2 = create_program_half();
migraphx::quantize(p1, {"add"}); migraphx::quantize_fp16(p1, {"add"});
migraphx::run_passes(p1, migraphx::run_passes(p1,
{migraphx::propagate_constant{}, migraphx::dead_code_elimination{}}); {migraphx::propagate_constant{}, migraphx::dead_code_elimination{}});
migraphx::run_passes(p2, migraphx::run_passes(p2,
...@@ -313,7 +313,7 @@ TEST_CASE(dot_float) ...@@ -313,7 +313,7 @@ TEST_CASE(dot_float)
auto p = create_program(); auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{ const std::vector<std::pair<float, float>>& quant_params{
{0.1f, 0.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}}; {0.1f, 0.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
migraphx::quantize_int8(p, {"dot"}, quant_params); migraphx::quantize_int8(p, quant_params, {"dot"});
migraphx::run_passes(p, {migraphx::dead_code_elimination{}}); migraphx::run_passes(p, {migraphx::dead_code_elimination{}});
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
...@@ -375,7 +375,7 @@ TEST_CASE(dot_double_2args) ...@@ -375,7 +375,7 @@ TEST_CASE(dot_double_2args)
auto p = create_program(); auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}}; const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
migraphx::quantize_int8(p, {"dot"}, quant_params); migraphx::quantize_int8(p, quant_params, {"dot"});
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
EXPECT(p == qp); EXPECT(p == qp);
...@@ -440,7 +440,7 @@ TEST_CASE(dot_large_alpha_beta_float) ...@@ -440,7 +440,7 @@ TEST_CASE(dot_large_alpha_beta_float)
auto p = create_program(); auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{ const std::vector<std::pair<float, float>>& quant_params{
{0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}}; {0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
migraphx::quantize_int8(p, {"dot"}, quant_params); migraphx::quantize_int8(p, quant_params, {"dot"});
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
EXPECT(p == qp); EXPECT(p == qp);
...@@ -504,7 +504,7 @@ TEST_CASE(dot_large_alpha_beta_int32) ...@@ -504,7 +504,7 @@ TEST_CASE(dot_large_alpha_beta_int32)
auto p = create_program(); auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{ const std::vector<std::pair<float, float>>& quant_params{
{0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}}; {0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
migraphx::quantize_int8(p, {"dot"}, quant_params); migraphx::quantize_int8(p, quant_params, {"dot"});
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
EXPECT(p == qp); EXPECT(p == qp);
...@@ -548,7 +548,7 @@ TEST_CASE(dot_int32_one_arg) ...@@ -548,7 +548,7 @@ TEST_CASE(dot_int32_one_arg)
auto p = create_program(); auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{1.0f, 1.0f}}; const std::vector<std::pair<float, float>>& quant_params{{1.0f, 1.0f}};
migraphx::quantize_int8(p, {"dot"}, quant_params); migraphx::quantize_int8(p, quant_params, {"dot"});
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
EXPECT(p == qp); EXPECT(p == qp);
...@@ -622,7 +622,7 @@ TEST_CASE(dot_int32) ...@@ -622,7 +622,7 @@ TEST_CASE(dot_int32)
auto p = create_program(); auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{ const std::vector<std::pair<float, float>>& quant_params{
{0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}}; {0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
migraphx::quantize_int8(p, {"dot"}, quant_params); migraphx::quantize_int8(p, quant_params, {"dot"});
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
EXPECT(p == qp); EXPECT(p == qp);
...@@ -671,7 +671,7 @@ TEST_CASE(dot_float_convert) ...@@ -671,7 +671,7 @@ TEST_CASE(dot_float_convert)
auto p = create_program(); auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 1.0f}, {0.1f, 0.0f}}; const std::vector<std::pair<float, float>>& quant_params{{0.1f, 1.0f}, {0.1f, 0.0f}};
migraphx::quantize_int8(p, {"dot"}, quant_params); migraphx::quantize_int8(p, quant_params, {"dot"});
migraphx::run_passes(p, {migraphx::dead_code_elimination{}}); migraphx::run_passes(p, {migraphx::dead_code_elimination{}});
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
...@@ -726,7 +726,7 @@ TEST_CASE(conv_float) ...@@ -726,7 +726,7 @@ TEST_CASE(conv_float)
auto p = create_program(); auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}}; const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
migraphx::quantize_int8(p, {"convolution"}, quant_params); migraphx::quantize_int8(p, quant_params, {"convolution"});
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
EXPECT(p == qp); EXPECT(p == qp);
...@@ -782,7 +782,7 @@ TEST_CASE(conv_int32) ...@@ -782,7 +782,7 @@ TEST_CASE(conv_int32)
auto p = create_program(); auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}}; const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
migraphx::quantize_int8(p, {"convolution"}, quant_params); migraphx::quantize_int8(p, quant_params, {"convolution"});
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
EXPECT(p == qp); EXPECT(p == qp);
...@@ -840,7 +840,7 @@ TEST_CASE(conv_half) ...@@ -840,7 +840,7 @@ TEST_CASE(conv_half)
auto p = create_program(); auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}}; const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
migraphx::quantize_int8(p, {"convolution"}, quant_params); migraphx::quantize_int8(p, quant_params, {"convolution"});
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
EXPECT(p == qp); EXPECT(p == qp);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment