Commit 12ff93ac authored by Shucai Xiao's avatar Shucai Xiao
Browse files

refine int8 quantization interface

parent b11bc73d
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/target.hpp>
#include <migraphx/program.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -21,11 +23,11 @@ std::size_t capture_arguments(program& prog, ...@@ -21,11 +23,11 @@ std::size_t capture_arguments(program& prog,
const std::vector<std::string>& ins_names, const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func); const std::function<void(std::size_t, std::vector<argument>)>& func);
std::shared_ptr<std::vector<std::pair<float, float>>> std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments(program& prog, const std::vector<std::string>& ins_names); capture_arguments(program& prog, const target& t, const std::vector<std::string>& ins_names);
std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments(program& prog); std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments(program& prog, const target& t);
void quantize_int8(program& prog); void quantize_int8(program& prog, const target& t, std::vector<program::parameter_map> &calibration_args);
void quantize_int8(program& prog, const std::vector<std::string>& ins_names); void quantize_int8(program& prog, const target& t, std::vector<program::parameter_map> &calibration_args, const std::vector<std::string>& ins_names);
void quantize_int8(program& prog, void quantize_int8(program& prog,
const std::vector<std::string>& ins_names, const std::vector<std::string>& ins_names,
const std::vector<std::pair<float, float>>& quant_params); const std::vector<std::pair<float, float>>& quant_params);
......
...@@ -125,6 +125,24 @@ struct target ...@@ -125,6 +125,24 @@ struct target
return (*this).private_detail_te_get_handle().get_context(); return (*this).private_detail_te_get_handle().get_context();
} }
argument copy_to(const argument& arg) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().copy_to(arg);
}
argument copy_from(const argument& arg) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().copy_from(arg);
}
argument allocate(const shape& s) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().allocate(s);
}
friend bool is_shared(const target& private_detail_x, const target& private_detail_y) friend bool is_shared(const target& private_detail_x, const target& private_detail_y)
{ {
return private_detail_x.private_detail_te_handle_mem_var == return private_detail_x.private_detail_te_handle_mem_var ==
...@@ -141,6 +159,9 @@ struct target ...@@ -141,6 +159,9 @@ struct target
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual std::vector<pass> get_passes(context& ctx) const = 0; virtual std::vector<pass> get_passes(context& ctx) const = 0;
virtual context get_context() const = 0; virtual context get_context() const = 0;
virtual argument copy_to(const argument& arg) const = 0;
virtual argument copy_from(const argument& arg) const = 0;
virtual argument allocate(const shape& s) const = 0;
}; };
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
...@@ -181,6 +202,12 @@ struct target ...@@ -181,6 +202,12 @@ struct target
context get_context() const override { return private_detail_te_value.get_context(); } context get_context() const override { return private_detail_te_value.get_context(); }
argument copy_to(const argument& arg) const override { return private_detail_te_value.copy_to(arg); }
argument copy_from(const argument& arg) const override { return private_detail_te_value.copy_from(arg); }
argument allocate(const shape& s) const override { return private_detail_te_value.allocate(s); }
PrivateDetailTypeErasedT private_detail_te_value; PrivateDetailTypeErasedT private_detail_te_value;
}; };
......
...@@ -193,10 +193,13 @@ PYBIND11_MODULE(migraphx, m) ...@@ -193,10 +193,13 @@ PYBIND11_MODULE(migraphx, m)
std::vector<std::pair<float, float>>& quant_params) { std::vector<std::pair<float, float>>& quant_params) {
migraphx::quantize_int8(p, ins_names, quant_params); migraphx::quantize_int8(p, ins_names, quant_params);
}); });
m.def("quantize_int8", [](migraphx::program& p, std::vector<std::string>& ins_names) { m.def("quantize_int8", [](migraphx::program& p, const migraphx::target& t,
migraphx::quantize_int8(p, ins_names); std::vector<migraphx::program::parameter_map>& cali_args,
std::vector<std::string>& ins_names) {
migraphx::quantize_int8(p, t, cali_args, ins_names);
}); });
m.def("quantize_int8", [](migraphx::program& p) { migraphx::quantize_int8(p); }); m.def("quantize_int8", [](migraphx::program& p, const migraphx::target& t,
std::vector<migraphx::program::parameter_map>& cali_args) { migraphx::quantize_int8(p, t, cali_args); });
#ifdef HAVE_GPU #ifdef HAVE_GPU
m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false); m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <migraphx/op/multibroadcast.hpp> #include <migraphx/op/multibroadcast.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/target.hpp>
#include <utility> #include <utility>
#include <iomanip> #include <iomanip>
#include <fstream> #include <fstream>
...@@ -409,15 +410,44 @@ void quantize_int8(program& prog, ...@@ -409,15 +410,44 @@ void quantize_int8(program& prog,
} }
} }
void quantize_int8(program& prog, const std::vector<std::string>& ins_names) void quantize_int8(program& prog, const target& t,
std::vector<program::parameter_map> &calibration_args, const std::vector<std::string>& ins_names)
{ {
quantize_int8(prog, ins_names, *prog.int8_quant_params); // insert capture operator
auto cap_prog = prog;
auto int8_quant_params = capture_arguments(cap_prog, t, ins_names);
// use the calibration data to compute the quantization scale
cap_prog.compile(t);
// use all calibration data to run the program to calculate the
// quantization scale and shift
for (auto&& arg : calibration_args)
{
program::parameter_map m;
for (auto&& x : cap_prog.get_parameter_shapes())
{
if (arg.count(x.first) > 0)
{
assert(x.second == arg[x.first].get_shape());
m[x.first] = t.copy_to(arg[x.first]);
}
else
{
m[x.first] = t.allocate(x.second);
}
}
cap_prog.eval(m);
}
quantize_int8(prog, ins_names, *int8_quant_params);
} }
void quantize_int8(program& prog) void quantize_int8(program& prog, const target& t,
std::vector<program::parameter_map> &calibration_args)
{ {
std::vector<std::string> ins_names = {"dot", "convolution"}; std::vector<std::string> ins_names = {"dot", "convolution"};
quantize_int8(prog, ins_names); quantize_int8(prog, t, calibration_args, ins_names);
} }
// For the input of each input argument, we need to insert a // For the input of each input argument, we need to insert a
...@@ -469,20 +499,20 @@ std::size_t capture_arguments(program& prog, ...@@ -469,20 +499,20 @@ std::size_t capture_arguments(program& prog,
} }
std::shared_ptr<std::vector<std::pair<float, float>>> std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments(program& prog, const std::vector<std::string>& ins_names) capture_arguments(program& prog, const target& t, const std::vector<std::string>& ins_names)
{ {
std::shared_ptr<std::vector<std::pair<float, float>>> int8_quant_params = std::shared_ptr<std::vector<std::pair<float, float>>> int8_quant_params =
std::make_shared<std::vector<std::pair<float, float>>>(); std::make_shared<std::vector<std::pair<float, float>>>();
std::shared_ptr<std::vector<float>> max_abs_vals = std::make_shared<std::vector<float>>(); std::shared_ptr<std::vector<float>> max_abs_vals = std::make_shared<std::vector<float>>();
auto calc_quant_params = [int8_quant_params, max_abs_vals]( auto calc_quant_params = [int8_quant_params, max_abs_vals, &t](
std::size_t ins_index, std::vector<migraphx::argument> args) { std::size_t ins_index, std::vector<argument> args) {
std::pair<float, float> param_pair{64.0f, 0.0f}; std::pair<float, float> param_pair{64.0f, 0.0f};
// scale and shift is need for only int8 type, and we do not // scale and shift is need for only int8 type, and we do not
// consider shift, so set shift to 0 // consider shift, so set shift to 0
std::vector<float> vec_val; std::vector<float> vec_val;
args.front().visit([&](auto output) { vec_val.assign(output.begin(), output.end()); }); t.copy_from(args.front()).visit([&](auto output) { vec_val.assign(output.begin(), output.end()); });
auto max_val = *std::max_element(vec_val.begin(), vec_val.end()); auto max_val = *std::max_element(vec_val.begin(), vec_val.end());
auto min_val = *std::min_element(vec_val.begin(), vec_val.end()); auto min_val = *std::min_element(vec_val.begin(), vec_val.end());
auto max_abs = std::max(std::fabs(max_val), std::fabs(min_val)); auto max_abs = std::max(std::fabs(max_val), std::fabs(min_val));
...@@ -500,10 +530,10 @@ capture_arguments(program& prog, const std::vector<std::string>& ins_names) ...@@ -500,10 +530,10 @@ capture_arguments(program& prog, const std::vector<std::string>& ins_names)
return int8_quant_params; return int8_quant_params;
} }
std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments(program& prog) std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments(program& prog, const target& t)
{ {
std::vector<std::string> ins_names = {"dot", "convolution"}; std::vector<std::string> ins_names = {"dot", "convolution"};
return capture_arguments(prog, ins_names); return capture_arguments(prog, t, ins_names);
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -15,6 +15,10 @@ struct target ...@@ -15,6 +15,10 @@ struct target
std::string name() const; std::string name() const;
std::vector<pass> get_passes(migraphx::context& ctx) const; std::vector<pass> get_passes(migraphx::context& ctx) const;
migraphx::context get_context() const { return context{}; } migraphx::context get_context() const { return context{}; }
argument copy_to(const argument& arg) const { return arg; }
argument copy_from(const argument& arg) const { return arg; }
argument allocate(const shape& s) const;
}; };
} // namespace cpu } // namespace cpu
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <migraphx/auto_contiguous.hpp> #include <migraphx/auto_contiguous.hpp>
#include <migraphx/rewrite_rnn.hpp> #include <migraphx/rewrite_rnn.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/generate.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -22,6 +23,11 @@ std::vector<pass> target::get_passes(migraphx::context&) const ...@@ -22,6 +23,11 @@ std::vector<pass> target::get_passes(migraphx::context&) const
dead_code_elimination{}}; dead_code_elimination{}};
} }
argument target::allocate(const shape& s) const
{
return fill_argument(s, 0);
}
} // namespace cpu } // namespace cpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -13,6 +13,10 @@ struct target ...@@ -13,6 +13,10 @@ struct target
std::string name() const; std::string name() const;
std::vector<pass> get_passes(migraphx::context& gctx) const; std::vector<pass> get_passes(migraphx::context& gctx) const;
migraphx::context get_context() const; migraphx::context get_context() const;
argument copy_to(const argument& arg) const;
argument copy_from(const argument& arg) const;
argument allocate(const shape& s) const;
}; };
} // namespace gpu } // namespace gpu
......
...@@ -85,6 +85,22 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -85,6 +85,22 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
std::string target::name() const { return "miopen"; } std::string target::name() const { return "miopen"; }
migraphx::context target::get_context() const { return context{}; } migraphx::context target::get_context() const { return context{}; }
argument target::copy_to(const argument& arg) const
{
return gpu::to_gpu(arg);
}
argument target::copy_from(const argument& arg) const
{
return gpu::from_gpu(arg);
}
argument target::allocate(const shape& s) const
{
return gpu::allocate_gpu(s);
}
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment