Commit 69da37bb authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix review comments

parent aefffcda
......@@ -14,8 +14,8 @@ inline namespace MIGRAPHX_INLINE_NS {
struct program;
void quantize(program& prog, const std::vector<std::string>& ins_names);
void quantize(program& prog);
void quantize_fp16(program& prog, const std::vector<std::string>& ins_names);
void quantize_fp16(program& prog);
// insert the capture operator for the inputs of each operator to be quantized
// to int8
......@@ -24,9 +24,7 @@ std::size_t capture_arguments(program& prog,
const std::function<void(std::size_t, std::vector<argument>)>& func);
std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments_impl(program& prog,
const target& t,
const std::vector<std::string>& ins_names = {"dot", "convolution"});
capture_arguments_impl(program& prog, const target& t, const std::vector<std::string>& ins_names);
template <class T>
std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments(
......@@ -40,14 +38,11 @@ std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments(
void quantize_int8(program& prog,
const target& t,
std::vector<program::parameter_map>& calibration_args);
std::vector<program::parameter_map>& calibration_args,
const std::vector<std::string>& ins_names = {"dot", "convolution"});
void quantize_int8(program& prog,
const target& t,
const std::vector<std::string>& ins_names,
std::vector<program::parameter_map>& calibration_args);
void quantize_int8(program& prog,
const std::vector<std::string>& ins_names,
const std::vector<std::pair<float, float>>& quant_params);
const std::vector<std::pair<float, float>>& quant_params,
const std::vector<std::string>& ins_names);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -183,22 +183,22 @@ PYBIND11_MODULE(migraphx, m)
});
m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0);
m.def("quantize", [](migraphx::program& p, std::vector<std::string>& ins_names) {
migraphx::quantize(p, ins_names);
m.def("quantize_fp16", [](migraphx::program& p, std::vector<std::string>& ins_names) {
migraphx::quantize_fp16(p, ins_names);
});
m.def("quantize", [](migraphx::program& p) { migraphx::quantize(p, {"all"}); });
m.def("quantize_fp16", [](migraphx::program& p) { migraphx::quantize_fp16(p, {"all"}); });
m.def("quantize_int8",
[](migraphx::program& p,
std::vector<std::string>& ins_names,
std::vector<std::pair<float, float>>& quant_params) {
migraphx::quantize_int8(p, ins_names, quant_params);
migraphx::quantize_int8(p, quant_params, ins_names);
});
m.def("quantize_int8",
[](migraphx::program& p,
const migraphx::target& t,
std::vector<std::string>& ins_names,
std::vector<migraphx::program::parameter_map>& cali_args) {
migraphx::quantize_int8(p, t, ins_names, cali_args);
migraphx::quantize_int8(p, t, cali_args, ins_names);
});
m.def("quantize_int8",
[](migraphx::program& p,
......
......@@ -96,7 +96,7 @@ instruction_ref insert_quant_ins(program& prog,
// For the conversion, there could be cases of overflowing, but it
// is very rare in the area of deeping learning, so we just do a
// truncate of the input to get the fp16.
void quantize(program& prog, const std::vector<std::string>& ins_names)
void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
{
std::unordered_map<instruction_ref, instruction_ref> map_fp16;
for(auto ins : iterator_for(prog))
......@@ -161,7 +161,7 @@ void quantize(program& prog, const std::vector<std::string>& ins_names)
}
}
void quantize(program& prog) { quantize(prog, {"all"}); }
void quantize_fp16(program& prog) { quantize_fp16(prog, {"all"}); }
static void ins_quantize_int8(program& prog,
instruction_ref ins,
......@@ -306,8 +306,8 @@ static void ins_quantize_int8(program& prog,
// a shift, then the convert can be done as v_int8 = fp * scale + shift.
// To simplify the changes, we consider shift as 0.0f for now.
void quantize_int8(program& prog,
const std::vector<std::string>& ins_names,
const std::vector<std::pair<float, float>>& quant_params)
const std::vector<std::pair<float, float>>& quant_params,
const std::vector<std::string>& ins_names)
{
// For now, we only support the int8 quantization of gemm and convolution
std::vector<std::string> op_names = {"dot", "convolution"};
......@@ -402,8 +402,8 @@ void quantize_int8(program& prog,
void quantize_int8(program& prog,
const target& t,
const std::vector<std::string>& ins_names,
std::vector<program::parameter_map>& calibration_args)
std::vector<program::parameter_map>& calibration_args,
const std::vector<std::string>& ins_names)
{
// insert capture operator
auto cap_prog = prog;
......@@ -432,15 +432,7 @@ void quantize_int8(program& prog,
cap_prog.eval(m);
}
quantize_int8(prog, ins_names, *int8_quant_params);
}
void quantize_int8(program& prog,
const target& t,
std::vector<program::parameter_map>& calibration_args)
{
std::vector<std::string> ins_names = {"dot", "convolution"};
quantize_int8(prog, t, ins_names, calibration_args);
quantize_int8(prog, *int8_quant_params, ins_names);
}
// For the input of each input argument, we need to insert a
......
......@@ -1821,7 +1821,7 @@ TEST_CASE(fp32_fp16_test)
auto test_case = [&](std::vector<std::string>&& op_names) {
std::vector<float> gold_res = {2.0, 4.0, 6.0, 8.0, 10.0, 12.0};
auto p = create_program();
migraphx::quantize(p, op_names);
migraphx::quantize_fp16(p, op_names);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> res;
......
......@@ -3694,7 +3694,7 @@ struct test_fp32_fp16_lall : verify_program<test_fp32_fp16_lall>
auto l1 = p.add_literal(migraphx::literal(s, data));
auto l2 = p.add_parameter("p2", s);
p.add_instruction(migraphx::op::add{}, l1, l2);
migraphx::quantize(p, {"all"});
migraphx::quantize_fp16(p, {"all"});
return p;
};
};
......@@ -3710,7 +3710,7 @@ struct test_fp32_fp16_ladd : verify_program<test_fp32_fp16_ladd>
auto l1 = p.add_literal(migraphx::literal(s, data));
auto l2 = p.add_parameter("p2", s);
p.add_instruction(migraphx::op::add{}, l1, l2);
migraphx::quantize(p, {"add"});
migraphx::quantize_fp16(p, {"add"});
return p;
};
};
......@@ -3726,7 +3726,7 @@ struct test_fp32_fp16_add : verify_program<test_fp32_fp16_add>
auto sum = p.add_instruction(migraphx::op::add{}, p1, p2);
auto diff = p.add_instruction(migraphx::op::sub{}, sum, p2);
p.add_instruction(migraphx::op::add{}, diff, p1);
migraphx::quantize(p, {"add"});
migraphx::quantize_fp16(p, {"add"});
return p;
};
......@@ -3743,7 +3743,7 @@ struct test_fp32_fp16_sub : verify_program<test_fp32_fp16_sub>
auto sum = p.add_instruction(migraphx::op::add{}, p1, p2);
auto diff = p.add_instruction(migraphx::op::sub{}, sum, p2);
p.add_instruction(migraphx::op::add{}, diff, p1);
migraphx::quantize(p, {"sub"});
migraphx::quantize_fp16(p, {"sub"});
return p;
};
......
......@@ -43,7 +43,7 @@ TEST_CASE(param_add)
auto p1 = create_program_float();
auto p2 = create_program_half();
migraphx::quantize(p1);
migraphx::quantize_fp16(p1);
EXPECT(p1 == p2);
}
......@@ -51,7 +51,7 @@ TEST_CASE(param_add)
auto p1 = create_program_float();
auto p2 = create_program_half();
migraphx::quantize(p1, {"add"});
migraphx::quantize_fp16(p1, {"add"});
EXPECT(p1 == p2);
}
}
......@@ -127,7 +127,7 @@ TEST_CASE(param_add_sub)
auto p1 = create_program_float();
auto p2 = create_program_half_add();
migraphx::quantize(p1, {"add"});
migraphx::quantize_fp16(p1, {"add"});
EXPECT(p1 == p2);
}
......@@ -135,7 +135,7 @@ TEST_CASE(param_add_sub)
auto p1 = create_program_float();
auto p2 = create_program_half_sub();
migraphx::quantize(p1, {"sub"});
migraphx::quantize_fp16(p1, {"sub"});
EXPECT(p1 == p2);
}
......@@ -143,7 +143,7 @@ TEST_CASE(param_add_sub)
auto p1 = create_program_float();
auto p2 = create_program_half_all();
migraphx::quantize(p1);
migraphx::quantize_fp16(p1);
migraphx::run_passes(p1, {migraphx::dead_code_elimination{}});
EXPECT(p1 == p2);
......@@ -181,7 +181,7 @@ TEST_CASE(literal_add)
auto p1 = create_program_float();
auto p2 = create_program_half();
migraphx::quantize(p1, {"all"});
migraphx::quantize_fp16(p1, {"all"});
migraphx::run_passes(p1,
{migraphx::propagate_constant{}, migraphx::dead_code_elimination{}});
migraphx::run_passes(p2,
......@@ -194,7 +194,7 @@ TEST_CASE(literal_add)
auto p1 = create_program_float();
auto p2 = create_program_half();
migraphx::quantize(p1, {"add"});
migraphx::quantize_fp16(p1, {"add"});
migraphx::run_passes(p1,
{migraphx::propagate_constant{}, migraphx::dead_code_elimination{}});
migraphx::run_passes(p2,
......@@ -313,7 +313,7 @@ TEST_CASE(dot_float)
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{
{0.1f, 0.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
migraphx::quantize_int8(p, {"dot"}, quant_params);
migraphx::quantize_int8(p, quant_params, {"dot"});
migraphx::run_passes(p, {migraphx::dead_code_elimination{}});
auto qp = create_int8_quantized_prog();
......@@ -375,7 +375,7 @@ TEST_CASE(dot_double_2args)
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
migraphx::quantize_int8(p, {"dot"}, quant_params);
migraphx::quantize_int8(p, quant_params, {"dot"});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
......@@ -440,7 +440,7 @@ TEST_CASE(dot_large_alpha_beta_float)
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{
{0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
migraphx::quantize_int8(p, {"dot"}, quant_params);
migraphx::quantize_int8(p, quant_params, {"dot"});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
......@@ -504,7 +504,7 @@ TEST_CASE(dot_large_alpha_beta_int32)
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{
{0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
migraphx::quantize_int8(p, {"dot"}, quant_params);
migraphx::quantize_int8(p, quant_params, {"dot"});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
......@@ -548,7 +548,7 @@ TEST_CASE(dot_int32_one_arg)
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{1.0f, 1.0f}};
migraphx::quantize_int8(p, {"dot"}, quant_params);
migraphx::quantize_int8(p, quant_params, {"dot"});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
......@@ -622,7 +622,7 @@ TEST_CASE(dot_int32)
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{
{0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
migraphx::quantize_int8(p, {"dot"}, quant_params);
migraphx::quantize_int8(p, quant_params, {"dot"});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
......@@ -671,7 +671,7 @@ TEST_CASE(dot_float_convert)
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 1.0f}, {0.1f, 0.0f}};
migraphx::quantize_int8(p, {"dot"}, quant_params);
migraphx::quantize_int8(p, quant_params, {"dot"});
migraphx::run_passes(p, {migraphx::dead_code_elimination{}});
auto qp = create_int8_quantized_prog();
......@@ -726,7 +726,7 @@ TEST_CASE(conv_float)
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
migraphx::quantize_int8(p, {"convolution"}, quant_params);
migraphx::quantize_int8(p, quant_params, {"convolution"});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
......@@ -782,7 +782,7 @@ TEST_CASE(conv_int32)
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
migraphx::quantize_int8(p, {"convolution"}, quant_params);
migraphx::quantize_int8(p, quant_params, {"convolution"});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
......@@ -840,7 +840,7 @@ TEST_CASE(conv_half)
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
migraphx::quantize_int8(p, {"convolution"}, quant_params);
migraphx::quantize_int8(p, quant_params, {"convolution"});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment