Commit 75f5ed4a authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from branch int8_quantization.

parents 3119fa01 2363d06c
...@@ -39,8 +39,6 @@ else() ...@@ -39,8 +39,6 @@ else()
set(MIGRAPHX_ENABLE_GPU Off CACHE BOOL "") set(MIGRAPHX_ENABLE_GPU Off CACHE BOOL "")
endif() endif()
set(MIGRAPHX_ENABLE_TF Off CACHE BOOL "")
add_compile_options(-std=c++14) add_compile_options(-std=c++14)
list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
......
...@@ -42,10 +42,10 @@ struct convert : unary<convert> ...@@ -42,10 +42,10 @@ struct convert : unary<convert>
float res = scale * x + shift; float res = scale * x + shift;
if(target_type == shape::int8_type) if(target_type == shape::int8_type)
{ {
int factor = (res > 0) ? 1 : -1; int factor = (res >= 0.0f) ? 1 : -1;
res = res + factor * 0.5f; res = res + factor * 0.5f;
res = res > 127.0 ? 127.0 : res; res = res > 127.0f ? 127.0f : res;
res = res < -128.0 ? -128.0 : res; res = res < -128.0f ? -128.0f : res;
} }
return res; return res;
......
...@@ -12,12 +12,7 @@ if(MIGRAPHX_ENABLE_PYTHON) ...@@ -12,12 +12,7 @@ if(MIGRAPHX_ENABLE_PYTHON)
C_VISIBILITY_PRESET hidden C_VISIBILITY_PRESET hidden
CXX_VISIBILITY_PRESET hidden CXX_VISIBILITY_PRESET hidden
) )
if(MIGRAPHX_ENABLE_TF) target_link_libraries(migraphx_py PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_cpu)
target_link_libraries(migraphx_py PRIVATE migraphx migraphx_tf migraphx_cpu)
target_compile_definitions(migraphx_py PRIVATE -DENABLE_TF)
else()
target_link_libraries(migraphx_py PRIVATE migraphx migraphx_onnx migraphx_cpu)
endif()
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
target_link_libraries(migraphx_py PRIVATE migraphx_gpu) target_link_libraries(migraphx_py PRIVATE migraphx_gpu)
target_compile_definitions(migraphx_py PRIVATE -DHAVE_GPU) target_compile_definitions(migraphx_py PRIVATE -DHAVE_GPU)
......
...@@ -6,11 +6,8 @@ ...@@ -6,11 +6,8 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#ifdef ENABLE_TF
#include <migraphx/tf.hpp> #include <migraphx/tf.hpp>
#else
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#endif
#ifdef HAVE_GPU #ifdef HAVE_GPU
#include <migraphx/gpu/target.hpp> #include <migraphx/gpu/target.hpp>
...@@ -162,16 +159,13 @@ PYBIND11_MODULE(migraphx, m) ...@@ -162,16 +159,13 @@ PYBIND11_MODULE(migraphx, m)
.def("__ne__", std::not_equal_to<migraphx::program>{}) .def("__ne__", std::not_equal_to<migraphx::program>{})
.def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); }); .def("__repr__", [](const migraphx::program& p) { return migraphx::to_string(p); });
#ifdef ENABLE_TF
m.def("parse_tf", m.def("parse_tf",
&migraphx::parse_tf, &migraphx::parse_tf,
"Parse tf protobuf (default format is nhwc)", "Parse tf protobuf (default format is nhwc)",
py::arg("filename"), py::arg("filename"),
py::arg("is_nhwc") = true); py::arg("is_nhwc") = true);
#else
m.def("parse_onnx", &migraphx::parse_onnx); m.def("parse_onnx", &migraphx::parse_onnx);
#endif
m.def("get_target", [](const std::string& name) -> migraphx::target { m.def("get_target", [](const std::string& name) -> migraphx::target {
if(name == "cpu") if(name == "cpu")
return migraphx::cpu::target{}; return migraphx::cpu::target{};
......
...@@ -327,9 +327,7 @@ void quantize_int8(program& prog, ...@@ -327,9 +327,7 @@ void quantize_int8(program& prog,
ins, ins,
op::quant_convolution{padding, stride, dilation, padding_mode, group}, op::quant_convolution{padding, stride, dilation, padding_mode, group},
converted_inputs); converted_inputs);
auto fp_conv = prog.insert_instruction( prog.replace_instruction(ins, op::convert{orig_type, adjust_factor, 0.0f}, quant_conv);
ins, op::convert{shape::float_type, adjust_factor, 0.0f}, quant_conv);
prog.replace_instruction(ins, op::convert{orig_type, 1.0f, 0.0f}, fp_conv);
} }
else else
{ {
......
...@@ -21,7 +21,7 @@ void convert(hipStream_t stream, ...@@ -21,7 +21,7 @@ void convert(hipStream_t stream,
{ {
gs_launch(stream, result.get_shape().elements())([=](auto i) { gs_launch(stream, result.get_shape().elements())([=](auto i) {
float res = input_ptr[i] * scale + shift; float res = input_ptr[i] * scale + shift;
int factor = (res > 0) ? 1 : -1; int factor = (res >= 0.0f) ? 1 : -1;
output_ptr[i] = static_cast<int8_t>( output_ptr[i] = static_cast<int8_t>(
std::min<float>(std::max<float>(-128.0f, res + factor * 0.5), 127.0f)); std::min<float>(std::max<float>(-128.0f, res + factor * 0.5), 127.0f));
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment