"vscode:/vscode.git/clone" did not exist on "64c628434d9afebf7fa39ecfb59ab2a4acc17acd"
Commit 33ce5786 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add unit test for int8 quantization

parent dfb016fe
......@@ -43,8 +43,8 @@ void quantize_int8(program& prog,
std::vector<program::parameter_map>& calibration_args);
void quantize_int8(program& prog,
const target& t,
std::vector<program::parameter_map>& calibration_args,
const std::vector<std::string>& ins_names);
const std::vector<std::string>& ins_names,
std::vector<program::parameter_map>& calibration_args);
void quantize_int8(program& prog,
const std::vector<std::string>& ins_names,
const std::vector<std::pair<float, float>>& quant_params);
......
......@@ -196,9 +196,9 @@ PYBIND11_MODULE(migraphx, m)
m.def("quantize_int8",
[](migraphx::program& p,
const migraphx::target& t,
std::vector<migraphx::program::parameter_map>& cali_args,
std::vector<std::string>& ins_names) {
migraphx::quantize_int8(p, t, cali_args, ins_names);
std::vector<std::string>& ins_names,
std::vector<migraphx::program::parameter_map>& cali_args) {
migraphx::quantize_int8(p, t, ins_names, cali_args);
});
m.def("quantize_int8",
[](migraphx::program& p,
......
......@@ -162,7 +162,7 @@ void quantize(program& prog, const std::vector<std::string>& ins_names)
void quantize(program& prog) { quantize(prog, {"all"}); }
static void quantize_ins(program& prog,
static void ins_quantize_int8(program& prog,
instruction_ref ins,
std::vector<instruction_ref>& converted_inputs,
const std::vector<std::pair<float, float>>& ins_quant_params)
......@@ -180,8 +180,8 @@ static void quantize_ins(program& prog,
float threshold = 50.0f;
if(fabs(new_alpha) >= threshold && fabs(new_beta) >= threshold)
{
int32_t quant_alpha = static_cast<int32_t>(new_alpha);
int32_t quant_beta = static_cast<int32_t>(new_beta);
int32_t quant_alpha = static_cast<int32_t>(std::round(new_alpha));
int32_t quant_beta = static_cast<int32_t>(std::round(new_beta));
if(shape::int32_type == orig_type)
{
prog.replace_instruction(
......@@ -308,14 +308,6 @@ void quantize_int8(program& prog,
const std::vector<std::string>& ins_names,
const std::vector<std::pair<float, float>>& quant_params)
{
// for(size_t i = 0; i < quant_params.size(); i++)
// {
// auto param = quant_params.at(i);
// std::cout << "index = " << i << ", scale = " << param.first << "\t" << param.second
// << std::endl;
// }
// std::cout << std::endl;
// For now, we only support the int8 quantization of gemm and convolution
std::vector<std::string> op_names = {"dot", "convolution"};
if(!std::all_of(ins_names.begin(), ins_names.end(), [&](auto name) {
......@@ -395,7 +387,7 @@ void quantize_int8(program& prog,
continue;
}
quantize_ins(prog, ins, converted_inputs, ins_quant_params);
ins_quantize_int8(prog, ins, converted_inputs, ins_quant_params);
}
if(quant_param_index != quant_params.size())
......@@ -406,8 +398,8 @@ void quantize_int8(program& prog,
void quantize_int8(program& prog,
const target& t,
std::vector<program::parameter_map>& calibration_args,
const std::vector<std::string>& ins_names)
const std::vector<std::string>& ins_names,
std::vector<program::parameter_map>& calibration_args)
{
// insert capture operator
auto cap_prog = prog;
......@@ -444,7 +436,7 @@ void quantize_int8(program& prog,
std::vector<program::parameter_map>& calibration_args)
{
std::vector<std::string> ins_names = {"dot", "convolution"};
quantize_int8(prog, t, calibration_args, ins_names);
quantize_int8(prog, t, ins_names, calibration_args);
}
// For the input of each input argument, we need to insert a
......
......@@ -65,4 +65,69 @@ TEST_CASE(target_copy)
}
}
TEST_CASE(dot_large_alpha_beta_float)
{
auto create_program = [] {
migraphx::program p;
migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = p.add_parameter("a", sa);
auto pb = p.add_parameter("b", sb);
auto pc = p.add_parameter("c", sc);
p.add_instruction(migraphx::op::dot{20.0f, 50.5f}, pa, pb, pc);
return p;
};
auto create_int8_quantized_prog = [] {
migraphx::program p;
migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = p.add_parameter("a", sa);
auto pb = p.add_parameter("b", sb);
auto pc = p.add_parameter("c", sc);
// quantize parameter a to int8 type, multiply the scale
std::vector<float> vfa(sa.elements(), 0.1f);
auto fa = p.add_literal(migraphx::literal(sa, vfa));
auto ma = p.add_instruction(migraphx::op::mul{}, fa, pa);
// add the shift
std::vector<float> vsa(sa.elements(), 1.0f);
auto sfta = p.add_literal(migraphx::literal(sa, vsa));
auto msa = p.add_instruction(migraphx::op::add{}, sfta, ma);
auto ra = p.add_instruction(migraphx::op::round{}, msa);
auto ca = p.add_instruction(migraphx::op::clip{127.0f, -128.0f}, ra);
auto qa = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, ca);
// quantize parameter b to int8 type
auto insert_loc = std::next(pb);
std::vector<float> vfb(sb.elements(), 0.1f);
auto fb = p.add_literal(migraphx::literal(sb, vfb));
auto mb = p.insert_instruction(insert_loc, migraphx::op::mul{}, fb, pb);
auto rb = p.insert_instruction(insert_loc, migraphx::op::round{}, mb);
auto cb = p.insert_instruction(insert_loc, migraphx::op::clip{127.0f, -128.0f}, rb);
auto qb = p.insert_instruction(insert_loc, migraphx::op::convert{migraphx::shape::int8_type}, cb);
// quantize parameter b to int32 type
auto qc = p.insert_instruction(std::next(pc), migraphx::op::convert{migraphx::shape::int32_type}, pc);
auto qdot = p.add_instruction(migraphx::op::quant_dot{2000, 51}, qa, qb, qc);
p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, qdot);
return p;
};
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
// default scale 64.0f is used for all args
migraphx::quantize_int8(p, {"dot"}, quant_params);
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment