"tools/vscode:/vscode.git/clone" did not exist on "1b692d0fa298270fef438f3bc2394028d9272e75"
Commit 4c75e7fa authored by Shucai Xiao's avatar Shucai Xiao
Browse files

small fix to achieve same result between cpu and gpu

parent ae1cb853
......@@ -196,6 +196,9 @@ PYBIND11_MODULE(migraphx, m)
std::vector<std::pair<float, float>>& quant_params) {
migraphx::quantize_int8(p, ins_names, quant_params);
});
m.def("quantize_int8", [](migraphx::program& p, std::vector<std::string>& ins_names) {
migraphx::quantize_int8(p, ins_names);
});
m.def("quantize_int8", [](migraphx::program& p) { migraphx::quantize_int8(p); });
m.def("capture_arguments", [](migraphx::program& p, const std::vector<std::string>& ins_names) {
......
......@@ -74,10 +74,10 @@ instruction_ref insert_quant_ins(program& prog,
shifted_ins = prog.insert_instruction(insert_loc, op::add{}, l_shift, float_ins);
}
auto rounded_ins = prog.insert_instruction(insert_loc, op::round{}, shifted_ins);
auto clipped_ins =
prog.insert_instruction(insert_loc, op::clip{127.0f, -128.0f}, shifted_ins);
auto rounded_ins = prog.insert_instruction(insert_loc, op::round{}, clipped_ins);
quant_ins = prog.insert_instruction(insert_loc, op::convert{type}, rounded_ins);
prog.insert_instruction(insert_loc, op::clip{127.0f, -128.0f}, rounded_ins);
quant_ins = prog.insert_instruction(insert_loc, op::convert{type}, clipped_ins);
}
else
{
......@@ -283,8 +283,7 @@ void quantize_int8(program& prog,
{
int32_t quant_alpha = static_cast<int32_t>(new_alpha);
int32_t quant_beta = static_cast<int32_t>(new_beta);
shape quant_shape = compute_shape(op::quant_dot{1, 0}, converted_inputs);
if(quant_shape.type() == orig_type)
if(shape::int32_type == orig_type)
{
prog.replace_instruction(
ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs);
......@@ -300,6 +299,10 @@ void quantize_int8(program& prog,
// relative rounding error
else
{
if (converted_inputs.size() == 3)
{
converted_inputs.pop_back();
}
auto q_dot = prog.insert_instruction(ins, op::quant_dot{1, 0}, converted_inputs);
auto f_dot = prog.insert_instruction(ins, op::convert{shape::float_type}, q_dot);
auto c_shape = q_dot->get_shape();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment