"configs/vscode:/vscode.git/clone" did not exist on "e6fb90ca332594b4e6d3d8edecf8a90756863daa"
Commit 4c75e7fa authored by Shucai Xiao's avatar Shucai Xiao
Browse files

small fix to achieve same result between cpu and gpu

parent ae1cb853
...@@ -196,6 +196,9 @@ PYBIND11_MODULE(migraphx, m) ...@@ -196,6 +196,9 @@ PYBIND11_MODULE(migraphx, m)
std::vector<std::pair<float, float>>& quant_params) { std::vector<std::pair<float, float>>& quant_params) {
migraphx::quantize_int8(p, ins_names, quant_params); migraphx::quantize_int8(p, ins_names, quant_params);
}); });
m.def("quantize_int8", [](migraphx::program& p, std::vector<std::string>& ins_names) {
migraphx::quantize_int8(p, ins_names);
});
m.def("quantize_int8", [](migraphx::program& p) { migraphx::quantize_int8(p); }); m.def("quantize_int8", [](migraphx::program& p) { migraphx::quantize_int8(p); });
m.def("capture_arguments", [](migraphx::program& p, const std::vector<std::string>& ins_names) { m.def("capture_arguments", [](migraphx::program& p, const std::vector<std::string>& ins_names) {
......
...@@ -74,10 +74,10 @@ instruction_ref insert_quant_ins(program& prog, ...@@ -74,10 +74,10 @@ instruction_ref insert_quant_ins(program& prog,
shifted_ins = prog.insert_instruction(insert_loc, op::add{}, l_shift, float_ins); shifted_ins = prog.insert_instruction(insert_loc, op::add{}, l_shift, float_ins);
} }
auto rounded_ins = prog.insert_instruction(insert_loc, op::round{}, shifted_ins);
auto clipped_ins = auto clipped_ins =
prog.insert_instruction(insert_loc, op::clip{127.0f, -128.0f}, shifted_ins); prog.insert_instruction(insert_loc, op::clip{127.0f, -128.0f}, rounded_ins);
auto rounded_ins = prog.insert_instruction(insert_loc, op::round{}, clipped_ins); quant_ins = prog.insert_instruction(insert_loc, op::convert{type}, clipped_ins);
quant_ins = prog.insert_instruction(insert_loc, op::convert{type}, rounded_ins);
} }
else else
{ {
...@@ -283,8 +283,7 @@ void quantize_int8(program& prog, ...@@ -283,8 +283,7 @@ void quantize_int8(program& prog,
{ {
int32_t quant_alpha = static_cast<int32_t>(new_alpha); int32_t quant_alpha = static_cast<int32_t>(new_alpha);
int32_t quant_beta = static_cast<int32_t>(new_beta); int32_t quant_beta = static_cast<int32_t>(new_beta);
shape quant_shape = compute_shape(op::quant_dot{1, 0}, converted_inputs); if(shape::int32_type == orig_type)
if(quant_shape.type() == orig_type)
{ {
prog.replace_instruction( prog.replace_instruction(
ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs); ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs);
...@@ -300,6 +299,10 @@ void quantize_int8(program& prog, ...@@ -300,6 +299,10 @@ void quantize_int8(program& prog,
// relative rounding error // relative rounding error
else else
{ {
if (converted_inputs.size() == 3)
{
converted_inputs.pop_back();
}
auto q_dot = prog.insert_instruction(ins, op::quant_dot{1, 0}, converted_inputs); auto q_dot = prog.insert_instruction(ins, op::quant_dot{1, 0}, converted_inputs);
auto f_dot = prog.insert_instruction(ins, op::convert{shape::float_type}, q_dot); auto f_dot = prog.insert_instruction(ins, op::convert{shape::float_type}, q_dot);
auto c_shape = q_dot->get_shape(); auto c_shape = q_dot->get_shape();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment