Commit 47f6954f authored by Shucai Xiao's avatar Shucai Xiao
Browse files

refactor the int8 quantization

parent 531feac3
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/op/convert.hpp> #include <migraphx/op/convert.hpp>
#include <migraphx/op/clip.hpp>
#include <migraphx/op/round.hpp>
#include <migraphx/op/dot.hpp> #include <migraphx/op/dot.hpp>
#include <migraphx/op/mul.hpp> #include <migraphx/op/mul.hpp>
#include <migraphx/op/add.hpp> #include <migraphx/op/add.hpp>
...@@ -23,7 +25,9 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -23,7 +25,9 @@ inline namespace MIGRAPHX_INLINE_NS {
instruction_ref insert_quant_ins(program& prog, instruction_ref insert_quant_ins(program& prog,
instruction_ref& ins, instruction_ref& ins,
shape::type_t type, shape::type_t type,
std::unordered_map<instruction_ref, instruction_ref>& map_ins) std::unordered_map<instruction_ref, instruction_ref>& map_ins,
float scale = 1.0f,
float shift = 0.0f)
{ {
if(map_ins.count(ins) > 0) if(map_ins.count(ins) > 0)
{ {
...@@ -35,11 +39,48 @@ instruction_ref insert_quant_ins(program& prog, ...@@ -35,11 +39,48 @@ instruction_ref insert_quant_ins(program& prog,
return ins; return ins;
} }
assert(ins->get_shape().type() == shape::float_type || assert(ins->get_shape().type() == shape::float_type or
ins->get_shape().type() == shape::double_type || ins->get_shape().type() == shape::double_type or
ins->get_shape().type() == shape::int32_type); ins->get_shape().type() == shape::int32_type);
instruction_ref quant_ins{}; instruction_ref quant_ins{};
quant_ins = prog.insert_instruction(std::next(ins), op::convert{type}, ins); auto insert_loc = std::next(ins);
if (type == shape::int8_type)
{
auto scaled_ins = ins;
if (scale != 1.0f)
{
auto float_ins = scaled_ins;
if (scaled_ins->get_shape().type() != shape::float_type)
{
float_ins = prog.insert_instruction(insert_loc, op::convert{shape::float_type}, scaled_ins);
}
std::vector<float> vec_scale(scaled_ins->get_shape().elements(), scale);
auto l_scale = prog.add_literal(literal(scaled_ins->get_shape(), vec_scale));
scaled_ins = prog.insert_instruction(insert_loc, op::mul{}, l_scale, float_ins);
}
auto shifted_ins = scaled_ins;
if (shift != 0.0f)
{
auto float_ins = shifted_ins;
if (shifted_ins->get_shape().type() != shape::float_type)
{
float_ins = prog.insert_instruction(insert_loc, op::convert{shape::float_type}, shifted_ins);
}
std::vector<float> vec_shift(shifted_ins->get_shape().elements(), shift);
auto l_shift = prog.add_literal(literal(shifted_ins->get_shape(), vec_shift));
shifted_ins = prog.insert_instruction(insert_loc, op::add{}, l_shift, float_ins);
}
auto clipped_ins = prog.insert_instruction(insert_loc, op::clip{127.0f, -128.0f}, shifted_ins);
auto rounded_ins = prog.insert_instruction(insert_loc, op::round{}, clipped_ins);
quant_ins = prog.insert_instruction(insert_loc, op::convert{type}, rounded_ins);
}
else
{
quant_ins = prog.insert_instruction(insert_loc, op::convert{type}, ins);
}
map_ins[ins] = quant_ins; map_ins[ins] = quant_ins;
return quant_ins; return quant_ins;
...@@ -182,8 +223,8 @@ void quantize_int8(program& prog, ...@@ -182,8 +223,8 @@ void quantize_int8(program& prog,
} }
auto s = input->get_shape(); auto s = input->get_shape();
if((s.type() == shape::float_type || s.type() == shape::double_type || if((s.type() == shape::float_type or s.type() == shape::double_type or
s.type() == shape::int32_type) && s.type() == shape::int32_type) and
s.type() != quant_type) s.type() != quant_type)
{ {
// if the input is a convert operator, uses its input // if the input is a convert operator, uses its input
...@@ -198,12 +239,12 @@ void quantize_int8(program& prog, ...@@ -198,12 +239,12 @@ void quantize_int8(program& prog,
} }
else else
{ {
quant_input = insert_quant_ins(prog, input, quant_type, map_quant_ins); quant_input = insert_quant_ins(prog, input, quant_type, map_quant_ins, param.first, param.second);
} }
} }
else else
{ {
quant_input = insert_quant_ins(prog, input, quant_type, map_quant_ins); quant_input = insert_quant_ins(prog, input, quant_type, map_quant_ins, param.first, param.second);
} }
converted_inputs.push_back(quant_input); converted_inputs.push_back(quant_input);
} }
...@@ -298,7 +339,29 @@ void quantize_int8(program& prog, ...@@ -298,7 +339,29 @@ void quantize_int8(program& prog,
ins, ins,
op::quant_convolution{padding, stride, dilation, padding_mode, group}, op::quant_convolution{padding, stride, dilation, padding_mode, group},
converted_inputs); converted_inputs);
prog.replace_instruction(ins, op::convert{orig_type}, quant_conv); float threshold = 50.0f;
std::vector<float> vec_factor(quant_conv->get_shape().elements(), adjust_factor);
if (quant_conv->get_shape().type() == orig_type and adjust_factor >= threshold)
{
auto l_factor = prog.add_literal(literal(quant_conv->get_shape(), vec_factor.begin(), vec_factor.end()));
prog.replace_instruction(ins, op::mul{}, quant_conv, l_factor);
}
// convert quant_conv output to float type, multiply the factor and
// conver back to original type
else
{
auto float_conv = prog.insert_instruction(ins, op::convert{shape::float_type}, quant_conv);
auto l_factor = prog.add_literal(literal(float_conv->get_shape(), vec_factor));
if (orig_type == shape::float_type)
{
prog.replace_instruction(ins, op::mul{}, l_factor, float_conv);
}
else
{
auto adjusted_conv = prog.insert_instruction(ins, op::mul{}, l_factor, float_conv);
prog.replace_instruction(ins, op::convert{orig_type}, adjusted_conv);
}
}
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment