Commit 675d3b5b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

change quantize_int8 following changes of quant_convolution.

parent 6e01c7cb
...@@ -17,7 +17,7 @@ void quantize(program& prog); ...@@ -17,7 +17,7 @@ void quantize(program& prog);
void quantize_int8(program& prog, void quantize_int8(program& prog,
const std::vector<std::string>& ins_names, const std::vector<std::string>& ins_names,
std::vector<std::pair<float, float>>& int8_quant_params); const std::vector<std::pair<float, float>>& quant_params);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <utility> #include <utility>
#include <iomanip>
#include <fstream>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -34,6 +36,11 @@ instruction_ref insert_quant_ins(program& prog, ...@@ -34,6 +36,11 @@ instruction_ref insert_quant_ins(program& prog,
return ins; return ins;
} }
if(scale < 0.0f)
{
MIGRAPHX_THROW("INSERT_QUANT_INS: scale less than 0");
}
assert(ins->get_shape().type() == shape::float_type || assert(ins->get_shape().type() == shape::float_type ||
ins->get_shape().type() == shape::double_type || ins->get_shape().type() == shape::double_type ||
ins->get_shape().type() == shape::int32_type); ins->get_shape().type() == shape::int32_type);
...@@ -115,41 +122,23 @@ void quantize(program& prog, const std::vector<std::string>& ins_names) ...@@ -115,41 +122,23 @@ void quantize(program& prog, const std::vector<std::string>& ins_names)
void quantize(program& prog) { quantize(prog, {"all"}); } void quantize(program& prog) { quantize(prog, {"all"}); }
static std::vector<std::pair<float, float>> int8_quant_params;
// int8 quantization is different from fp16 since int8 can only handle value // int8 quantization is different from fp16 since int8 can only handle value
// -128 ~ 127. To convert the float or double to int8, we need a scale and // -128 ~ 127. To convert the float or double to int8, we need a scale and
// a shift, then the convert can be done as v_int8 = fp * scale + shift. // a shift, then the convert can be done as v_int8 = fp * scale + shift.
// To simplify the changes, we consider shift as 0.0f for now. // To simplify the changes, we consider shift as 0.0f for now.
void quantize_int8(program& prog, void quantize_int8(program& prog,
const std::vector<std::string>& ins_names, const std::vector<std::string>& ins_names,
std::vector<std::pair<float, float>>& int8_quant_params) const std::vector<std::pair<float, float>>& quant_params)
{ {
// // For debugging for(size_t i = 0; i < quant_params.size(); i++)
// auto print_gemm_res = [&](std::size_t ins_index, std::vector<migraphx::argument> args) { {
// // scale and shift is need for only int8 type, and we do not auto param = quant_params.at(i);
// // consider shift, so set shift to 0 std::cout << "index = " << i << ", scale = " << param.first << "\t" << param.second
// std::vector<float> vec_val; << std::endl;
// args.front().visit([&](auto output) { vec_val.assign(output.begin(), output.end()); }); }
// std::cout << "quant_gemm = " << std::endl; std::cout << std::endl;
// for (size_t i = 0; i < 20; i++)
// {
// std::cout << vec_val[i] << "\t";
// }
// std::cout << std::endl;
// };
// // For debugging
// auto print_conv_res = [&](std::size_t ins_index, std::vector<migraphx::argument> args) {
// // scale and shift is need for only int8 type, and we do not
// // consider shift, so set shift to 0
// std::vector<float> vec_val;
// args.front().visit([&](auto output) { vec_val.assign(output.begin(), output.end()); });
// std::cout << "quant_conv = " << std::endl;
// for (size_t i = 0; i < 20; i++)
// {
// std::cout << vec_val[i] << "\t";
// }
// std::cout << std::endl;
// };
// For now, we only support the int8 quantization of gemm and convolution // For now, we only support the int8 quantization of gemm and convolution
std::vector<std::string> op_names = {"dot", "convolution"}; std::vector<std::string> op_names = {"dot", "convolution"};
...@@ -162,6 +151,7 @@ void quantize_int8(program& prog, ...@@ -162,6 +151,7 @@ void quantize_int8(program& prog,
std::size_t quant_param_index = 0; std::size_t quant_param_index = 0;
std::unordered_map<instruction_ref, instruction_ref> map_quant_ins; std::unordered_map<instruction_ref, instruction_ref> map_quant_ins;
std::unordered_map<instruction_ref, std::size_t> map_index;
for(auto ins : iterator_for(prog)) for(auto ins : iterator_for(prog))
{ {
if(not contains(ins_names, ins->name())) if(not contains(ins_names, ins->name()))
...@@ -182,12 +172,18 @@ void quantize_int8(program& prog, ...@@ -182,12 +172,18 @@ void quantize_int8(program& prog,
std::vector<std::pair<float, float>> ins_quant_params; std::vector<std::pair<float, float>> ins_quant_params;
for(auto input : inputs) for(auto input : inputs)
{ {
// calculate the index of each instruction to be quantized
if(map_index.count(input) == 0)
{
map_index[input] = quant_param_index++;
}
auto param = quant_params[map_index[input]];
ins_quant_params.push_back(param);
// In general, the target_type is int8, but for the dot // In general, the target_type is int8, but for the dot
// operation, if it has 3 inputs, then the last one should // operation, if it has 3 inputs, then the last one should
// be converted to int32_type // be converted to int32_type
shape::type_t quant_type = shape::int8_type; shape::type_t quant_type = shape::int8_type;
auto param = int8_quant_params[quant_param_index++];
ins_quant_params.push_back(param);
if(ins->name() == "dot" and inputs.size() == 3 and input == inputs.back()) if(ins->name() == "dot" and inputs.size() == 3 and input == inputs.back())
{ {
quant_type = shape::int32_type; quant_type = shape::int32_type;
...@@ -269,6 +265,8 @@ void quantize_int8(program& prog, ...@@ -269,6 +265,8 @@ void quantize_int8(program& prog,
// addition // addition
else if(fabs(new_alpha) >= threshold) else if(fabs(new_alpha) >= threshold)
{ {
// truncate to the nearest integer
new_alpha = new_alpha > 0.0 ? new_alpha + 0.5 : new_alpha - 0.5;
int32_t quant_alpha = static_cast<int32_t>(new_alpha); int32_t quant_alpha = static_cast<int32_t>(new_alpha);
int32_t quant_beta = 0; int32_t quant_beta = 0;
if(orig_type == shape::int32_type) if(orig_type == shape::int32_type)
...@@ -308,7 +306,7 @@ void quantize_int8(program& prog, ...@@ -308,7 +306,7 @@ void quantize_int8(program& prog,
auto l_beta = prog.add_literal(literal{oq_dot->get_shape(), vec_beta}); auto l_beta = prog.add_literal(literal{oq_dot->get_shape(), vec_beta});
auto beta_c = auto beta_c =
prog.insert_instruction(ins, op::mul{}, l_beta, inputs.back()); prog.insert_instruction(ins, op::mul{}, l_beta, inputs.back());
prog.replace_instruction(ins, op::add{}, q_dot, beta_c); prog.replace_instruction(ins, op::add{}, oq_dot, beta_c);
} }
} }
} }
...@@ -365,54 +363,37 @@ void quantize_int8(program& prog, ...@@ -365,54 +363,37 @@ void quantize_int8(program& prog,
auto dilation = conv_op.dilation; auto dilation = conv_op.dilation;
auto padding_mode = conv_op.padding_mode; auto padding_mode = conv_op.padding_mode;
auto group = conv_op.group; auto group = conv_op.group;
auto adjust_factor = 1.0 / (ins_quant_params[0].first * ins_quant_params[1].first); auto adjust_factor = 1.0f / (ins_quant_params[0].first * ins_quant_params[1].first);
shape quant_shape =
compute_shape(op::quant_convolution{padding, stride, dilation, padding_mode, group},
converted_inputs);
std::vector<float> vec_factor(quant_shape.elements(), adjust_factor);
auto fl = prog.add_literal(literal{{orig_type, quant_shape.lens()}, vec_factor});
if(quant_shape.type() == orig_type)
{
if(adjust_factor == 1.0f)
{
prog.replace_instruction(
ins,
op::quant_convolution{padding, stride, dilation, padding_mode, group},
converted_inputs);
}
else
{
auto quant_conv = prog.insert_instruction( auto quant_conv = prog.insert_instruction(
ins, ins,
op::quant_convolution{padding, stride, dilation, padding_mode, group}, op::quant_convolution{padding, stride, dilation, padding_mode, group},
converted_inputs); converted_inputs);
prog.replace_instruction(ins, op::mul{}, quant_conv, fl); auto fp_conv = prog.insert_instruction(ins, op::convert{shape::float_type, adjust_factor, 0.0f}, quant_conv);
} prog.replace_instruction(ins, op::convert{orig_type, 1.0f, 0.0f}, fp_conv);
}
else
{
auto quant_conv = prog.insert_instruction(
ins,
op::quant_convolution{padding, stride, dilation, padding_mode, group},
converted_inputs);
if(adjust_factor == 1.0f)
{
prog.replace_instruction(ins, op::convert{orig_type}, quant_conv);
} }
else else
{ {
auto oq_conv = prog.insert_instruction(ins, op::convert{orig_type}, quant_conv); MIGRAPHX_THROW("QUANTIZE_INT8: does not support operator" + ins->name());
prog.replace_instruction(ins, op::mul{}, oq_conv, fl);
} }
} }
}
else if(quant_param_index != quant_params.size())
{ {
MIGRAPHX_THROW("INT8_QUANTIZE: does not support operator" + ins->name()); MIGRAPHX_THROW("QUANTIZE_INT8: number of scales does not match");
}
} }
} }
void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
{
quantize_int8(prog, ins_names, int8_quant_params);
}
void quantize_int8(program& prog)
{
std::vector<std::string> ins_names = {"dot", "convolution"};
quantize_int8(prog, ins_names);
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -71,7 +71,7 @@ argument miopen_quant_convolution::compute(context& ctx, ...@@ -71,7 +71,7 @@ argument miopen_quant_convolution::compute(context& ctx,
} }
// Add a conversion from float to int32_t // Add a conversion from float to int32_t
device::convert(ctx.get_stream().get(), args[4], args[3]); device::convert(ctx.get_stream().get(), args[4], args[3], 1.0f, 0.0f, shape::int32_type);
return args[4]; return args[4];
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment