Unverified Commit 0e6bd17c authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Expose context in C++ API (#1118)

Add a callable C++ API to migraphx
parent d71a7b6a
...@@ -194,6 +194,8 @@ void print_program(const program& p) { std::cout << p << std::endl; } ...@@ -194,6 +194,8 @@ void print_program(const program& p) { std::cout << p << std::endl; }
void print_module(const module& m) { std::cout << m << std::endl; } void print_module(const module& m) { std::cout << m << std::endl; }
migraphx::context get_context(const program& p) { return p.get_context(); }
} // namespace migraphx } // namespace migraphx
template <class T, class U, class Target = std::remove_pointer_t<T>> template <class T, class U, class Target = std::remove_pointer_t<T>>
...@@ -379,6 +381,16 @@ struct migraphx_quantize_int8_options ...@@ -379,6 +381,16 @@ struct migraphx_quantize_int8_options
migraphx::quantize_int8_options object; migraphx::quantize_int8_options object;
}; };
extern "C" struct migraphx_context;
struct migraphx_context
{
template <class... Ts>
migraphx_context(Ts&&... xs) : object(std::forward<Ts>(xs)...)
{
}
migraphx::context object;
};
extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape) extern "C" migraphx_status migraphx_shape_destroy(migraphx_shape_t shape)
{ {
auto api_error_result = migraphx::try_([&] { destroy((shape)); }); auto api_error_result = migraphx::try_([&] { destroy((shape)); });
...@@ -883,6 +895,17 @@ migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migrap ...@@ -883,6 +895,17 @@ migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migrap
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_program_get_context(migraphx_context_t* out,
const_migraphx_program_t program)
{
auto api_error_result = migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
*out = allocate<migraphx_context_t>(migraphx::get_context((program->object)));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_operation_destroy(migraphx_operation_t operation) extern "C" migraphx_status migraphx_operation_destroy(migraphx_operation_t operation)
{ {
auto api_error_result = migraphx::try_([&] { destroy((operation)); }); auto api_error_result = migraphx::try_([&] { destroy((operation)); });
...@@ -1324,3 +1347,13 @@ extern "C" migraphx_status migraphx_quantize_int8(migraphx_program_t prog, ...@@ -1324,3 +1347,13 @@ extern "C" migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
}); });
return api_error_result; return api_error_result;
} }
extern "C" migraphx_status migraphx_context_finish(const_migraphx_context_t context)
{
auto api_error_result = migraphx::try_([&] {
if(context == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter context: Null pointer");
(context->object).finish();
});
return api_error_result;
}
...@@ -91,6 +91,9 @@ typedef const struct migraphx_quantize_op_names* const_migraphx_quantize_op_name ...@@ -91,6 +91,9 @@ typedef const struct migraphx_quantize_op_names* const_migraphx_quantize_op_name
typedef struct migraphx_quantize_int8_options* migraphx_quantize_int8_options_t; typedef struct migraphx_quantize_int8_options* migraphx_quantize_int8_options_t;
typedef const struct migraphx_quantize_int8_options* const_migraphx_quantize_int8_options_t; typedef const struct migraphx_quantize_int8_options* const_migraphx_quantize_int8_options_t;
typedef struct migraphx_context* migraphx_context_t;
typedef const struct migraphx_context* const_migraphx_context_t;
migraphx_status migraphx_shape_destroy(migraphx_shape_t shape); migraphx_status migraphx_shape_destroy(migraphx_shape_t shape);
migraphx_status migraphx_shape_assign_to(migraphx_shape_t output, const_migraphx_shape_t input); migraphx_status migraphx_shape_assign_to(migraphx_shape_t output, const_migraphx_shape_t input);
...@@ -229,6 +232,9 @@ migraphx_status migraphx_program_run(migraphx_arguments_t* out, ...@@ -229,6 +232,9 @@ migraphx_status migraphx_program_run(migraphx_arguments_t* out,
migraphx_status migraphx_status
migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migraphx_program_t x); migraphx_program_equal(bool* out, const_migraphx_program_t program, const_migraphx_program_t x);
migraphx_status migraphx_program_get_context(migraphx_context_t* out,
const_migraphx_program_t program);
migraphx_status migraphx_operation_destroy(migraphx_operation_t operation); migraphx_status migraphx_operation_destroy(migraphx_operation_t operation);
migraphx_status migraphx_operation_assign_to(migraphx_operation_t output, migraphx_status migraphx_operation_assign_to(migraphx_operation_t output,
...@@ -355,6 +361,8 @@ migraphx_status migraphx_quantize_int8(migraphx_program_t prog, ...@@ -355,6 +361,8 @@ migraphx_status migraphx_quantize_int8(migraphx_program_t prog,
migraphx_target_t target, migraphx_target_t target,
migraphx_quantize_int8_options_t options); migraphx_quantize_int8_options_t options);
migraphx_status migraphx_context_finish(const_migraphx_context_t context);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif
......
...@@ -531,6 +531,13 @@ struct module ...@@ -531,6 +531,13 @@ struct module
void print() const { call(&migraphx_module_print, mm); } void print() const { call(&migraphx_module_print, mm); }
}; };
struct context
{
migraphx_context_t ctx;
void finish() const { call(&migraphx_context_finish, ctx); }
};
struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options) struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
{ {
compile_options() { this->make_handle(&migraphx_compile_options_create); } compile_options() { this->make_handle(&migraphx_compile_options_create); }
...@@ -627,6 +634,13 @@ struct program : MIGRAPHX_HANDLE_BASE(program) ...@@ -627,6 +634,13 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
return module{p_modu}; return module{p_modu};
} }
context get_context()
{
migraphx_context_t ctx;
call(&migraphx_program_get_context, &ctx, this->get_handle_ptr());
return context{ctx};
}
friend bool operator!=(const program& px, const program& py) { return !(px == py); } friend bool operator!=(const program& px, const program& py) { return !(px == py); }
}; };
......
...@@ -207,6 +207,10 @@ def program(h): ...@@ -207,6 +207,10 @@ def program(h):
invoke='migraphx::equal($@)', invoke='migraphx::equal($@)',
returns='bool', returns='bool',
const=True) const=True)
h.method('get_context',
invoke='migraphx::get_context($@)',
const=True,
returns='migraphx::context')
@auto_handle() @auto_handle()
...@@ -353,3 +357,8 @@ api.add_function('migraphx_quantize_int8', ...@@ -353,3 +357,8 @@ api.add_function('migraphx_quantize_int8',
target='migraphx::target', target='migraphx::target',
options='migraphx::quantize_int8_options'), options='migraphx::quantize_int8_options'),
fname='migraphx::quantize_int8_wrap') fname='migraphx::quantize_int8_wrap')
@auto_handle(ref=True)
def context(h):
h.method('finish', const=True)
...@@ -25,6 +25,23 @@ TEST_CASE(load_and_run) ...@@ -25,6 +25,23 @@ TEST_CASE(load_and_run)
CHECK(bool{shapes_before.front() == outputs.front().get_shape()}); CHECK(bool{shapes_before.front() == outputs.front().get_shape()});
} }
TEST_CASE(load_and_run_ctx)
{
auto p = migraphx::parse_onnx("conv_relu_maxpool_test.onnx");
migraphx::compile_options options;
options.set_offload_copy();
p.compile(migraphx::target("gpu"), options);
migraphx::program_parameters pp;
auto param_shapes = p.get_parameter_shapes();
for(auto&& name : param_shapes.names())
{
pp.add(name, migraphx::argument::generate(param_shapes[name]));
}
auto ctx = p.get_context();
p.eval(pp);
ctx.finish();
}
TEST_CASE(if_pl_test) TEST_CASE(if_pl_test)
{ {
auto run_prog = [&](auto cond) { auto run_prog = [&](auto cond) {
......
...@@ -194,6 +194,8 @@ void print_program(const program& p) { std::cout << p << std::endl; } ...@@ -194,6 +194,8 @@ void print_program(const program& p) { std::cout << p << std::endl; }
void print_module(const module& m) { std::cout << m << std::endl; } void print_module(const module& m) { std::cout << m << std::endl; }
migraphx::context get_context(const program& p) { return p.get_context(); }
} // namespace migraphx } // namespace migraphx
<% generate_c_api_body() %> <% generate_c_api_body() %>
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment