Commit ea9776f5 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

code backup for the int8 quantization

parent f00002d3
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <migraphx/op/convert.hpp> #include <migraphx/op/convert.hpp>
#include <migraphx/op/dot.hpp> #include <migraphx/op/dot.hpp>
#include <migraphx/op/mul.hpp> #include <migraphx/op/mul.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/quant_dot.hpp> #include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/convolution.hpp> #include <migraphx/op/convolution.hpp>
#include <migraphx/op/quant_convolution.hpp> #include <migraphx/op/quant_convolution.hpp>
...@@ -152,15 +153,15 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names) ...@@ -152,15 +153,15 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
// operation, if it has 3 inputs, then the last one should // operation, if it has 3 inputs, then the last one should
// be converted to int32_type // be converted to int32_type
shape::type_t quant_type = shape::int8_type; shape::type_t quant_type = shape::int8_type;
auto param = int8_param[param_index++];
if(ins->name() == "dot" and inputs.size() == 3 and input == inputs.back()) if(ins->name() == "dot" and inputs.size() == 3 and input == inputs.back())
{ {
quant_type = shape::int32_type; quant_type = shape::int32_type;
} }
auto param = int8_param[param_index++];
auto s = input->get_shape(); auto s = input->get_shape();
if(s.type() == shape::float_type || s.type() == shape::double_type || if((s.type() == shape::float_type || s.type() == shape::double_type ||
s.type() == shape::int32_type) s.type() == shape::int32_type) && s.type() != quant_type)
{ {
// if the input is a convert operator, uses its input // if the input is a convert operator, uses its input
// as its current input // as its current input
...@@ -197,42 +198,116 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names) ...@@ -197,42 +198,116 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
continue; continue;
} }
auto op = ins->get_operator(); // When converting from other types to int8_type, there are parameters
shape ins_shape{}; // used as scale and shift(.0f), which will generate results diffrent from
// just to compute the output shape // the original results. To adjust the output to be "correct(approximatly
// equal)", we need additional calculation for the adjustment
if(ins->name() == "dot") if(ins->name() == "dot")
{ {
ins_shape = compute_shape(op::quant_dot{}, converted_inputs); auto dot_op = any_cast<op::dot>(ins->get_operator());
float new_alpha = dot_op.alpha / (int8_param[0].first * int8_param[1].first);
float new_beta = dot_op.beta;
// We need additional checking about the quant_alpha value. If
// abs(quant_alpha) > 50 (some tmp value set here), we can convert
// it to an integer as the new_alpha in the quant_dot
float threshold = 50.0f;
if (fabs(new_alpha) >= threshold && fabs(new_beta) >= threshold)
{
int32_t quant_alpha = static_cast<int32_t>(new_alpha);
int32_t quant_beta = static_cast<int32_t>(new_beta);
shape quant_shape = compute_shape(op::quant_dot{1, 0}, converted_inputs);
if (quant_shape.type() == orig_type)
{
prog.replace_instruction(ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs);
} }
else else
{ {
ins_shape = compute_shape(op::quant_convolution{}, converted_inputs); auto quant_dot = prog.insert_instruction(ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs);
prog.replace_instruction(ins, op::convert{orig_type}, quant_dot);
} }
}
if(ins_shape.type() != orig_type) // only alpha can be quantized, quantization of beta will cause
// big error, so we have to manually do the multiplication and
// addition
else if (fabs(new_alpha) >= threshold)
{ {
// check the dead code case to avoid assert int32_t quant_alpha = static_cast<int32_t>(new_alpha);
bool output_empty = ins->outputs().empty(); int32_t quant_beta = 0;
// this conversion can be only from int32 to float or double if (orig_type == shape::int32_type)
auto ins_orig_type =
prog.insert_instruction(std::next(ins), op::convert{orig_type}, ins);
if(!output_empty)
{ {
prog.replace_instruction(ins, ins_orig_type); if (inputs.size() == 2 or dot_op.beta == 0.0f)
{
prog.replace_instruction(ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs);
} }
// if there are 3 inputs, we need to consider the third argument
else
{
auto q_dot = prog.insert_instruction(ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs);
std::vector<float> vec_beta(q_dot->get_shape().elements(), dot_op.beta);
auto l_beta = prog.add_literal(literal{orig_type, vec_beta});
auto beta_c = prog.insert_instruction(ins, op::mul{}, l_beta, inputs.back());
prog.replace_instruction(ins, op::add{}, q_dot, beta_c);
} }
}
// When converting from other types to int8_type, there are parameters else
// used as scale and shift(.0f), which will generate results diffrent from
// the original results. To adjust the output to be "correct(approximatly
// equal)", we need additional calculation for the adjustment
if(ins->name() == "dot")
{ {
auto dot_op = any_cast<op::dot>(ins->get_operator()); if (inputs.size() == 2 or dot_op.beta == 0.0f)
int32_t quant_alpha = static_cast<int32_t>( {
dot_op.alpha / (int8_param[0].first * int8_param[1].first) + 0.5f); auto q_dot = prog.insert_instruction(ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs);
int32_t quant_beta = static_cast<int32_t>(dot_op.beta + 0.5f); prog.replace_instruction(ins, op::convert{orig_type}, q_dot);
prog.replace_instruction(ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs); }
// if there are 3 inputs, we need to consider the third argument
else
{
auto q_dot = prog.insert_instruction(ins, op::quant_dot{quant_alpha, quant_beta}, converted_inputs);
auto oq_dot = prog.insert_instruction(ins, op::convert{orig_type}, q_dot);
std::vector<float> vec_beta(q_dot->get_shape().elements(), dot_op.beta);
auto l_beta = prog.add_literal(literal{oq_dot->get_shape(), vec_beta});
auto beta_c = prog.insert_instruction(ins, op::mul{}, l_beta, inputs.back());
prog.replace_instruction(ins, op::add{}, q_dot, beta_c);
}
}
}
else
{
auto q_dot = prog.insert_instruction(ins, op::quant_dot{1, 0}, converted_inputs);
std::vector<float> vec_alpha(q_dot->get_shape().elements(), new_alpha);
if (orig_type == shape::int32_type)
{
auto l_alpha = prog.add_literal(literal(ins->get_shape(), vec_alpha));
if (converted_inputs.size() == 2 or dot_op.beta == 0.0f)
{
prog.replace_instruction(ins, op::mul{}, l_alpha, q_dot);
}
// case of 3 arguments
else
{
std::vector<float> vec_beta(ins->get_shape().elements(), new_beta);
auto l_beta = prog.add_literal(literal(ins->get_shape(), vec_beta));
auto alpha_ab = prog.insert_instruction(ins, op::mul{}, l_alpha, q_dot);
auto beta_c = prog.insert_instruction(ins, op::mul{}, l_beta, inputs.back());
prog.replace_instruction(ins, op::add{}, alpha_ab, beta_c);
}
}
else
{
auto oq_dot = prog.insert_instruction(ins, op::convert{orig_type}, q_dot);
auto l_alpha = prog.add_literal(literal(ins->get_shape(), vec_alpha));
if (converted_inputs.size() == 2 or dot_op.beta == 0.0f)
{
prog.replace_instruction(ins, op::mul{}, l_alpha, oq_dot);
}
// case of 3 arguments
else
{
std::vector<float> vec_beta(ins->get_shape().elements(), new_beta);
auto l_beta = prog.add_literal(literal(ins->get_shape(), vec_beta));
auto alpha_ab = prog.insert_instruction(ins, op::mul{}, l_alpha, oq_dot);
auto beta_c = prog.insert_instruction(ins, op::mul{}, l_beta, inputs.back());
prog.replace_instruction(ins, op::add{}, alpha_ab, beta_c);
}
}
}
} }
else if(ins->name() == "convolution") else if(ins->name() == "convolution")
{ {
...@@ -246,16 +321,34 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names) ...@@ -246,16 +321,34 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
auto group = conv_op.group; auto group = conv_op.group;
auto adjust_factor = 1.0 / (int8_param[0].first * int8_param[1].first); auto adjust_factor = 1.0 / (int8_param[0].first * int8_param[1].first);
auto conv_res = prog.insert_instruction( shape quant_shape = compute_shape(op::quant_convolution{}, converted_inputs);
ins, std::vector<float> vec_factor(quant_shape.elements(), adjust_factor);
op::quant_convolution{padding, stride, dilation, padding_mode, group}, auto fl = prog.add_literal(literal{{orig_type, quant_shape.lens()}, vec_factor});
converted_inputs); if (quant_shape.type() == orig_type)
auto conv_s = conv_res->get_shape(); {
std::vector<float> vec_fact(conv_s.elements(), adjust_factor); if (adjust_factor == 1.0f)
{
auto fl = prog.add_literal(literal{conv_s, vec_fact}); prog.replace_instruction(ins, op::quant_convolution{padding, stride, dilation, padding_mode, group}, converted_inputs);
auto ad_res = prog.insert_instruction(ins, op::mul{}, conv_res, fl); }
prog.replace_instruction(ins, ad_res); else
{
auto quant_conv = prog.replace_instruction(ins, op::quant_convolution{padding, stride, dilation, padding_mode, group}, converted_inputs);
prog.replace_instruction(ins, op::mul{}, quant_conv, fl);
}
}
else
{
auto quant_conv = prog.insert_instruction(ins, op::quant_convolution{padding, stride, dilation, padding_mode, group}, converted_inputs);
if (adjust_factor == 1.0f)
{
prog.replace_instruction(ins, op::convert{orig_type}, quant_conv);
}
else
{
auto oq_conv = prog.insert_instruction(ins, op::convert{orig_type}, quant_conv);
prog.replace_instruction(ins, op::mul{}, oq_conv, fl);
}
}
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment