Commit 492d329a authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into bugs_for_bert

parents f8cb174d 95050fbd
......@@ -126,9 +126,6 @@ struct program
friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); }
std::shared_ptr<std::vector<std::pair<float, float>>> int8_quant_params =
std::make_shared<std::vector<std::pair<float, float>>>();
private:
void assign(const program& p);
......
......@@ -17,11 +17,12 @@ void quantize(program& prog);
// insert the capture operator for the inputs of each operator to be quantized
// to int8
void capture_arguments(program& prog,
const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func);
void capture_arguments(program& prog, const std::vector<std::string>& ins_names);
void capture_arguments(program& prog);
std::size_t capture_arguments(program& prog,
const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func);
std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments(program& prog, const std::vector<std::string>& ins_names);
std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments(program& prog);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -112,8 +112,7 @@ void program::assign(const program& p)
{
impl->instructions.clear();
}
impl->ctx = p.impl->ctx;
int8_quant_params = p.int8_quant_params;
impl->ctx = p.impl->ctx;
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(p))
......
......@@ -187,11 +187,6 @@ PYBIND11_MODULE(migraphx, m)
migraphx::quantize(p, ins_names);
});
m.def("quantize", [](migraphx::program& p) { migraphx::quantize(p, {"all"}); });
m.def("capture_arguments", [](migraphx::program& p, const std::vector<std::string>& ins_names) {
migraphx::capture_arguments(p, ins_names);
});
m.def("capture_arguments", [](migraphx::program& p) { migraphx::capture_arguments(p); });
#ifdef HAVE_GPU
m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
......
......@@ -119,9 +119,9 @@ void quantize(program& prog) { quantize(prog, {"all"}); }
// For the input of each input argument, we need to insert a
// capture operator to compute the scale and shift
void capture_arguments(program& prog,
const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func)
std::size_t capture_arguments(program& prog,
const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func)
{
size_t num_quant_params = 0;
......@@ -162,34 +162,45 @@ void capture_arguments(program& prog,
instruction::replace(ins, ins->get_operator(), ins->get_shape(), new_args);
}
// set one pair of parameter for each argument
prog.int8_quant_params->resize(num_quant_params, std::make_pair(-1.0f, -1.0f));
return num_quant_params;
}
void capture_arguments(program& prog, const std::vector<std::string>& ins_names)
std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments(program& prog, const std::vector<std::string>& ins_names)
{
auto calc_quant_params = [&](std::size_t ins_index, std::vector<migraphx::argument> args) {
std::pair<float, float> param_pair{1.0f, 0.0f};
std::shared_ptr<std::vector<std::pair<float, float>>> int8_quant_params =
std::make_shared<std::vector<std::pair<float, float>>>();
std::shared_ptr<std::vector<float>> max_abs_vals = std::make_shared<std::vector<float>>();
auto calc_quant_params = [int8_quant_params, max_abs_vals](
std::size_t ins_index, std::vector<migraphx::argument> args) {
std::pair<float, float> param_pair{64.0f, 0.0f};
// scale and shift is need for only int8 type, and we do not
// consider shift, so set shift to 0
std::vector<float> vec_val;
args.front().visit([&](auto output) { vec_val.assign(output.begin(), output.end()); });
auto max_val = *std::max_element(vec_val.begin(), vec_val.end());
auto min_val = *std::min_element(vec_val.begin(), vec_val.end());
auto max_abs = std::max(std::fabs(max_val), std::fabs(min_val));
auto max_val = *std::max_element(vec_val.begin(), vec_val.end());
auto min_val = *std::min_element(vec_val.begin(), vec_val.end());
auto max_abs = std::max(std::fabs(max_val), std::fabs(min_val));
max_abs_vals->at(ins_index) = std::max(max_abs_vals->at(ins_index), max_abs);
param_pair.first = 127.0f / max_abs;
(*prog.int8_quant_params)[ins_index] = param_pair;
param_pair.first = 127.0f / max_abs_vals->at(ins_index);
int8_quant_params->at(ins_index) = param_pair;
};
capture_arguments(prog, ins_names, calc_quant_params);
auto num_params = capture_arguments(prog, ins_names, calc_quant_params);
int8_quant_params->resize(num_params, std::pair<float, float>(64.0f, 0.0f));
max_abs_vals->resize(num_params, 0.0f);
return int8_quant_params;
}
void capture_arguments(program& prog)
std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments(program& prog)
{
std::vector<std::string> ins_names = {"dot", "convolution"};
capture_arguments(prog, ins_names);
return capture_arguments(prog, ins_names);
}
} // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment