Commit efd73d5c authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 684d5a4e
......@@ -27,8 +27,8 @@ std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments_impl(program& prog, const target& t, const std::vector<std::string>& ins_names);
template <class T>
std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments(
program& prog, T&& t, const std::vector<std::string>& ins_names)
std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments(program& prog, T&& t, const std::vector<std::string>& ins_names)
{
static_assert(std::is_same<std::remove_cv_t<std::remove_reference_t<T>>, target>{} &&
std::is_lvalue_reference<T>{},
......@@ -41,8 +41,8 @@ void quantize_int8(program& prog,
std::vector<program::parameter_map>& calibration,
const std::vector<std::string>& ins_names = {"dot", "convolution"});
void quantize_int8_impl(program& prog,
const std::vector<std::pair<float, float>>& quant_params,
const std::vector<std::string>& ins_names);
const std::vector<std::pair<float, float>>& quant_params,
const std::vector<std::string>& ins_names);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -183,9 +183,16 @@ PYBIND11_MODULE(migraphx, m)
});
m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0);
m.def("quantize_fp16", &migraphx::quantize_fp16, py::arg("prog"), py::arg("ins_names") = std::vector<std::string>{"all"});
m.def("quantize_int8", &migraphx::quantize_int8, py::arg("prog"), py::arg("t"), py::arg("calibration") = std::vector<migraphx::program::parameter_map>{},
py::arg("ins_names") = std::vector<std::string>{"dot", "convolution"});
m.def("quantize_fp16",
&migraphx::quantize_fp16,
py::arg("prog"),
py::arg("ins_names") = std::vector<std::string>{"all"});
m.def("quantize_int8",
&migraphx::quantize_int8,
py::arg("prog"),
py::arg("t"),
py::arg("calibration") = std::vector<migraphx::program::parameter_map>{},
py::arg("ins_names") = std::vector<std::string>{"dot", "convolution"});
#ifdef HAVE_GPU
m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
......
......@@ -308,8 +308,8 @@ static void ins_quantize_int8(program& prog,
// a shift, then the convert can be done as v_int8 = fp * scale + shift.
// To simplify the changes, we consider shift as 0.0f for now.
void quantize_int8_impl(program& prog,
const std::vector<std::pair<float, float>>& quant_params,
const std::vector<std::string>& ins_names)
const std::vector<std::pair<float, float>>& quant_params,
const std::vector<std::string>& ins_names)
{
if(enabled(MIGRAPHX_INT8_QUANTIZATION_PARAMS{}))
{
......@@ -325,7 +325,8 @@ void quantize_int8_impl(program& prog,
// For now, we only support the int8 quantization of gemm and convolution
std::set<std::string> op_names = {"convolution", "dot"};
std::set<std::string> input_ins_names(ins_names.begin(), ins_names.end());
if (!std::includes(op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end()))
if(!std::includes(
op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end()))
{
MIGRAPHX_THROW("QUANTIZE_INT8: only support DOT and CONVOLUTION operation");
}
......@@ -458,7 +459,8 @@ std::size_t capture_arguments(program& prog,
// the int8 quantization only support dot and convolution
std::vector<std::string> op_names = {"dot", "convolution"};
std::set<std::string> input_ins_names(ins_names.begin(), ins_names.end());
if (!std::includes(op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end()))
if(!std::includes(
op_names.begin(), op_names.end(), input_ins_names.begin(), input_ins_names.end()))
{
MIGRAPHX_THROW("CAPTURE_ARGUMENTS: input operator is not supported");
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment