Commit 684d5a4e authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix review comments

parent 0fbdaf58
......@@ -15,10 +15,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct program;
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS)
void quantize_fp16(program& prog, const std::vector<std::string>& ins_names);
void quantize_fp16(program& prog);
void quantize_fp16(program& prog, const std::vector<std::string>& ins_names = {"all"});
// insert the capture operator for the inputs of each operator to be quantized
// to int8
......@@ -31,7 +28,7 @@ capture_arguments_impl(program& prog, const target& t, const std::vector<std::st
template <class T>
std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments(
program& prog, T&& t, const std::vector<std::string>& ins_names = {"dot", "convolution"})
program& prog, T&& t, const std::vector<std::string>& ins_names)
{
static_assert(std::is_same<std::remove_cv_t<std::remove_reference_t<T>>, target>{} &&
std::is_lvalue_reference<T>{},
......@@ -41,9 +38,9 @@ std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments(
void quantize_int8(program& prog,
const target& t,
std::vector<program::parameter_map>& calibration_args,
std::vector<program::parameter_map>& calibration,
const std::vector<std::string>& ins_names = {"dot", "convolution"});
void quantize_int8(program& prog,
void quantize_int8_impl(program& prog,
const std::vector<std::pair<float, float>>& quant_params,
const std::vector<std::string>& ins_names);
......
......@@ -183,29 +183,9 @@ PYBIND11_MODULE(migraphx, m)
});
m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0);
m.def("quantize_fp16", [](migraphx::program& p, std::vector<std::string>& ins_names) {
migraphx::quantize_fp16(p, ins_names);
});
m.def("quantize_fp16", [](migraphx::program& p) { migraphx::quantize_fp16(p, {"all"}); });
m.def("quantize_int8",
[](migraphx::program& p,
std::vector<std::string>& ins_names,
std::vector<std::pair<float, float>>& quant_params) {
migraphx::quantize_int8(p, quant_params, ins_names);
});
m.def("quantize_int8",
[](migraphx::program& p,
const migraphx::target& t,
std::vector<std::string>& ins_names,
std::vector<migraphx::program::parameter_map>& cali_args) {
migraphx::quantize_int8(p, t, cali_args, ins_names);
});
m.def("quantize_int8",
[](migraphx::program& p,
const migraphx::target& t,
std::vector<migraphx::program::parameter_map>& cali_args) {
migraphx::quantize_int8(p, t, cali_args);
});
m.def("quantize_fp16", &migraphx::quantize_fp16, py::arg("prog"), py::arg("ins_names") = std::vector<std::string>{"all"});
m.def("quantize_int8", &migraphx::quantize_int8, py::arg("prog"), py::arg("t"), py::arg("calibration") = std::vector<migraphx::program::parameter_map>{},
py::arg("ins_names") = std::vector<std::string>{"dot", "convolution"});
#ifdef HAVE_GPU
m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
......
......@@ -17,12 +17,16 @@
#include <migraphx/ranges.hpp>
#include <migraphx/target.hpp>
#include <utility>
#include <set>
#include <iomanip>
#include <fstream>
#include <algorithm>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_INT8_QUANTIZATION_PARAMS)
instruction_ref insert_quant_ins(program& prog,
instruction_ref& ins,
shape::type_t type,
......@@ -161,8 +165,6 @@ void quantize_fp16(program& prog, const std::vector<std::string>& ins_names)
}
}
void quantize_fp16(program& prog) { quantize_fp16(prog, {"all"}); }
static void ins_quantize_int8(program& prog,
instruction_ref ins,
std::vector<instruction_ref>& converted_inputs,
......@@ -305,7 +307,7 @@ static void ins_quantize_int8(program& prog,
// -128 ~ 127. To convert the float or double to int8, we need a scale and
// a shift, then the convert can be done as v_int8 = fp * scale + shift.
// To simplify the changes, we consider shift as 0.0f for now.
void quantize_int8(program& prog,
void quantize_int8_impl(program& prog,
const std::vector<std::pair<float, float>>& quant_params,
const std::vector<std::string>& ins_names)
{
......@@ -321,10 +323,9 @@ void quantize_int8(program& prog,
}
// For now, we only support the int8 quantization of gemm and convolution
std::vector<std::string> op_names = {"dot", "convolution"};
if(!std::all_of(ins_names.begin(), ins_names.end(), [&](auto name) {
return (std::find(op_names.begin(), op_names.end(), name) != op_names.end());
}))
std::set<std::string> op_names = {"convolution", "dot"};
std::set<std::string> input_ins_names(ins_names.begin(), ins_names.end());
if (!std::includes(op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end()))
{
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
}
......@@ -381,7 +382,7 @@ void quantize_int8(program& prog,
quant_input = input->inputs().front();
// the scale in this case is not used, so tune the scale
// to 1.0f for this parameter
ins_quant_params.back() = std::make_pair<float, float>(1.0f, 0.0f);
ins_quant_params.back() = std::pair<float, float>(1.0f, 0.0f);
}
else
{
......@@ -413,7 +414,7 @@ void quantize_int8(program& prog,
void quantize_int8(program& prog,
const target& t,
std::vector<program::parameter_map>& calibration_args,
std::vector<program::parameter_map>& calibration,
const std::vector<std::string>& ins_names)
{
// insert capture operator
......@@ -425,7 +426,7 @@ void quantize_int8(program& prog,
// use all calibration data to run the program to calculate the
// quantization scale and shift
for(auto&& arg : calibration_args)
for(auto&& arg : calibration)
{
program::parameter_map m;
for(auto&& x : cap_prog.get_parameter_shapes())
......@@ -443,7 +444,7 @@ void quantize_int8(program& prog,
cap_prog.eval(m);
}
quantize_int8(prog, *int8_quant_params, ins_names);
quantize_int8_impl(prog, *int8_quant_params, ins_names);
}
// For the input of each input argument, we need to insert a
......@@ -456,9 +457,8 @@ std::size_t capture_arguments(program& prog,
size_t num_quant_params = 0;
// the int8 quantization only support dot and convolution
std::vector<std::string> op_names = {"dot", "convolution"};
if(!std::all_of(ins_names.begin(), ins_names.end(), [&](auto name) {
return std::find(op_names.begin(), op_names.end(), name) != op_names.end();
}))
std::set<std::string> input_ins_names(ins_names.begin(), ins_names.end());
if (!std::includes(op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end()))
{
MIGRAPHX_THROW("CAPTURE_ARGUMENTS: input operator is not supported");
}
......
......@@ -2068,7 +2068,7 @@ TEST_CASE(op_capture)
migraphx::program capture_p = p;
migraphx::target t = migraphx::cpu::target{};
migraphx::capture_arguments(capture_p, t);
migraphx::capture_arguments(capture_p, t, {"dot"});
p.compile(migraphx::cpu::target{});
capture_p.compile(migraphx::cpu::target{});
......
......@@ -250,7 +250,7 @@ TEST_CASE(op_capture)
auto p = create_program_float();
auto op_capture_p = create_program_op();
migraphx::target t = migraphx::cpu::target{};
migraphx::capture_arguments(p, t);
migraphx::capture_arguments(p, t, {"dot", "convolution"});
EXPECT(p == op_capture_p);
}
}
......@@ -313,7 +313,7 @@ TEST_CASE(dot_float)
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{
{0.1f, 0.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
migraphx::quantize_int8(p, quant_params, {"dot"});
migraphx::quantize_int8_impl(p, quant_params, {"dot"});
migraphx::run_passes(p, {migraphx::dead_code_elimination{}});
auto qp = create_int8_quantized_prog();
......@@ -375,7 +375,7 @@ TEST_CASE(dot_double_2args)
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
migraphx::quantize_int8(p, quant_params, {"dot"});
migraphx::quantize_int8_impl(p, quant_params, {"dot"});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
......@@ -440,7 +440,7 @@ TEST_CASE(dot_large_alpha_beta_float)
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{
{0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
migraphx::quantize_int8(p, quant_params, {"dot"});
migraphx::quantize_int8_impl(p, quant_params, {"dot"});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
......@@ -504,7 +504,7 @@ TEST_CASE(dot_large_alpha_beta_int32)
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{
{0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
migraphx::quantize_int8(p, quant_params, {"dot"});
migraphx::quantize_int8_impl(p, quant_params, {"dot"});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
......@@ -548,7 +548,7 @@ TEST_CASE(dot_int32_one_arg)
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{1.0f, 1.0f}};
migraphx::quantize_int8(p, quant_params, {"dot"});
migraphx::quantize_int8_impl(p, quant_params, {"dot"});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
......@@ -622,7 +622,7 @@ TEST_CASE(dot_int32)
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{
{0.1f, 1.0f}, {0.1f, 0.0f}, {0.1f, 100.0f}};
migraphx::quantize_int8(p, quant_params, {"dot"});
migraphx::quantize_int8_impl(p, quant_params, {"dot"});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
......@@ -671,7 +671,7 @@ TEST_CASE(dot_float_convert)
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 1.0f}, {0.1f, 0.0f}};
migraphx::quantize_int8(p, quant_params, {"dot"});
migraphx::quantize_int8_impl(p, quant_params, {"dot"});
migraphx::run_passes(p, {migraphx::dead_code_elimination{}});
auto qp = create_int8_quantized_prog();
......@@ -726,7 +726,7 @@ TEST_CASE(conv_float)
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
migraphx::quantize_int8(p, quant_params, {"convolution"});
migraphx::quantize_int8_impl(p, quant_params, {"convolution"});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
......@@ -782,7 +782,7 @@ TEST_CASE(conv_int32)
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
migraphx::quantize_int8(p, quant_params, {"convolution"});
migraphx::quantize_int8_impl(p, quant_params, {"convolution"});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
......@@ -840,7 +840,7 @@ TEST_CASE(conv_half)
auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
migraphx::quantize_int8(p, quant_params, {"convolution"});
migraphx::quantize_int8_impl(p, quant_params, {"convolution"});
auto qp = create_int8_quantized_prog();
EXPECT(p == qp);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment