Unverified Commit 6554639b authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

add get_main_module api (#665)



* add an api get_main_module

* clang format

* modify onnx unit test for module

* clang format

* refactor ops unit test with the get_main_module

* clang format

* code backup

* clang format

* refine module c api

* add python api for module

* clang format

* fix a python api issue

* clang format

* fix cppcheck error

* clang format

* refine unit tests changes

* clang format

* defer some changes to later PRs

* change return of get_main_module from ref to pointer

* clang format

* add unit tests for the get_main_module_api

* clang format

* fix cppcheck error

* clang format

* add more unit tests for more code change coverage

* clang format

* fix unit test

* clang format

* code change for more code coverage

* Add option to no generate a destroy method

* Formatting

* fix some review comments

* clang format

* fix review comments

* clang format

* clang format
Co-authored-by: default avatarPaul <pfultz2@yahoo.com>
parent 500d9441
......@@ -167,7 +167,9 @@ std::vector<argument> run(program& p, const program::parameter_map& params)
std::vector<shape> get_output_shapes(program& p) { return p.get_output_shapes(); }
void print(const program& p) { std::cout << p << std::endl; }
void print_program(const program& p) { std::cout << p << std::endl; }
void print_module(const module& m) { std::cout << m << std::endl; }
} // namespace migraphx
......@@ -264,6 +266,16 @@ struct migraphx_shapes
std::vector<migraphx::shape> object;
};
extern "C" struct migraphx_module;
struct migraphx_module
{
template <class... Ts>
migraphx_module(Ts&&... xs) : object(std::forward<Ts>(xs)...)
{
}
migraphx::module object;
};
extern "C" struct migraphx_program;
struct migraphx_program
{
......@@ -616,11 +628,30 @@ migraphx_shapes_get(const_migraphx_shape_t* out, migraphx_shapes_t shapes, size_
});
}
extern "C" migraphx_status migraphx_module_print(const_migraphx_module_t module)
{
return migraphx::try_([&] {
if(module == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter module: Null pointer");
migraphx::print_module((module->object));
});
}
extern "C" migraphx_status migraphx_program_destroy(migraphx_program_t program)
{
return migraphx::try_([&] { destroy((program)); });
}
extern "C" migraphx_status migraphx_program_get_main_module(migraphx_module_t* out,
migraphx_program_t program)
{
return migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
*out = object_cast<migraphx_module_t>((program->object).get_main_module());
});
}
extern "C" migraphx_status migraphx_program_compile(migraphx_program_t program,
migraphx_target_t target,
migraphx_compile_options* options)
......@@ -664,7 +695,7 @@ extern "C" migraphx_status migraphx_program_print(const_migraphx_program_t progr
return migraphx::try_([&] {
if(program == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter program: Null pointer");
migraphx::print((program->object));
migraphx::print_program((program->object));
});
}
......
......@@ -71,6 +71,9 @@ typedef const struct migraphx_arguments* const_migraphx_arguments_t;
typedef struct migraphx_shapes* migraphx_shapes_t;
typedef const struct migraphx_shapes* const_migraphx_shapes_t;
typedef struct migraphx_module* migraphx_module_t;
typedef const struct migraphx_module* const_migraphx_module_t;
typedef struct migraphx_program* migraphx_program_t;
typedef const struct migraphx_program* const_migraphx_program_t;
......@@ -174,8 +177,13 @@ migraphx_status migraphx_shapes_size(size_t* out, migraphx_shapes_t shapes);
migraphx_status
migraphx_shapes_get(const_migraphx_shape_t* out, migraphx_shapes_t shapes, size_t idx);
migraphx_status migraphx_module_print(const_migraphx_module_t module);
migraphx_status migraphx_program_destroy(migraphx_program_t program);
migraphx_status migraphx_program_get_main_module(migraphx_module_t* out,
migraphx_program_t program);
migraphx_status migraphx_program_compile(migraphx_program_t program,
migraphx_target_t target,
migraphx_compile_options* options);
......
......@@ -6,6 +6,7 @@
#include <exception>
#include <vector>
#include <cassert>
#include <iostream>
namespace migraphx {
inline namespace api { // NOLINT
......@@ -422,6 +423,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
{
const_migraphx_argument_t pout;
call(&migraphx_arguments_get, &pout, self, pidx);
return argument(pout);
}
};
......@@ -459,6 +461,14 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
};
};
struct module
{
migraphx_module_t mm;
module(const migraphx_module_t& m) : mm(m) {}
void print() const { call(&migraphx_module_print, mm); }
};
struct program : MIGRAPHX_HANDLE_BASE(program)
{
program() {}
......@@ -514,6 +524,13 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
return pout;
}
module get_main_module()
{
migraphx_module_t p_modu;
call(&migraphx_program_get_main_module, &p_modu, this->get_handle_ptr());
return module{p_modu};
}
friend bool operator!=(const program& px, const program& py) { return !(px == py); }
};
......
......@@ -167,8 +167,14 @@ def shapes(h):
returns='const migraphx::shape&')
@auto_handle(ref=True)
def module(h):
h.method('print', invoke='migraphx::print_module($@)', const=True)
@auto_handle()
def program(h):
h.method('get_main_module', returns='migraphx::module*')
h.method(
'compile',
api.params(target='migraphx::target',
......@@ -178,7 +184,7 @@ def program(h):
h.method('get_output_shapes',
invoke='migraphx::get_output_shapes($@)',
returns='std::vector<migraphx::shape>')
h.method('print', invoke='migraphx::print($@)', const=True)
h.method('print', invoke='migraphx::print_program($@)', const=True)
h.method('sort')
h.method('run',
api.params(
......
......@@ -17,6 +17,9 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
using module = program;
using module_ref = module*;
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_COMPILE)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_EVAL)
......@@ -139,6 +142,8 @@ struct program
friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); }
module* get_main_module() { return this; }
private:
void assign(const program& p);
std::unique_ptr<program_impl> impl;
......
......@@ -208,6 +208,15 @@ migraphx::shape to_shape(const py::buffer_info& info)
}
}
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module_wrap
{
migraphx::program* prog;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
MIGRAPHX_PYBIND11_MODULE(migraphx, m)
{
py::class_<migraphx::shape>(m, "shape")
......@@ -247,6 +256,9 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::target>(m, "target");
py::class_<migraphx::module_wrap>(m, "module")
.def("print", [](const migraphx::module_wrap& mm) { std::cout << *mm.prog << std::endl; });
py::class_<migraphx::program>(m, "program")
.def("clone", [](migraphx::program& p) { return *(new migraphx::program(p)); })
.def("get_parameter_names", &migraphx::program::get_parameter_names)
......@@ -263,6 +275,11 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::arg("t"),
py::arg("offload_copy") = true,
py::arg("fast_math") = true)
.def("get_main_module",
[](migraphx::program& p) {
auto mm = p.get_main_module();
return migraphx::module_wrap{mm};
})
.def("run",
[](migraphx::program& p, py::dict params) {
migraphx::program::parameter_map pm;
......
......@@ -155,4 +155,12 @@ TEST_CASE(strided_shape)
EXPECT(s.strides() == strides);
}
TEST_CASE(get_main_module)
{
auto p = migraphx::parse_onnx("constant_fill_test.onnx");
migraphx::module mm = p.get_main_module();
mm.print();
p.print();
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
This diff is collapsed.
This diff is collapsed.
......@@ -53,6 +53,14 @@ def test_add_scalar():
print(r)
def test_module():
p = migraphx.parse_onnx("add_scalar_test.onnx")
mm = p.get_main_module()
p.print()
mm.print()
test_conv_relu()
test_module()
if sys.version_info >= (3, 0):
test_add_scalar()
......@@ -167,7 +167,9 @@ std::vector<argument> run(program& p, const program::parameter_map& params)
std::vector<shape> get_output_shapes(program& p) { return p.get_output_shapes(); }
void print(const program& p) { std::cout << p << std::endl; }
void print_program(const program& p) { std::cout << p << std::endl; }
void print_module(const module& m) { std::cout << m << std::endl; }
} // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment