Commit 8fbd2874 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 12ff93ac
......@@ -24,10 +24,16 @@ std::size_t capture_arguments(program& prog,
const std::function<void(std::size_t, std::vector<argument>)>& func);
std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments(program& prog, const target& t, const std::vector<std::string>& ins_names);
std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments(program& prog, const target& t);
std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments(program& prog,
const target& t);
void quantize_int8(program& prog, const target& t, std::vector<program::parameter_map> &calibration_args);
void quantize_int8(program& prog, const target& t, std::vector<program::parameter_map> &calibration_args, const std::vector<std::string>& ins_names);
void quantize_int8(program& prog,
const target& t,
std::vector<program::parameter_map>& calibration_args);
void quantize_int8(program& prog,
const target& t,
std::vector<program::parameter_map>& calibration_args,
const std::vector<std::string>& ins_names);
void quantize_int8(program& prog,
const std::vector<std::string>& ins_names,
const std::vector<std::pair<float, float>>& quant_params);
......
......@@ -202,11 +202,20 @@ struct target
context get_context() const override { return private_detail_te_value.get_context(); }
argument copy_to(const argument& arg) const override { return private_detail_te_value.copy_to(arg); }
argument copy_to(const argument& arg) const override
{
return private_detail_te_value.copy_to(arg);
}
argument copy_from(const argument& arg) const override { return private_detail_te_value.copy_from(arg); }
argument copy_from(const argument& arg) const override
{
return private_detail_te_value.copy_from(arg);
}
argument allocate(const shape& s) const override { return private_detail_te_value.allocate(s); }
argument allocate(const shape& s) const override
{
return private_detail_te_value.allocate(s);
}
PrivateDetailTypeErasedT private_detail_te_value;
};
......
......@@ -193,13 +193,19 @@ PYBIND11_MODULE(migraphx, m)
std::vector<std::pair<float, float>>& quant_params) {
migraphx::quantize_int8(p, ins_names, quant_params);
});
m.def("quantize_int8", [](migraphx::program& p, const migraphx::target& t,
std::vector<migraphx::program::parameter_map>& cali_args,
std::vector<std::string>& ins_names) {
migraphx::quantize_int8(p, t, cali_args, ins_names);
});
m.def("quantize_int8", [](migraphx::program& p, const migraphx::target& t,
std::vector<migraphx::program::parameter_map>& cali_args) { migraphx::quantize_int8(p, t, cali_args); });
m.def("quantize_int8",
[](migraphx::program& p,
const migraphx::target& t,
std::vector<migraphx::program::parameter_map>& cali_args,
std::vector<std::string>& ins_names) {
migraphx::quantize_int8(p, t, cali_args, ins_names);
});
m.def("quantize_int8",
[](migraphx::program& p,
const migraphx::target& t,
std::vector<migraphx::program::parameter_map>& cali_args) {
migraphx::quantize_int8(p, t, cali_args);
});
#ifdef HAVE_GPU
m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
......
......@@ -410,24 +410,26 @@ void quantize_int8(program& prog,
}
}
void quantize_int8(program& prog, const target& t,
std::vector<program::parameter_map> &calibration_args, const std::vector<std::string>& ins_names)
void quantize_int8(program& prog,
const target& t,
std::vector<program::parameter_map>& calibration_args,
const std::vector<std::string>& ins_names)
{
// insert capture operator
auto cap_prog = prog;
auto cap_prog = prog;
auto int8_quant_params = capture_arguments(cap_prog, t, ins_names);
// use the calibration data to compute the quantization scale
cap_prog.compile(t);
// use all calibration data to run the program to calculate the
// use all calibration data to run the program to calculate the
// quantization scale and shift
for (auto&& arg : calibration_args)
for(auto&& arg : calibration_args)
{
program::parameter_map m;
for (auto&& x : cap_prog.get_parameter_shapes())
for(auto&& x : cap_prog.get_parameter_shapes())
{
if (arg.count(x.first) > 0)
if(arg.count(x.first) > 0)
{
assert(x.second == arg[x.first].get_shape());
m[x.first] = t.copy_to(arg[x.first]);
......@@ -443,8 +445,9 @@ void quantize_int8(program& prog, const target& t,
quantize_int8(prog, ins_names, *int8_quant_params);
}
void quantize_int8(program& prog, const target& t,
std::vector<program::parameter_map> &calibration_args)
void quantize_int8(program& prog,
const target& t,
std::vector<program::parameter_map>& calibration_args)
{
std::vector<std::string> ins_names = {"dot", "convolution"};
quantize_int8(prog, t, calibration_args, ins_names);
......@@ -505,14 +508,16 @@ capture_arguments(program& prog, const target& t, const std::vector<std::string>
std::make_shared<std::vector<std::pair<float, float>>>();
std::shared_ptr<std::vector<float>> max_abs_vals = std::make_shared<std::vector<float>>();
auto calc_quant_params = [int8_quant_params, max_abs_vals, &t](
std::size_t ins_index, std::vector<argument> args) {
auto calc_quant_params = [int8_quant_params, max_abs_vals, &t](std::size_t ins_index,
std::vector<argument> args) {
std::pair<float, float> param_pair{64.0f, 0.0f};
// scale and shift is need for only int8 type, and we do not
// consider shift, so set shift to 0
std::vector<float> vec_val;
t.copy_from(args.front()).visit([&](auto output) { vec_val.assign(output.begin(), output.end()); });
t.copy_from(args.front()).visit([&](auto output) {
vec_val.assign(output.begin(), output.end());
});
auto max_val = *std::max_element(vec_val.begin(), vec_val.end());
auto min_val = *std::min_element(vec_val.begin(), vec_val.end());
auto max_abs = std::max(std::fabs(max_val), std::fabs(min_val));
......@@ -530,7 +535,8 @@ capture_arguments(program& prog, const target& t, const std::vector<std::string>
return int8_quant_params;
}
std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments(program& prog, const target& t)
std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments(program& prog,
const target& t)
{
std::vector<std::string> ins_names = {"dot", "convolution"};
return capture_arguments(prog, t, ins_names);
......
......@@ -23,10 +23,7 @@ std::vector<pass> target::get_passes(migraphx::context&) const
dead_code_elimination{}};
}
argument target::allocate(const shape& s) const
{
return fill_argument(s, 0);
}
argument target::allocate(const shape& s) const { return fill_argument(s, 0); }
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -86,20 +86,11 @@ std::string target::name() const { return "miopen"; }
migraphx::context target::get_context() const { return context{}; }
argument target::copy_to(const argument& arg) const
{
return gpu::to_gpu(arg);
}
argument target::copy_to(const argument& arg) const { return gpu::to_gpu(arg); }
argument target::copy_from(const argument& arg) const
{
return gpu::from_gpu(arg);
}
argument target::copy_from(const argument& arg) const { return gpu::from_gpu(arg); }
argument target::allocate(const shape& s) const
{
return gpu::allocate_gpu(s);
}
argument target::allocate(const shape& s) const { return gpu::allocate_gpu(s); }
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment