Commit 5447f39c authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fixed a bug in the int8 quantization function

parent 965ac6fc
...@@ -301,10 +301,14 @@ void quantize_int8(program& prog, ...@@ -301,10 +301,14 @@ void quantize_int8(program& prog,
else else
{ {
auto q_dot = prog.insert_instruction(ins, op::quant_dot{1, 0}, converted_inputs); auto q_dot = prog.insert_instruction(ins, op::quant_dot{1, 0}, converted_inputs);
auto f_dot = prog.insert_instruction(ins, op::convert{shape::float_type}, q_dot);
auto c_shape = q_dot->get_shape();
std::vector<float> vec_alpha(c_shape.elements(), new_alpha);
auto l_alpha = prog.add_literal(literal({shape::float_type, c_shape.lens()}, vec_alpha));
if(inputs.size() == 3 and dot_op.beta != 0.0f) if(inputs.size() == 3 and dot_op.beta != 0.0f)
{ {
auto alpha_ab = prog.insert_instruction(ins, op::convert{orig_type}, q_dot); auto alpha_ab = prog.insert_instruction(ins, op::mul{}, l_alpha, f_dot);
auto c_shape = q_dot->get_shape();
std::vector<float> vec_beta(c_shape.elements(), dot_op.beta); std::vector<float> vec_beta(c_shape.elements(), dot_op.beta);
auto l_beta = auto l_beta =
prog.add_literal(literal({shape::float_type, c_shape.lens()}, vec_beta)); prog.add_literal(literal({shape::float_type, c_shape.lens()}, vec_beta));
...@@ -320,11 +324,28 @@ void quantize_int8(program& prog, ...@@ -320,11 +324,28 @@ void quantize_int8(program& prog,
{ {
beta_c = prog.insert_instruction(ins, op::mul{}, l_beta, inputs.back()); beta_c = prog.insert_instruction(ins, op::mul{}, l_beta, inputs.back());
} }
if (orig_type == shape::float_type)
{
prog.replace_instruction(ins, op::add{}, alpha_ab, beta_c); prog.replace_instruction(ins, op::add{}, alpha_ab, beta_c);
} }
else else
{ {
prog.replace_instruction(ins, op::convert{orig_type}, q_dot); auto f_res = prog.insert_instruction(ins, op::add{}, alpha_ab, beta_c);
prog.replace_instruction(ins, op::convert{orig_type}, f_res);
}
}
else
{
if (orig_type == shape::float_type)
{
prog.replace_instruction(ins, op::mul{}, l_alpha, f_dot);
}
else
{
auto alpha_ab = prog.insert_instruction(ins, op::mul{}, l_alpha, f_dot);
prog.replace_instruction(ins, op::convert{orig_type}, alpha_ab);
}
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment