Commit b76e669c authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from branch capture_more_changes

parents b4c1ec87 8efe1eeb
...@@ -126,9 +126,6 @@ struct program ...@@ -126,9 +126,6 @@ struct program
friend bool operator==(const program& x, const program& y); friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); } friend bool operator!=(const program& x, const program& y) { return !(x == y); }
std::shared_ptr<std::vector<std::pair<float, float>>> int8_quant_params =
std::make_shared<std::vector<std::pair<float, float>>>();
private: private:
void assign(const program& p); void assign(const program& p);
......
...@@ -17,11 +17,12 @@ void quantize(program& prog); ...@@ -17,11 +17,12 @@ void quantize(program& prog);
// insert the capture operator for the inputs of each operator to be quantized // insert the capture operator for the inputs of each operator to be quantized
// to int8 // to int8
void capture_arguments(program& prog, std::size_t capture_arguments(program& prog,
const std::vector<std::string>& ins_names, const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func); const std::function<void(std::size_t, std::vector<argument>)>& func);
void capture_arguments(program& prog, const std::vector<std::string>& ins_names); std::shared_ptr<std::vector<std::pair<float, float>>>
void capture_arguments(program& prog); capture_arguments(program& prog, const std::vector<std::string>& ins_names);
std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments(program& prog);
void quantize_int8(program& prog); void quantize_int8(program& prog);
void quantize_int8(program& prog, const std::vector<std::string>& ins_names); void quantize_int8(program& prog, const std::vector<std::string>& ins_names);
......
...@@ -112,8 +112,7 @@ void program::assign(const program& p) ...@@ -112,8 +112,7 @@ void program::assign(const program& p)
{ {
impl->instructions.clear(); impl->instructions.clear();
} }
impl->ctx = p.impl->ctx; impl->ctx = p.impl->ctx;
int8_quant_params = p.int8_quant_params;
std::unordered_map<instruction_ref, instruction_ref> ins_map; std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
......
...@@ -187,9 +187,6 @@ PYBIND11_MODULE(migraphx, m) ...@@ -187,9 +187,6 @@ PYBIND11_MODULE(migraphx, m)
migraphx::quantize(p, ins_names); migraphx::quantize(p, ins_names);
}); });
m.def("quantize", [](migraphx::program& p) { migraphx::quantize(p, {"all"}); }); m.def("quantize", [](migraphx::program& p) { migraphx::quantize(p, {"all"}); });
m.def("quantize_int8", [](migraphx::program& p, std::vector<std::string>& ins_names) {
migraphx::quantize_int8(p, ins_names);
});
m.def("quantize_int8", m.def("quantize_int8",
[](migraphx::program& p, [](migraphx::program& p,
std::vector<std::string>& ins_names, std::vector<std::string>& ins_names,
...@@ -201,12 +198,6 @@ PYBIND11_MODULE(migraphx, m) ...@@ -201,12 +198,6 @@ PYBIND11_MODULE(migraphx, m)
}); });
m.def("quantize_int8", [](migraphx::program& p) { migraphx::quantize_int8(p); }); m.def("quantize_int8", [](migraphx::program& p) { migraphx::quantize_int8(p); });
m.def("capture_arguments", [](migraphx::program& p, const std::vector<std::string>& ins_names) {
migraphx::capture_arguments(p, ins_names);
});
m.def("capture_arguments", [](migraphx::program& p) { migraphx::capture_arguments(p); });
#ifdef HAVE_GPU #ifdef HAVE_GPU
m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false); m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
m.def("to_gpu", &migraphx::gpu::to_gpu, py::arg("arg"), py::arg("host") = false); m.def("to_gpu", &migraphx::gpu::to_gpu, py::arg("arg"), py::arg("host") = false);
......
...@@ -421,9 +421,9 @@ void quantize_int8(program& prog) ...@@ -421,9 +421,9 @@ void quantize_int8(program& prog)
// For the input of each input argument, we need to insert a // For the input of each input argument, we need to insert a
// capture operator to compute the scale and shift // capture operator to compute the scale and shift
void capture_arguments(program& prog, std::size_t capture_arguments(program& prog,
const std::vector<std::string>& ins_names, const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func) const std::function<void(std::size_t, std::vector<argument>)>& func)
{ {
size_t num_quant_params = 0; size_t num_quant_params = 0;
...@@ -464,34 +464,45 @@ void capture_arguments(program& prog, ...@@ -464,34 +464,45 @@ void capture_arguments(program& prog,
instruction::replace(ins, ins->get_operator(), ins->get_shape(), new_args); instruction::replace(ins, ins->get_operator(), ins->get_shape(), new_args);
} }
// set one pair of parameter for each argument return num_quant_params;
prog.int8_quant_params->resize(num_quant_params, std::make_pair(-1.0f, -1.0f));
} }
void capture_arguments(program& prog, const std::vector<std::string>& ins_names) std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments(program& prog, const std::vector<std::string>& ins_names)
{ {
auto calc_quant_params = [&](std::size_t ins_index, std::vector<migraphx::argument> args) { std::shared_ptr<std::vector<std::pair<float, float>>> int8_quant_params =
std::pair<float, float> param_pair{1.0f, 0.0f}; std::make_shared<std::vector<std::pair<float, float>>>();
std::shared_ptr<std::vector<float>> max_abs_vals = std::make_shared<std::vector<float>>();
auto calc_quant_params = [int8_quant_params, max_abs_vals](
std::size_t ins_index, std::vector<migraphx::argument> args) {
std::pair<float, float> param_pair{64.0f, 0.0f};
// scale and shift is need for only int8 type, and we do not // scale and shift is need for only int8 type, and we do not
// consider shift, so set shift to 0 // consider shift, so set shift to 0
std::vector<float> vec_val; std::vector<float> vec_val;
args.front().visit([&](auto output) { vec_val.assign(output.begin(), output.end()); }); args.front().visit([&](auto output) { vec_val.assign(output.begin(), output.end()); });
auto max_val = *std::max_element(vec_val.begin(), vec_val.end()); auto max_val = *std::max_element(vec_val.begin(), vec_val.end());
auto min_val = *std::min_element(vec_val.begin(), vec_val.end()); auto min_val = *std::min_element(vec_val.begin(), vec_val.end());
auto max_abs = std::max(std::fabs(max_val), std::fabs(min_val)); auto max_abs = std::max(std::fabs(max_val), std::fabs(min_val));
max_abs_vals->at(ins_index) = std::max(max_abs_vals->at(ins_index), max_abs);
param_pair.first = 127.0f / max_abs; param_pair.first = 127.0f / max_abs_vals->at(ins_index);
(*prog.int8_quant_params)[ins_index] = param_pair; int8_quant_params->at(ins_index) = param_pair;
}; };
capture_arguments(prog, ins_names, calc_quant_params); auto num_params = capture_arguments(prog, ins_names, calc_quant_params);
int8_quant_params->resize(num_params, std::make_pair<float, float>(64.0f, 0.0f));
max_abs_vals->resize(num_params, 0.0f);
return int8_quant_params;
} }
void capture_arguments(program& prog) std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments(program& prog)
{ {
std::vector<std::string> ins_names = {"dot", "convolution"}; std::vector<std::string> ins_names = {"dot", "convolution"};
capture_arguments(prog, ins_names); return capture_arguments(prog, ins_names);
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment