Unverified Commit 158bf57c authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Add quantization c api (#547)



* add quantization_fp16 c api

* clang format

* add quantization c api

* clang format

* backup code for add_c_api of quantization

* add c/c++ api for the quantization

* clang format

* fix a cppcheck error

* fix cpp check error

* add unit test for quantization apis

* clang format

* fix cppcheck error

* clang format

* refine unit tests to cover more code changes

* clang format

* refine unit tests for more code change coverage

* add an op_names class

* clang format

* refine a unit test for more code change coverage

* code backup

* clang format

* remove unnecessary code

* fix review comments

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 866cca5b
......@@ -6,6 +6,7 @@
#include <migraphx/target.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/quantization.hpp>
#ifdef HAVE_GPU
#include <migraphx/gpu/target.hpp>
......@@ -108,6 +109,42 @@ std::vector<const char*> get_names(const std::unordered_map<std::string, Value>&
return result;
}
void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names)
{
if(names.empty())
{
names = {"all"};
}
migraphx::quantize_fp16(prog, names);
}
struct quantize_int8_options
{
std::vector<program::parameter_map> calibration = {};
std::vector<std::string> op_names = {};
};
void add_op_name(quantize_int8_options& options, const char* name)
{
options.op_names.push_back(name);
}
void add_calibration_data(quantize_int8_options& options, program::parameter_map& data)
{
options.calibration.push_back(data);
}
void quantize_int8_wrap(program& prog, const target& t, quantize_int8_options& options)
{
if(options.op_names.empty())
{
options.op_names = {"dot", "convolution"};
}
migraphx::quantize_int8(prog, t, options.calibration, options.op_names);
}
template <class T>
bool equal(const T& x, const T& y)
{
......@@ -238,6 +275,26 @@ struct migraphx_onnx_options
migraphx::onnx_options object;
};
extern "C" struct migraphx_quantize_op_names;
struct migraphx_quantize_op_names
{
template <class... Ts>
migraphx_quantize_op_names(Ts&&... xs) : object(std::forward<Ts>(xs)...)
{
}
std::vector<std::string> object;
};
extern "C" struct migraphx_quantize_int8_options;
struct migraphx_quantize_int8_options
{
template <class... Ts>
migraphx_quantize_int8_options(Ts&&... xs) : object(std::forward<Ts>(xs)...)
{
}
migraphx::quantize_int8_options object;
};
extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape)
{
return migraphx::try_([&] { destroy((shape)); });
......@@ -674,3 +731,105 @@ extern "C" migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out,
migraphx::parse_onnx_buffer((data), (size), (options->object)));
});
}
extern "C" migraphx_status
migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names)
{
return migraphx::try_([&] { destroy((quantize_op_names)); });
}
extern "C" migraphx_status
migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names)
{
return migraphx::try_([&] {
*quantize_op_names =
object_cast<migraphx_quantize_op_names_t>(allocate<std::vector<std::string>>());
});
}
extern "C" migraphx_status
migraphx_quantize_op_names_add(migraphx_quantize_op_names_t quantize_op_names, const char* name)
{
return migraphx::try_([&] {
if(quantize_op_names == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter quantize_op_names: Null pointer");
(quantize_op_names->object).push_back((name));
});
}
extern "C" migraphx_status migraphx_quantize_fp16_with_op_names(migraphx_program_t prog,
migraphx_quantize_op_names_t name)
{
return migraphx::try_([&] {
if(prog == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer");
if(name == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter name: Null pointer");
migraphx::quantize_fp16_with_op_names((prog->object), (name->object));
});
}
extern "C" migraphx_status migraphx_quantize_fp16(migraphx_program_t prog)
{
return migraphx::try_([&] {
if(prog == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer");
migraphx::quantize_fp16((prog->object));
});
}
extern "C" migraphx_status
migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize_int8_options)
{
return migraphx::try_([&] { destroy((quantize_int8_options)); });
}
extern "C" migraphx_status
migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options)
{
return migraphx::try_([&] {
*quantize_int8_options = object_cast<migraphx_quantize_int8_options_t>(
allocate<migraphx::quantize_int8_options>());
});
}
extern "C" migraphx_status
migraphx_quantize_int8_options_add_op_name(migraphx_quantize_int8_options_t quantize_int8_options,
const char* name)
{
return migraphx::try_([&] {
if(quantize_int8_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter quantize_int8_options: Null pointer");
migraphx::add_op_name((quantize_int8_options->object), (name));
});
}
extern "C" migraphx_status migraphx_quantize_int8_options_add_calibration_data(
migraphx_quantize_int8_options_t quantize_int8_options, migraphx_program_parameters_t data)
{
return migraphx::try_([&] {
if(quantize_int8_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param,
"Bad parameter quantize_int8_options: Null pointer");
if(data == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter data: Null pointer");
migraphx::add_calibration_data((quantize_int8_options->object), (data->object));
});
}
extern "C" migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
migraphx_target_t target,
migraphx_quantize_int8_options_t options)
{
return migraphx::try_([&] {
if(prog == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter prog: Null pointer");
if(target == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter target: Null pointer");
if(options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter options: Null pointer");
migraphx::quantize_int8_wrap((prog->object), (target->object), (options->object));
});
}
......@@ -70,6 +70,12 @@ typedef const struct migraphx_program* const_migraphx_program_t;
typedef struct migraphx_onnx_options* migraphx_onnx_options_t;
typedef const struct migraphx_onnx_options* const_migraphx_onnx_options_t;
typedef struct migraphx_quantize_op_names* migraphx_quantize_op_names_t;
typedef const struct migraphx_quantize_op_names* const_migraphx_quantize_op_names_t;
typedef struct migraphx_quantize_int8_options* migraphx_quantize_int8_options_t;
typedef const struct migraphx_quantize_int8_options* const_migraphx_quantize_int8_options_t;
migraphx_status migraphx_shape_destroy(migraphx_shape_t shape);
migraphx_status migraphx_shape_create(migraphx_shape_t* shape,
......@@ -197,6 +203,35 @@ migraphx_status migraphx_parse_onnx_buffer(migraphx_program_t* out,
size_t size,
migraphx_onnx_options_t options);
migraphx_status migraphx_quantize_op_names_destroy(migraphx_quantize_op_names_t quantize_op_names);
migraphx_status migraphx_quantize_op_names_create(migraphx_quantize_op_names_t* quantize_op_names);
migraphx_status migraphx_quantize_op_names_add(migraphx_quantize_op_names_t quantize_op_names,
const char* name);
migraphx_status migraphx_quantize_fp16_with_op_names(migraphx_program_t prog,
migraphx_quantize_op_names_t name);
migraphx_status migraphx_quantize_fp16(migraphx_program_t prog);
migraphx_status
migraphx_quantize_int8_options_destroy(migraphx_quantize_int8_options_t quantize_int8_options);
migraphx_status
migraphx_quantize_int8_options_create(migraphx_quantize_int8_options_t* quantize_int8_options);
migraphx_status
migraphx_quantize_int8_options_add_op_name(migraphx_quantize_int8_options_t quantize_int8_options,
const char* name);
migraphx_status migraphx_quantize_int8_options_add_calibration_data(
migraphx_quantize_int8_options_t quantize_int8_options, migraphx_program_parameters_t data);
migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
migraphx_target_t target,
migraphx_quantize_int8_options_t options);
#ifdef __cplusplus
}
#endif
......
......@@ -375,6 +375,8 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
program_parameters(migraphx_program_parameters* p, borrow) { this->set_handle(p, borrow{}); }
program_parameters(migraphx_program_parameters* p) { this->set_handle(p, borrow{}); }
program_parameters() { this->make_handle(&migraphx_program_parameters_create); }
void add(const char* pname, const argument& pargument) const
......@@ -569,6 +571,62 @@ inline program parse_onnx_buffer(const std::string& buffer)
own{});
}
struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names)
{
quantize_op_names() { this->make_handle(&migraphx_quantize_op_names_create); }
quantize_op_names(migraphx_quantize_op_names* p, own) { this->set_handle(p, own{}); }
void add(const std::string& name)
{
call(&migraphx_quantize_op_names_add, this->get_handle_ptr(), name.c_str());
}
};
// fp16 quantization apis
inline void quantize_fp16(const program& prog, const quantize_op_names& names)
{
call(&migraphx_quantize_fp16_with_op_names, prog.get_handle_ptr(), names.get_handle_ptr());
}
inline void quantize_fp16(const program& prog)
{
call(&migraphx_quantize_fp16, prog.get_handle_ptr());
}
struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options)
{
quantize_int8_options() { this->make_handle(&migraphx_quantize_int8_options_create); }
quantize_int8_options(migraphx_quantize_int8_options* p, own) { this->set_handle(p, own{}); }
quantize_int8_options(migraphx_quantize_int8_options* p, borrow)
{
this->set_handle(p, borrow{});
}
void add_op_name(const std::string& name)
{
call(&migraphx_quantize_int8_options_add_op_name, this->get_handle_ptr(), name.c_str());
}
void add_calibration_data(const program_parameters& pp)
{
call(&migraphx_quantize_int8_options_add_calibration_data,
this->get_handle_ptr(),
pp.get_handle_ptr());
}
};
inline void
quantize_int8(const program& prog, const target& ptarget, const quantize_int8_options& options)
{
call(&migraphx_quantize_int8,
prog.get_handle_ptr(),
ptarget.get_handle_ptr(),
options.get_handle_ptr());
}
} // namespace api
} // namespace migraphx
......
......@@ -203,3 +203,41 @@ api.add_function('migraphx_parse_onnx_buffer',
options='migraphx::onnx_options'),
fname='migraphx::parse_onnx_buffer',
returns='migraphx::program')
@api.handle('migraphx_quantize_op_names', 'std::vector<std::string>')
def quantize_op_names(h):
h.constructor('create')
h.method('add', api.params(name='const char*'), fname='push_back')
api.add_function('migraphx_quantize_fp16_with_op_names',
api.params(prog='migraphx::program&',
name='std::vector<std::string>&'),
fname='migraphx::quantize_fp16_with_op_names')
api.add_function('migraphx_quantize_fp16',
api.params(prog='migraphx::program&'),
fname='migraphx::quantize_fp16')
@auto_handle
def quantize_int8_options(h):
h.constructor('create')
h.method(
'add_op_name',
api.params(name='const char*'),
invoke='migraphx::add_op_name($@)',
)
h.method(
'add_calibration_data',
api.params(data='std::unordered_map<std::string, migraphx::argument>'),
invoke='migraphx::add_calibration_data($@)',
)
api.add_function('migraphx_quantize_int8',
api.params(prog='migraphx::program&',
target='migraphx::target',
options='migraphx::quantize_int8_options'),
fname='migraphx::quantize_int8_wrap')
......@@ -22,6 +22,43 @@ TEST_CASE(load_and_run)
CHECK(bool{shapes_before.front() == outputs.front().get_shape()});
}
TEST_CASE(quantize_fp16)
{
auto p1 = migraphx::parse_onnx("gemm_ex_test.onnx");
const auto& p2 = p1;
const auto& p3 = p1;
migraphx::quantize_fp16(p1);
migraphx::quantize_op_names names;
migraphx::quantize_fp16(p2, names);
CHECK(bool{p1 == p2});
names.add("dot");
migraphx::quantize_fp16(p3, names);
CHECK(bool{p1 == p3});
}
TEST_CASE(quantize_int8)
{
auto p1 = migraphx::parse_onnx("gemm_ex_test.onnx");
const auto& p2 = p1;
auto t = migraphx::target("cpu");
migraphx::quantize_int8_options options;
migraphx::quantize_int8(p1, t, options);
migraphx::program_parameters pp;
auto param_shapes = p1.get_parameter_shapes();
for(auto&& name : param_shapes.names())
{
pp.add(name, migraphx::argument::generate(param_shapes[name]));
}
options.add_calibration_data(pp);
options.add_op_name("dot");
migraphx::quantize_int8(p2, t, options);
CHECK(bool{p1 == p2});
}
TEST_CASE(load_and_run_user_input_shape)
{
migraphx::onnx_options options;
......
......@@ -1183,8 +1183,8 @@ def gemm_test():
@onnx_test
def gemm_ex_test():
m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 1, 5, 6])
m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [1, 1, 5, 7])
m1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 1, 8, 6])
m2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [1, 1, 8, 7])
m3 = helper.make_tensor_value_info('3', TensorProto.FLOAT, [1, 1, 6, 7])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 1, 6, 7])
......
......@@ -851,8 +851,8 @@ TEST_CASE(gemm_test)
TEST_CASE(gemm_ex_test)
{
migraphx::program p;
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 6}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 5, 7}});
auto l0 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 8, 6}});
auto l1 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 1, 8, 7}});
auto l2 = p.add_parameter("3", migraphx::shape{migraphx::shape::float_type, {1, 1, 6, 7}});
auto t0 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l0);
auto alpha = 0.5f;
......
......@@ -6,6 +6,7 @@
#include <migraphx/target.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/quantization.hpp>
#ifdef HAVE_GPU
#include <migraphx/gpu/target.hpp>
......@@ -108,6 +109,42 @@ std::vector<const char*> get_names(const std::unordered_map<std::string, Value>&
return result;
}
void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names)
{
if(names.empty())
{
names = {"all"};
}
migraphx::quantize_fp16(prog, names);
}
struct quantize_int8_options
{
std::vector<program::parameter_map> calibration = {};
std::vector<std::string> op_names = {};
};
void add_op_name(quantize_int8_options& options, const char* name)
{
options.op_names.push_back(name);
}
void add_calibration_data(quantize_int8_options& options, program::parameter_map& data)
{
options.calibration.push_back(data);
}
void quantize_int8_wrap(program& prog, const target& t, quantize_int8_options& options)
{
if(options.op_names.empty())
{
options.op_names = {"dot", "convolution"};
}
migraphx::quantize_int8(prog, t, options.calibration, options.op_names);
}
template <class T>
bool equal(const T& x, const T& y)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment