Commit 33ce5786 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add unit test for int8 quantization

parent dfb016fe
...@@ -43,8 +43,8 @@ void quantize_int8(program& prog, ...@@ -43,8 +43,8 @@ void quantize_int8(program& prog,
std::vector<program::parameter_map>& calibration_args); std::vector<program::parameter_map>& calibration_args);
void quantize_int8(program& prog, void quantize_int8(program& prog,
const target& t, const target& t,
std::vector<program::parameter_map>& calibration_args, const std::vector<std::string>& ins_names,
const std::vector<std::string>& ins_names); std::vector<program::parameter_map>& calibration_args);
void quantize_int8(program& prog, void quantize_int8(program& prog,
const std::vector<std::string>& ins_names, const std::vector<std::string>& ins_names,
const std::vector<std::pair<float, float>>& quant_params); const std::vector<std::pair<float, float>>& quant_params);
......
...@@ -196,9 +196,9 @@ PYBIND11_MODULE(migraphx, m) ...@@ -196,9 +196,9 @@ PYBIND11_MODULE(migraphx, m)
m.def("quantize_int8", m.def("quantize_int8",
[](migraphx::program& p, [](migraphx::program& p,
const migraphx::target& t, const migraphx::target& t,
std::vector<migraphx::program::parameter_map>& cali_args, std::vector<std::string>& ins_names,
std::vector<std::string>& ins_names) { std::vector<migraphx::program::parameter_map>& cali_args) {
migraphx::quantize_int8(p, t, cali_args, ins_names); migraphx::quantize_int8(p, t, ins_names, cali_args);
}); });
m.def("quantize_int8", m.def("quantize_int8",
[](migraphx::program& p, [](migraphx::program& p,
......
...@@ -162,7 +162,7 @@ void quantize(program& prog, const std::vector<std::string>& ins_names) ...@@ -162,7 +162,7 @@ void quantize(program& prog, const std::vector<std::string>& ins_names)
void quantize(program& prog) { quantize(prog, {"all"}); } void quantize(program& prog) { quantize(prog, {"all"}); }
static void quantize_ins(program& prog, static void ins_quantize_int8(program& prog,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref>& converted_inputs, std::vector<instruction_ref>& converted_inputs,
const std::vector<std::pair<float, float>>& ins_quant_params) const std::vector<std::pair<float, float>>& ins_quant_params)
...@@ -180,8 +180,8 @@ static void quantize_ins(program& prog, ...@@ -180,8 +180,8 @@ static void quantize_ins(program& prog,
float threshold = 50.0f; float threshold = 50.0f;
if(fabs(new_alpha) >= threshold && fabs(new_beta) >= threshold) if(fabs(new_alpha) >= threshold && fabs(new_beta) >= threshold)
{ {
int32_t quant_alpha = static_cast<int32_t>(new_alpha); int32_t quant_alpha = static_cast<int32_t>(std::round(new_alpha));
int32_t quant_beta = static_cast<int32_t>(new_beta); int32_t quant_beta = static_cast<int32_t>(std::round(new_beta));
if(shape::int32_type == orig_type) if(shape::int32_type == orig_type)
{ {
prog.replace_instruction( prog.replace_instruction(
...@@ -308,14 +308,6 @@ void quantize_int8(program& prog, ...@@ -308,14 +308,6 @@ void quantize_int8(program& prog,
const std::vector<std::string>& ins_names, const std::vector<std::string>& ins_names,
const std::vector<std::pair<float, float>>& quant_params) const std::vector<std::pair<float, float>>& quant_params)
{ {
// for(size_t i = 0; i < quant_params.size(); i++)
// {
// auto param = quant_params.at(i);
// std::cout << "index = " << i << ", scale = " << param.first << "\t" << param.second
// << std::endl;
// }
// std::cout << std::endl;
// For now, we only support the int8 quantization of gemm and convolution // For now, we only support the int8 quantization of gemm and convolution
std::vector<std::string> op_names = {"dot", "convolution"}; std::vector<std::string> op_names = {"dot", "convolution"};
if(!std::all_of(ins_names.begin(), ins_names.end(), [&](auto name) { if(!std::all_of(ins_names.begin(), ins_names.end(), [&](auto name) {
...@@ -395,7 +387,7 @@ void quantize_int8(program& prog, ...@@ -395,7 +387,7 @@ void quantize_int8(program& prog,
continue; continue;
} }
quantize_ins(prog, ins, converted_inputs, ins_quant_params); ins_quantize_int8(prog, ins, converted_inputs, ins_quant_params);
} }
if(quant_param_index != quant_params.size()) if(quant_param_index != quant_params.size())
...@@ -406,8 +398,8 @@ void quantize_int8(program& prog, ...@@ -406,8 +398,8 @@ void quantize_int8(program& prog,
void quantize_int8(program& prog, void quantize_int8(program& prog,
const target& t, const target& t,
std::vector<program::parameter_map>& calibration_args, const std::vector<std::string>& ins_names,
const std::vector<std::string>& ins_names) std::vector<program::parameter_map>& calibration_args)
{ {
// insert capture operator // insert capture operator
auto cap_prog = prog; auto cap_prog = prog;
...@@ -444,7 +436,7 @@ void quantize_int8(program& prog, ...@@ -444,7 +436,7 @@ void quantize_int8(program& prog,
std::vector<program::parameter_map>& calibration_args) std::vector<program::parameter_map>& calibration_args)
{ {
std::vector<std::string> ins_names = {"dot", "convolution"}; std::vector<std::string> ins_names = {"dot", "convolution"};
quantize_int8(prog, t, calibration_args, ins_names); quantize_int8(prog, t, ins_names, calibration_args);
} }
// For the input of each input argument, we need to insert a // For the input of each input argument, we need to insert a
......
...@@ -65,4 +65,69 @@ TEST_CASE(target_copy) ...@@ -65,4 +65,69 @@ TEST_CASE(target_copy)
} }
} }
TEST_CASE(dot_large_alpha_beta_float)
{
auto create_program = [] {
migraphx::program p;
migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = p.add_parameter("a", sa);
auto pb = p.add_parameter("b", sb);
auto pc = p.add_parameter("c", sc);
p.add_instruction(migraphx::op::dot{20.0f, 50.5f}, pa, pb, pc);
return p;
};
auto create_int8_quantized_prog = [] {
migraphx::program p;
migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = p.add_parameter("a", sa);
auto pb = p.add_parameter("b", sb);
auto pc = p.add_parameter("c", sc);
// quantize parameter a to int8 type, multiply the scale
std::vector<float> vfa(sa.elements(), 0.1f);
auto fa = p.add_literal(migraphx::literal(sa, vfa));
auto ma = p.add_instruction(migraphx::op::mul{}, fa, pa);
// add the shift
std::vector<float> vsa(sa.elements(), 1.0f);
auto sfta = p.add_literal(migraphx::literal(sa, vsa));
auto msa = p.add_instruction(migraphx::op::add{}, sfta, ma);
auto ra = p.add_instruction(migraphx::op::round{}, msa);
auto ca = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, ra);
auto qa = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca);
// quantize parameter b to int8 type
auto insert_loc = std::next(pb);
std::vector<float> vfb(sb.elements(), 0.1f);
auto fb = p.add_literal(migraphx::literal(sb, vfb));
auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, pb);
auto rb = p.insert_instruction(insert_loc, migraphx::op::round{}, mb);
auto cb = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rb);
auto qb = p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb);
// quantize parameter b to int32 type
auto qc = p.insert_instruction(std::next(pc), migraphx::op::convert{migraphx::shape::int32_type}, pc);
auto qdot = p.add_instruction(migraphx::op::quant_dot{2000, 51}, qa, qb, qc);
p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, qdot);
return p;
};
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
// default scale 64.0f is used for all args
migraphx::quantize_int8(p, {"dot"}, quant_params);
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment