Unverified Commit 6554639b authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

add get_main_module api (#665)



* add an api get_main_module

* clang format

* modify onnx unit test for module

* clang format

* refactor ops unit test with the get_main_module

* clang format

* code backup

* clang format

* refine module c api

* add python api for module

* clang format

* fix a python api issue

* clang format

* fix cppcheck error

* clang format

* refine unit tests changes

* clang format

* defer some changes to later PRs

* change return of get_main_module from ref to pointer

* clang format

* add unit tests for the get_main_module_api

* clang format

* fix cppcheck error

* clang format

* add more unit tests for more code change coverage

* clang format

* fix unit test

* clang format

* code change for more code coverage

* Add option to no generate a destroy method

* Formatting

* fix some review comments

* clang format

* fix review comments

* clang format

* clang format
Co-authored-by: default avatarPaul <pfultz2@yahoo.com>
parent 500d9441
...@@ -167,7 +167,9 @@ std::vector<argument> run(program& p, const program::parameter_map& params) ...@@ -167,7 +167,9 @@ std::vector<argument> run(program& p, const program::parameter_map& params)
std::vector<shape> get_output_shapes(program& p) { return p.get_output_shapes(); } std::vector<shape> get_output_shapes(program& p) { return p.get_output_shapes(); }
void print(const program& p) { std::cout << p << std::endl; } void print_program(const program& p) { std::cout << p << std::endl; }
void print_module(const module& m) { std::cout << m << std::endl; }
} // namespace migraphx } // namespace migraphx
...@@ -264,6 +266,16 @@ struct migraphx_shapes ...@@ -264,6 +266,16 @@ struct migraphx_shapes
std::vector<migraphx::shape> object; std::vector<migraphx::shape> object;
}; };
extern "C" struct migraphx_module;
struct migraphx_module
{
template <class... Ts>
migraphx_module(Ts&&... xs) : object(std::forward<Ts>(xs)...)
{
}
migraphx::module object;
};
extern "C" struct migraphx_program; extern "C" struct migraphx_program;
struct migraphx_program struct migraphx_program
{ {
...@@ -616,11 +628,30 @@ migraphx_shapes_get(const_migraphx_shape_t* out, migraphx_shapes_t shapes, size_ ...@@ -616,11 +628,30 @@ migraphx_shapes_get(const_migraphx_shape_t* out, migraphx_shapes_t shapes, size_
}); });
} }
extern "C" migraphx_status migraphx_module_print(const_migraphx_module_t module)
{
return migraphx::try_([&] {
if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
migraphx::print_module((module->object));
});
}
extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program) extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program)
{ {
return migraphx::try_([&] { destroy((program)); }); return migraphx::try_([&] { destroy((program)); });
} }
extern "C" migraphx_status migraphx_program_get_main_module(migraphx_module_t* out,
migraphx_program_t program)
{
return migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
*out = object_cast<migraphx_module_t>((program->object).get_main_module());
});
}
extern "C" migraphx_status migraphx_program_compile(migraphx_program_t program, extern "C" migraphx_status migraphx_program_compile(migraphx_program_t program,
migraphx_target_t target, migraphx_target_t target,
migraphx_compile_options* options) migraphx_compile_options* options)
...@@ -664,7 +695,7 @@ extern "C" migraphx_status migraphx_program_print(const_migraphx_program_t progr ...@@ -664,7 +695,7 @@ extern "C" migraphx_status migraphx_program_print(const_migraphx_program_t progr
return migraphx::try_([&] { return migraphx::try_([&] {
if(program == nullptr) if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer"); MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
migraphx::print((program->object)); migraphx::print_program((program->object));
}); });
} }
......
...@@ -71,6 +71,9 @@ typedef const struct migraphx_arguments* const_migraphx_arguments_t; ...@@ -71,6 +71,9 @@ typedef const struct migraphx_arguments* const_migraphx_arguments_t;
typedef struct migraphx_shapes* migraphx_shapes_t; typedef struct migraphx_shapes* migraphx_shapes_t;
typedef const struct migraphx_shapes* const_migraphx_shapes_t; typedef const struct migraphx_shapes* const_migraphx_shapes_t;
typedef struct migraphx_module* migraphx_module_t;
typedef const struct migraphx_module* const_migraphx_module_t;
typedef struct migraphx_program* migraphx_program_t; typedef struct migraphx_program* migraphx_program_t;
typedef const struct migraphx_program* const_migraphx_program_t; typedef const struct migraphx_program* const_migraphx_program_t;
...@@ -174,8 +177,13 @@ migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes); ...@@ -174,8 +177,13 @@ migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes);
migraphx_status migraphx_status
migraphx_shapes_get(const_migraphx_shape_t* out, migraphx_shapes_t shapes, size_t idx); migraphx_shapes_get(const_migraphx_shape_t* out, migraphx_shapes_t shapes, size_t idx);
migraphx_status migraphx_module_print(const_migraphx_module_t module);
migraphx_status migraphx_program_destroy(migraphx_program_t program); migraphx_status migraphx_program_destroy(migraphx_program_t program);
migraphx_status migraphx_program_get_main_module(migraphx_module_t* out,
migraphx_program_t program);
migraphx_status migraphx_program_compile(migraphx_program_t program, migraphx_status migraphx_program_compile(migraphx_program_t program,
migraphx_target_t target, migraphx_target_t target,
migraphx_compile_options* options); migraphx_compile_options* options);
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <exception> #include <exception>
#include <vector> #include <vector>
#include <cassert> #include <cassert>
#include <iostream>
namespace migraphx { namespace migraphx {
inline namespace api { // NOLINT inline namespace api { // NOLINT
...@@ -422,6 +423,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments> ...@@ -422,6 +423,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
{ {
const_migraphx_argument_t pout; const_migraphx_argument_t pout;
call(&migraphx_arguments_get, &pout, self, pidx); call(&migraphx_arguments_get, &pout, self, pidx);
return argument(pout); return argument(pout);
} }
}; };
...@@ -459,6 +461,14 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes> ...@@ -459,6 +461,14 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
}; };
}; };
struct module
{
migraphx_module_t mm;
module(const migraphx_module_t& m) : mm(m) {}
void print() const { call(&migraphx_module_print, mm); }
};
struct program : MIGRAPHX_HANDLE_BASE(program) struct program : MIGRAPHX_HANDLE_BASE(program)
{ {
program() {} program() {}
...@@ -514,6 +524,13 @@ struct program : MIGRAPHX_HANDLE_BASE(program) ...@@ -514,6 +524,13 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
return pout; return pout;
} }
module get_main_module()
{
migraphx_module_t p_modu;
call(&migraphx_program_get_main_module, &p_modu, this->get_handle_ptr());
return module{p_modu};
}
friend bool operator!=(const program& px, const program& py) { return !(px == py); } friend bool operator!=(const program& px, const program& py) { return !(px == py); }
}; };
......
...@@ -167,8 +167,14 @@ def shapes(h): ...@@ -167,8 +167,14 @@ def shapes(h):
returns='const migraphx::shape&') returns='const migraphx::shape&')
@auto_handle(ref=True)
def module(h):
h.method('print', invoke='migraphx::print_module($@)', const=True)
@auto_handle() @auto_handle()
def program(h): def program(h):
h.method('get_main_module', returns='migraphx::module*')
h.method( h.method(
'compile', 'compile',
api.params(target='migraphx::target', api.params(target='migraphx::target',
...@@ -178,7 +184,7 @@ def program(h): ...@@ -178,7 +184,7 @@ def program(h):
h.method('get_output_shapes', h.method('get_output_shapes',
invoke='migraphx::get_output_shapes($@)', invoke='migraphx::get_output_shapes($@)',
returns='std::vector<migraphx::shape>') returns='std::vector<migraphx::shape>')
h.method('print', invoke='migraphx::print($@)', const=True) h.method('print', invoke='migraphx::print_program($@)', const=True)
h.method('sort') h.method('sort')
h.method('run', h.method('run',
api.params( api.params(
......
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
using module = program;
using module_ref = module*;
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_COMPILE) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_COMPILE)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_EVAL) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_EVAL)
...@@ -139,6 +142,8 @@ struct program ...@@ -139,6 +142,8 @@ struct program
friend bool operator==(const program& x, const program& y); friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); } friend bool operator!=(const program& x, const program& y) { return !(x == y); }
module* get_main_module() { return this; }
private: private:
void assign(const program& p); void assign(const program& p);
std::unique_ptr<program_impl> impl; std::unique_ptr<program_impl> impl;
......
...@@ -208,6 +208,15 @@ migraphx::shape to_shape(const py::buffer_info& info) ...@@ -208,6 +208,15 @@ migraphx::shape to_shape(const py::buffer_info& info)
} }
} }
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module_wrap
{
migraphx::program* prog;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
MIGRAPHX_PYBIND11_MODULE(migraphx, m) MIGRAPHX_PYBIND11_MODULE(migraphx, m)
{ {
py::class_<migraphx::shape>(m, "shape") py::class_<migraphx::shape>(m, "shape")
...@@ -247,6 +256,9 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -247,6 +256,9 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::target>(m, "target"); py::class_<migraphx::target>(m, "target");
py::class_<migraphx::module_wrap>(m, "module")
.def("print", [](const migraphx::module_wrap& mm) { std::cout << *mm.prog << std::endl; });
py::class_<migraphx::program>(m, "program") py::class_<migraphx::program>(m, "program")
.def("clone", [](migraphx::program& p) { return *(new migraphx::program(p)); }) .def("clone", [](migraphx::program& p) { return *(new migraphx::program(p)); })
.def("get_parameter_names", &migraphx::program::get_parameter_names) .def("get_parameter_names", &migraphx::program::get_parameter_names)
...@@ -263,6 +275,11 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -263,6 +275,11 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::arg("t"), py::arg("t"),
py::arg("offload_copy") = true, py::arg("offload_copy") = true,
py::arg("fast_math") = true) py::arg("fast_math") = true)
.def("get_main_module",
[](migraphx::program& p) {
auto mm = p.get_main_module();
return migraphx::module_wrap{mm};
})
.def("run", .def("run",
[](migraphx::program& p, py::dict params) { [](migraphx::program& p, py::dict params) {
migraphx::program::parameter_map pm; migraphx::program::parameter_map pm;
......
...@@ -155,4 +155,12 @@ TEST_CASE(strided_shape) ...@@ -155,4 +155,12 @@ TEST_CASE(strided_shape)
EXPECT(s.strides() == strides); EXPECT(s.strides() == strides);
} }
TEST_CASE(get_main_module)
{
auto p = migraphx::parse_onnx("constant_fill_test.onnx");
migraphx::module mm = p.get_main_module();
mm.print();
p.print();
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
This diff is collapsed.
This diff is collapsed.
...@@ -53,6 +53,14 @@ def test_add_scalar(): ...@@ -53,6 +53,14 @@ def test_add_scalar():
print(r) print(r)
def test_module():
p = migraphx.parse_onnx("add_scalar_test.onnx")
mm = p.get_main_module()
p.print()
mm.print()
test_conv_relu() test_conv_relu()
test_module()
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
test_add_scalar() test_add_scalar()
...@@ -167,7 +167,9 @@ std::vector<argument> run(program& p, const program::parameter_map& params) ...@@ -167,7 +167,9 @@ std::vector<argument> run(program& p, const program::parameter_map& params)
std::vector<shape> get_output_shapes(program& p) { return p.get_output_shapes(); } std::vector<shape> get_output_shapes(program& p) { return p.get_output_shapes(); }
void print(const program& p) { std::cout << p << std::endl; } void print_program(const program& p) { std::cout << p << std::endl; }
void print_module(const module& m) { std::cout << m << std::endl; }
} // namespace migraphx } // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment