"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "17129c883d944d8adf17042f01951e2def163277"
Commit 4a10535c authored by Shucai Xiao's avatar Shucai Xiao
Browse files

code backup

parent 762cee10
......@@ -42,9 +42,10 @@ struct convert : unary<convert>
float res = scale * x + shift;
if(target_type == shape::int8_type)
{
res = res + 0.5f;
res = res > 127.0 ? 127.0 : res;
res = res < -128.0 ? -128.0 : res;
int factor = (res > 0) ? 1 : -1;
res = res + factor * 0.5f;
res = res > 127.0 ? 127.0 : res;
res = res < -128.0 ? -128.0 : res;
}
return res;
......
......@@ -15,7 +15,9 @@ struct program;
void quantize(program& prog, const std::vector<std::string>& ins_names);
void quantize(program& prog);
void quantize_int8(program& prog, const std::vector<std::string>& ins_names);
void quantize_int8(program& prog,
const std::vector<std::string>& ins_names,
std::vector<std::pair<float, float>>& int8_quant_params);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -186,6 +186,12 @@ PYBIND11_MODULE(migraphx, m)
migraphx::quantize(p, ins_names);
});
m.def("quantize", [](migraphx::program& p) { migraphx::quantize(p, {"all"}); });
m.def("quantize_int8",
[](migraphx::program& p,
std::vector<std::string>& ins_names,
std::vector<std::pair<float, float>>& quant_params) {
migraphx::quantize_int8(p, ins_names, quant_params);
});
#ifdef HAVE_GPU
m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
......
......@@ -119,8 +119,38 @@ void quantize(program& prog) { quantize(prog, {"all"}); }
// -128 ~ 127. To convert the float or double to int8, we need a scale and
// a shift, then the convert can be done as v_int8 = fp * scale + shift.
// To simplify the changes, we consider shift as 0.0f for now.
void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
void quantize_int8(program& prog,
const std::vector<std::string>& ins_names,
std::vector<std::pair<float, float>>& int8_quant_params)
{
// // For debugging
// auto print_gemm_res = [&](std::size_t ins_index, std::vector<migraphx::argument> args) {
// // scale and shift is need for only int8 type, and we do not
// // consider shift, so set shift to 0
// std::vector<float> vec_val;
// args.front().visit([&](auto output) { vec_val.assign(output.begin(), output.end()); });
// std::cout << "quant_gemm = " << std::endl;
// for (size_t i = 0; i < 20; i++)
// {
// std::cout << vec_val[i] << "\t";
// }
// std::cout << std::endl;
// };
// // For debugging
// auto print_conv_res = [&](std::size_t ins_index, std::vector<migraphx::argument> args) {
// // scale and shift is need for only int8 type, and we do not
// // consider shift, so set shift to 0
// std::vector<float> vec_val;
// args.front().visit([&](auto output) { vec_val.assign(output.begin(), output.end()); });
// std::cout << "quant_conv = " << std::endl;
// for (size_t i = 0; i < 20; i++)
// {
// std::cout << vec_val[i] << "\t";
// }
// std::cout << std::endl;
// };
// For now, we only support the int8 quantization of gemm and convolution
std::vector<std::string> op_names = {"dot", "convolution"};
if(!std::all_of(ins_names.begin(), ins_names.end(), [&](auto name) {
......@@ -130,9 +160,7 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
}
// tmp value used just testing
std::vector<std::pair<float, float>> int8_param{{127.0f, 0.0f}, {127.0f, 0.0f}, {128.0f, 0.0f}};
std::size_t quant_param_index = 0;
std::unordered_map<instruction_ref, instruction_ref> map_quant_ins;
for(auto ins : iterator_for(prog))
{
......@@ -150,15 +178,16 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
// process all inputs, if input is a fp32 or fp64, convert it
// to a int8 type by adding a convert operator and replace
// the operator with the corresponding int8 version
auto inputs = ins->inputs();
std::size_t param_index = 0;
auto inputs = ins->inputs();
std::vector<std::pair<float, float>> ins_quant_params;
for(auto input : inputs)
{
// In general, the target_type is int8, but for the dot
// operation, if it has 3 inputs, then the last one should
// be converted to int32_type
shape::type_t quant_type = shape::int8_type;
auto param = int8_param[param_index++];
auto param = int8_quant_params[quant_param_index++];
ins_quant_params.push_back(param);
if(ins->name() == "dot" and inputs.size() == 3 and input == inputs.back())
{
quant_type = shape::int32_type;
......@@ -210,9 +239,10 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
// equal)", we need additional calculation for the adjustment
if(ins->name() == "dot")
{
auto dot_op = any_cast<op::dot>(ins->get_operator());
float new_alpha = dot_op.alpha / (int8_param[0].first * int8_param[1].first);
float new_beta = dot_op.beta;
auto dot_op = any_cast<op::dot>(ins->get_operator());
float new_alpha =
dot_op.alpha / (ins_quant_params[0].first * ins_quant_params[1].first);
float new_beta = dot_op.beta;
// We need additional checking about the quant_alpha value. If
// abs(quant_alpha) > 50 (some tmp value set here), we can convert
// it to an integer as the new_alpha in the quant_dot
......@@ -335,7 +365,7 @@ void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
auto dilation = conv_op.dilation;
auto padding_mode = conv_op.padding_mode;
auto group = conv_op.group;
auto adjust_factor = 1.0 / (int8_param[0].first * int8_param[1].first);
auto adjust_factor = 1.0 / (ins_quant_params[0].first * ins_quant_params[1].first);
shape quant_shape =
compute_shape(op::quant_convolution{padding, stride, dilation, padding_mode, group},
......
......@@ -20,8 +20,10 @@ void convert(hipStream_t stream,
if(target_type == shape::int8_type)
{
gs_launch(stream, result.get_shape().elements())([=](auto i) {
output_ptr[i] = std::min<int8_t>(
std::max<float>(-128, input_ptr[i] * scale + shift + 0.5), 127);
float res = input_ptr[i] * scale + shift;
int factor = (res > 0) ? 1 : -1;
output_ptr[i] =
std::min<int8_t>(std::max<float>(-128, res + factor * 0.5), 127);
});
}
else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment