Unverified Commit c9b86f1c authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Module impl (#678)



* add an api get_main_module

* clang format

* modify onnx unit test for module

* clang format

* refactor ops unit test with the get_main_module

* clang format

* code backup

* clang format

* refine module c api

* add python api for module

* clang format

* fix a python api issue

* clang format

* fix cppcheck error

* clang format

* refine unit tests changes

* clang format

* code backup

* code backup

* clang format

* defer some changes to later PRs

* change return of get_main_module from ref to pointer

* clang format

* add unit tests for the get_main_module_api

* clang format

* fix cppcheck error

* clang format

* fix cppcheck error

* clang format

* add more unit tests for more code change coverage

* clang format

* fixed a unit test error

* clang format

* fix unit test

* clang format

* code backup

* code change for more code coverage

* change program to module in various passes and matcher

* clang format

* modify the pass API

* code backup

* code backup

* clang format

* code backup

* clang format

* Add option to no generate a destroy method

* Formatting

* fix some review comments

* clang format

* fix review comments

* clang format

* clang format

* code backup

* code backup

* clang format

* fix cppcheck errors

* clang format

* clang format

* fix build errors

* clang format

* modify gpu unit tests to using module

* clang format

* fix cppcheck error

* clang format

* Add flag to enable cpu backend

* Make buffers shared

* Enable optimizations

* Formatting

* fix review comments

* code backup

* clang format

* code backup

* clang format

* fix a bug related to a unit test

* clang format

* clang format

* fix a build error

* remove unnecessary code

* remove unnecessary files

* code backup

* clang format

* remove the compile function from the module class

* clang format

* clang format

* remove the context parameter from the from_value method of the module class

* code refinement

* clang format

* merge changes from develop branch

* clang format

* fix cppcheck error

* clang format

* fix a build error

* fixed a merge error

* fix cppcheck error

* fixed review comments

* clang format

* fix cppcheck error

* fix a cppcheck error

* fix cppcheck error

* fix build error caused by merge

* Add missing has_op function

* Formatting

* merge changes from develop branch

* fix a cppcheck error

* fixed some review comments

* clang format

* remove the begin/end function of the program class

* clang format

* refine code and fix cppcheck error

* clang format

* fix review comments

* clang format

* fix review comments

* clang format

* add unit tests for more code coverage

* clang format

* fix review comments

* clang format

* fix review comments

* clang format

* fix a build error in debug mode

* clang format
Co-authored-by: default avatarPaul <pfultz2@yahoo.com>
parent 1dd4e4d9
...@@ -7,8 +7,7 @@ ...@@ -7,8 +7,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
using module = program;
/** /**
* Replace instructions which take all literals with a literal of the computation. * Replace instructions which take all literals with a literal of the computation.
......
...@@ -8,8 +8,7 @@ ...@@ -8,8 +8,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
using module = program;
/** /**
* Decompose operators. * Decompose operators.
......
...@@ -8,8 +8,7 @@ ...@@ -8,8 +8,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
using module = program;
/** /**
* Rewrite batchnorm to a multiply and add. * Rewrite batchnorm to a multiply and add.
......
...@@ -7,8 +7,7 @@ ...@@ -7,8 +7,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
using module = program;
/** /**
* Rewrite pooling to reduce_mean * Rewrite pooling to reduce_mean
......
...@@ -11,8 +11,7 @@ ...@@ -11,8 +11,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
using module = program;
/** /**
* Rewrite rnn to gemm and add. * Rewrite rnn to gemm and add.
......
...@@ -9,8 +9,7 @@ ...@@ -9,8 +9,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
using module = program;
/** /**
* Schedule instructions for concurrent execution * Schedule instructions for concurrent execution
......
...@@ -15,8 +15,7 @@ ...@@ -15,8 +15,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
using module = program;
struct operation; struct operation;
#ifdef DOXYGEN #ifdef DOXYGEN
......
...@@ -7,8 +7,7 @@ ...@@ -7,8 +7,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
using module = program;
/** /**
* Simplify many algebraic instructions to more efficient versions. * Simplify many algebraic instructions to more efficient versions.
......
...@@ -8,8 +8,7 @@ ...@@ -8,8 +8,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
using module = program;
/** /**
* Eliminate redundant reshapes. * Eliminate redundant reshapes.
......
This diff is collapsed.
...@@ -35,7 +35,7 @@ struct parse_onehot : op_parser<parse_onehot> ...@@ -35,7 +35,7 @@ struct parse_onehot : op_parser<parse_onehot>
auto type = args[2]->get_shape().type(); auto type = args[2]->get_shape().type();
shape s{type, {depth, depth}}; shape s{type, {depth, depth}};
auto l_val = info.mm->add_literal({s, depth_input}); auto l_val = info.add_literal({s, depth_input});
auto gather_out = info.add_instruction(make_op("gather", {{"axis", 0}}), {l_val, args[0]}); auto gather_out = info.add_instruction(make_op("gather", {{"axis", 0}}), {l_val, args[0]});
// Finally, we need a transpose to move the inner most dim to the axis dim // Finally, we need a transpose to move the inner most dim to the axis dim
......
...@@ -70,7 +70,7 @@ instruction_ref reflect_pad(const onnx_parser::node_info& info, ...@@ -70,7 +70,7 @@ instruction_ref reflect_pad(const onnx_parser::node_info& info,
{ {
*starts_it = idx; *starts_it = idx;
*ends_it = *starts_it + 1; *ends_it = *starts_it + 1;
slices.push_back(info.mm->add_instruction( slices.push_back(info.add_instruction(
make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), input)); make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), input));
} }
// when padding on the left side, the outermost pad should be at the beginning // when padding on the left side, the outermost pad should be at the beginning
...@@ -83,7 +83,7 @@ instruction_ref reflect_pad(const onnx_parser::node_info& info, ...@@ -83,7 +83,7 @@ instruction_ref reflect_pad(const onnx_parser::node_info& info,
slices.push_back(info.add_instruction( slices.push_back(info.add_instruction(
make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), input)); make_op("slice", {{"axes", axes}, {"starts", starts}, {"ends", ends}}), input));
} }
input = info.mm->add_instruction(make_op("concat", {{"axis", axis}}), slices); input = info.add_instruction(make_op("concat", {{"axis", axis}}), slices);
} }
return input; return input;
} }
......
This diff is collapsed.
...@@ -208,17 +208,6 @@ migraphx::shape to_shape(const py::buffer_info& info) ...@@ -208,17 +208,6 @@ migraphx::shape to_shape(const py::buffer_info& info)
} }
} }
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module_wrap
{
migraphx::program* prog;
operator const migraphx::program&() const { return *prog; }
operator migraphx::program&() { return *prog; }
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
MIGRAPHX_PYBIND11_MODULE(migraphx, m) MIGRAPHX_PYBIND11_MODULE(migraphx, m)
{ {
py::class_<migraphx::shape>(m, "shape") py::class_<migraphx::shape>(m, "shape")
...@@ -258,12 +247,11 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -258,12 +247,11 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::target>(m, "target"); py::class_<migraphx::target>(m, "target");
py::class_<migraphx::module_wrap>(m, "module") py::class_<migraphx::module>(m, "module")
.def("print", [](const migraphx::module_wrap& mm) { std::cout << *mm.prog << std::endl; }) .def("print", [](const migraphx::module& mm) { std::cout << mm << std::endl; })
.def("__eq__", std::equal_to<migraphx::program>{}) .def("__eq__", std::equal_to<migraphx::module>{})
.def("__ne__", std::not_equal_to<migraphx::program>{}) .def("__ne__", std::not_equal_to<migraphx::module>{})
.def("__repr__", .def("__repr__", [](const migraphx::module& mm) { return migraphx::to_string(mm); });
[](const migraphx::module_wrap& mm) { return migraphx::to_string(*mm.prog); });
py::class_<migraphx::program>(m, "program") py::class_<migraphx::program>(m, "program")
.def("clone", [](migraphx::program& p) { return *(new migraphx::program(p)); }) .def("clone", [](migraphx::program& p) { return *(new migraphx::program(p)); })
...@@ -284,7 +272,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -284,7 +272,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("get_main_module", .def("get_main_module",
[](migraphx::program& p) { [](migraphx::program& p) {
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
return migraphx::module_wrap{mm}; return *mm;
}) })
.def("run", .def("run",
[](migraphx::program& p, py::dict params) { [](migraphx::program& p, py::dict params) {
......
...@@ -7,15 +7,14 @@ ...@@ -7,15 +7,14 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
using module = program;
namespace cpu { namespace cpu {
struct lowering struct lowering
{ {
std::string name() const { return "cpu::lowering"; } std::string name() const { return "cpu::lowering"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace cpu } // namespace cpu
......
...@@ -526,14 +526,14 @@ struct cpu_literal ...@@ -526,14 +526,14 @@ struct cpu_literal
struct cpu_apply struct cpu_apply
{ {
module* prog; module* modl;
std::unordered_map<std::string, std::function<instruction_ref(instruction_ref)>> apply_map{}; std::unordered_map<std::string, std::function<instruction_ref(instruction_ref)>> apply_map{};
std::unordered_map<instruction_ref, std::string> prog_output_names{}; std::unordered_map<instruction_ref, std::string> prog_output_names{};
instruction_ref last{}; instruction_ref last{};
void create_output_names() void create_output_names()
{ {
this->last = instruction::get_output_alias(std::prev(prog->end())); this->last = instruction::get_output_alias(std::prev(modl->end()));
if(this->last->name() == "@return") if(this->last->name() == "@return")
{ {
const auto& prog_outputs = last->inputs(); const auto& prog_outputs = last->inputs();
...@@ -558,7 +558,7 @@ struct cpu_apply ...@@ -558,7 +558,7 @@ struct cpu_apply
auto&& op = ins->get_operator(); auto&& op = ins->get_operator();
if(allocate) if(allocate)
replace(ins, make_op(cpu_name, op.to_value())); replace(ins, make_op(cpu_name, op.to_value()));
return prog->replace_instruction(ins, make_op(cpu_name, op.to_value()), ins->inputs()); return modl->replace_instruction(ins, make_op(cpu_name, op.to_value()), ins->inputs());
}); });
} }
...@@ -610,7 +610,7 @@ struct cpu_apply ...@@ -610,7 +610,7 @@ struct cpu_apply
void apply() void apply()
{ {
init(); init();
for(auto it : iterator_for(*prog)) for(auto it : iterator_for(*modl))
{ {
if(it->name() == "@literal") if(it->name() == "@literal")
{ {
...@@ -629,7 +629,7 @@ struct cpu_apply ...@@ -629,7 +629,7 @@ struct cpu_apply
instruction_ref apply_literal(instruction_ref ins) const instruction_ref apply_literal(instruction_ref ins) const
{ {
return prog->replace_instruction(ins, cpu_literal{ins->get_literal().get_argument()}); return modl->replace_instruction(ins, cpu_literal{ins->get_literal().get_argument()});
} }
instruction_ref apply_pooling(instruction_ref ins) instruction_ref apply_pooling(instruction_ref ins)
...@@ -651,7 +651,7 @@ struct cpu_apply ...@@ -651,7 +651,7 @@ struct cpu_apply
{ {
auto inputs = ins->inputs(); auto inputs = ins->inputs();
inputs.push_back(insert_allocation(ins, ins->get_shape())); inputs.push_back(insert_allocation(ins, ins->get_shape()));
return prog->replace_instruction(ins, op, inputs); return modl->replace_instruction(ins, op, inputs);
} }
instruction_ref insert_allocation(instruction_ref ins, const shape& s) instruction_ref insert_allocation(instruction_ref ins, const shape& s)
...@@ -659,18 +659,18 @@ struct cpu_apply ...@@ -659,18 +659,18 @@ struct cpu_apply
auto ins_alias = instruction::get_output_alias(ins); auto ins_alias = instruction::get_output_alias(ins);
if(last->name() == "@return" and prog_output_names.count(ins_alias) > 0) if(last->name() == "@return" and prog_output_names.count(ins_alias) > 0)
{ {
return prog->add_parameter(prog_output_names[ins_alias], s); return modl->add_parameter(prog_output_names[ins_alias], s);
} }
else if(ins == last) else if(ins == last)
{ {
return prog->add_parameter("output", s); return modl->add_parameter("output", s);
} }
return prog->insert_instruction(ins, make_op("cpu::allocate", {{"shape", to_value(s)}})); return modl->insert_instruction(ins, make_op("cpu::allocate", {{"shape", to_value(s)}}));
} }
}; };
void lowering::apply(module& p) const { cpu_apply{&p}.apply(); } void lowering::apply(module& m) const { cpu_apply{&m}.apply(); }
} // namespace cpu } // namespace cpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -356,7 +356,7 @@ struct find_triadd_layernorm ...@@ -356,7 +356,7 @@ struct find_triadd_layernorm
match::used_once(), match::all_of[match::inputs()](match::standard_shape())))); match::used_once(), match::all_of[match::inputs()](match::standard_shape()))));
} }
void apply(program& p, const match::matcher_result& r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto triadd = ins->inputs().front(); auto triadd = ins->inputs().front();
......
...@@ -7,8 +7,7 @@ ...@@ -7,8 +7,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
using module = program;
namespace gpu { namespace gpu {
......
...@@ -7,8 +7,7 @@ ...@@ -7,8 +7,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
using module = program;
namespace gpu { namespace gpu {
......
...@@ -7,8 +7,7 @@ ...@@ -7,8 +7,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
using module = program;
namespace gpu { namespace gpu {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment