Commit 28c7a058 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

modify to remove an unnecessary global variable

parent 52ed1fc3
......@@ -126,6 +126,8 @@ struct program
friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); }
std::vector<std::pair<float, float>> int8_quant_params;
private:
void assign(const program& p);
......
......@@ -123,25 +123,6 @@ void quantize(program& prog, const std::vector<std::string>& ins_names)
void quantize(program& prog) { quantize(prog, {"all"}); }
static std::vector<std::pair<float, float>> int8_quant_params;
// function to compute the scale for each convert operator to convert to int8
void calc_quant_params(std::size_t ins_index, std::vector<migraphx::argument> args)
{
std::pair<float, float> param_pair{1.0f, 0.0f};
// scale and shift is need for only int8 type, and we do not
// consider shift, so set shift to 0
std::vector<float> vec_val;
args.front().visit([&](auto output) { vec_val.assign(output.begin(), output.end()); });
auto max_val = *std::max_element(vec_val.begin(), vec_val.end());
auto min_val = *std::min_element(vec_val.begin(), vec_val.end());
auto max_abs = std::max(std::fabs(max_val), std::fabs(min_val));
param_pair.first = 127.0f / max_abs;
int8_quant_params[ins_index] = param_pair;
};
// int8 quantization is different from fp16 since int8 can only handle value
// -128 ~ 127. To convert the float or double to int8, we need a scale and
// a shift, then the convert can be done as v_int8 = fp * scale + shift.
......@@ -343,13 +324,13 @@ void quantize_int8(program& prog,
void quantize_int8(program& prog, const std::vector<std::string>& ins_names)
{
quantize_int8(prog, ins_names, int8_quant_params);
quantize_int8(prog, ins_names, prog.int8_quant_params);
}
void quantize_int8(program& prog)
{
std::vector<std::string> ins_names = {"dot", "convolution"};
quantize_int8(prog, ins_names, int8_quant_params);
quantize_int8(prog, ins_names, prog.int8_quant_params);
}
// For the input of each input argument, we need to insert a
......@@ -358,6 +339,7 @@ void capture_arguments(program& prog,
const std::vector<std::string>& ins_names,
std::function<void(std::size_t, std::vector<argument>)> func)
{
size_t num_quant_params = 0;
// the int8 quantization only support dot and convolution
std::vector<std::string> op_names = {"dot", "convolution", "quant_dot", "quant_convolution"};
......@@ -397,11 +379,28 @@ void capture_arguments(program& prog,
}
// set one pair of parameter for each argument
int8_quant_params.resize(num_quant_params, std::make_pair(-1.0f, -1.0f));
prog.int8_quant_params.resize(num_quant_params, std::make_pair(-1.0f, -1.0f));
}
void capture_arguments(program& prog, const std::vector<std::string>& ins_names)
{
auto calc_quant_params = [&](std::size_t ins_index, std::vector<migraphx::argument> args)
{
std::pair<float, float> param_pair{1.0f, 0.0f};
// scale and shift is need for only int8 type, and we do not
// consider shift, so set shift to 0
std::vector<float> vec_val;
args.front().visit([&](auto output) { vec_val.assign(output.begin(), output.end()); });
auto max_val = *std::max_element(vec_val.begin(), vec_val.end());
auto min_val = *std::min_element(vec_val.begin(), vec_val.end());
auto max_abs = std::max(std::fabs(max_val), std::fabs(min_val));
param_pair.first = 127.0f / max_abs;
prog.int8_quant_params[ins_index] = param_pair;
};
capture_arguments(prog, ins_names, calc_quant_params);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment