Unverified Commit c9b86f1c authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Module impl (#678)



* add an api get_main_module

* clang format

* modify onnx unit test for module

* clang format

* refactor ops unit test with the get_main_module

* clang format

* code backup

* clang format

* refine module c api

* add python api for module

* clang format

* fix a python api issue

* clang format

* fix cppcheck error

* clang format

* refine unit tests changes

* clang format

* code backup

* code backup

* clang format

* defer some changes to later PRs

* change return of get_main_module from ref to pointer

* clang format

* add unit tests for the get_main_module_api

* clang format

* fix cppcheck error

* clang format

* fix cppcheck error

* clang format

* add more unit tests for more code change coverage

* clang format

* fixed a unit test error

* clang format

* fix unit test

* clang format

* code backup

* code change for more code coverage

* change program to module in various passes and matcher

* clang format

* modify the pass API

* code backup

* code backup

* clang format

* code backup

* clang format

* Add option to no generate a destroy method

* Formatting

* fix some review comments

* clang format

* fix review comments

* clang format

* clang format

* code backup

* code backup

* clang format

* fix cppcheck errors

* clang format

* clang format

* fix build errors

* clang format

* modify gpu unit tests to using module

* clang format

* fix cppcheck error

* clang format

* Add flag to enable cpu backend

* Make buffers shared

* Enable optimizations

* Formatting

* fix review comments

* code backup

* clang format

* code backup

* clang format

* fix a bug related to a unit test

* clang format

* clang format

* fix a build error

* remove unnecessary code

* remove unnecessary files

* code backup

* clang format

* remove the compile function from the module class

* clang format

* clang format

* remove the context parameter from the from_value method of the module class

* code refinement

* clang format

* merge changes from develop branch

* clang format

* fix cppcheck error

* clang format

* fix a build error

* fixed a merge error

* fix cppcheck error

* fixed review comments

* clang format

* fix cppcheck error

* fix a cppcheck error

* fix cppcheck error

* fix build error caused by merge

* Add missing has_op function

* Formatting

* merge changes from develop branch

* fix a cppcheck error

* fixed some review comments

* clang format

* remove the begin/end function of the program class

* clang format

* refine code and fix cppcheck error

* clang format

* fix review comments

* clang format

* fix review comments

* clang format

* add unit tests for more code coverage

* clang format

* fix review comments

* clang format

* fix review comments

* clang format

* fix a build error in debug mode

* clang format
Co-authored-by: default avatarPaul <pfultz2@yahoo.com>
parent 1dd4e4d9
...@@ -29,6 +29,7 @@ add_library(migraphx ...@@ -29,6 +29,7 @@ add_library(migraphx
msgpack.cpp msgpack.cpp
operation.cpp operation.cpp
program.cpp program.cpp
module.cpp
quantization.cpp quantization.cpp
reduce_dims.cpp reduce_dims.cpp
remap.cpp remap.cpp
......
...@@ -53,7 +53,7 @@ struct find_dot_alpha ...@@ -53,7 +53,7 @@ struct find_dot_alpha
{ {
auto matcher() const { return match::name("dot")(match::nargs(2)); } auto matcher() const { return match::name("dot")(match::nargs(2)); }
void apply(program& p, const match::matcher_result& r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto dot = any_cast<op::dot>(ins->get_operator()); auto dot = any_cast<op::dot>(ins->get_operator());
......
...@@ -161,9 +161,9 @@ struct loader ...@@ -161,9 +161,9 @@ struct loader
} }
if(trim > 0) if(trim > 0)
{ {
auto last = std::prev(p.end(), trim);
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
mm->remove_instructions(last, p.end()); auto last = std::prev(mm->end(), trim);
mm->remove_instructions(last, mm->end());
} }
if(optimize) if(optimize)
{ {
......
...@@ -110,8 +110,8 @@ void verify_reduced(program p, ...@@ -110,8 +110,8 @@ void verify_reduced(program p,
double tolerance) double tolerance)
{ {
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto last = std::prev(p.end(), n + 1); auto last = std::prev(mm->end(), n + 1);
mm->remove_instructions(last, p.end()); mm->remove_instructions(last, mm->end());
std::cout << "Verify: " << std::endl; std::cout << "Verify: " << std::endl;
std::cout << p << std::endl; std::cout << p << std::endl;
verify_program(std::to_string(n), p, t, options, inputs, tolerance); verify_program(std::to_string(n), p, t, options, inputs, tolerance);
...@@ -123,7 +123,8 @@ void verify_reduced_program(const program& p, ...@@ -123,7 +123,8 @@ void verify_reduced_program(const program& p,
const parameter_map& inputs, const parameter_map& inputs,
double tolerance) double tolerance)
{ {
auto n = std::distance(p.begin(), p.end()); const auto* mm = p.get_main_module();
auto n = std::distance(mm->begin(), mm->end());
for(std::size_t i = 0; i < n; i++) for(std::size_t i = 0; i < n; i++)
{ {
verify_reduced(p, i, t, options, inputs, tolerance); verify_reduced(p, i, t, options, inputs, tolerance);
......
...@@ -7,8 +7,7 @@ ...@@ -7,8 +7,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
using module = program;
struct adjust_allocation struct adjust_allocation
{ {
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
struct stream_race struct stream_race
{ {
...@@ -16,7 +16,7 @@ struct stream_race ...@@ -16,7 +16,7 @@ struct stream_race
instruction_ref before; instruction_ref before;
}; };
std::vector<stream_race> analyze_streams(const program& p, const stream_model& m); std::vector<stream_race> analyze_streams(const module& p, const stream_model& m);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -8,8 +8,7 @@ ...@@ -8,8 +8,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
using module = program;
struct auto_contiguous struct auto_contiguous
{ {
......
...@@ -8,8 +8,7 @@ ...@@ -8,8 +8,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
using module = program;
/** /**
* Remove instructions where the output is not used. * Remove instructions where the output is not used.
......
...@@ -8,8 +8,7 @@ ...@@ -8,8 +8,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
using module = program;
/** /**
* Decompose operators. * Decompose operators.
......
...@@ -8,8 +8,7 @@ ...@@ -8,8 +8,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
using module = program;
/** /**
* Remove memory allocations. This will create a parameter which is the max of all memory used in * Remove memory allocations. This will create a parameter which is the max of all memory used in
......
...@@ -8,8 +8,7 @@ ...@@ -8,8 +8,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
using module = program;
/** /**
* Remove identical instructions. * Remove identical instructions.
......
...@@ -9,8 +9,7 @@ ...@@ -9,8 +9,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
using module = program;
/** /**
* Remove concat operators by having each operator can write to different chunk of memory. * Remove concat operators by having each operator can write to different chunk of memory.
......
...@@ -8,8 +8,7 @@ ...@@ -8,8 +8,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
using module = program;
/** /**
* Remove contiguous instructions by checking if the operator can use non-standard shapes. * Remove contiguous instructions by checking if the operator can use non-standard shapes.
......
...@@ -8,8 +8,7 @@ ...@@ -8,8 +8,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
using module = program;
/** /**
* Remove identity instructions. Currently when used as the last pass, it will * Remove identity instructions. Currently when used as the last pass, it will
......
...@@ -10,8 +10,7 @@ ...@@ -10,8 +10,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
using module = program;
/** /**
* Remove pads if they can be written as an * Remove pads if they can be written as an
......
...@@ -7,8 +7,7 @@ ...@@ -7,8 +7,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
using module = program;
/** /**
* Remove memory allocations. It uses graph coloring to find memory allocations that can be reused. * Remove memory allocations. It uses graph coloring to find memory allocations that can be reused.
......
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_MODULE_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_MODULE_HPP
#include <list>
#include <unordered_map>
#include <migraphx/operation.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/builtin.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/target.hpp>
#include <migraphx/compile_options.hpp>
#include <migraphx/env.hpp>
#include <migraphx/config.hpp>
#include <algorithm>
#include <iostream>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
const operation& get_operation(instruction_ref ins);
struct module_impl;
using parameter_map = std::unordered_map<std::string, argument>;
/**
* @brief Stores the instruction stream
*/
struct module
{
module();
// move constructor
module(module&&) noexcept;
// copy constructor
module(const module&);
// copy assignment operator
module& operator=(module);
~module() noexcept;
std::string name() const { return module_name; }
template <class... Ts>
instruction_ref add_instruction(operation op, Ts... args)
{
return add_instruction(op, {args...});
}
instruction_ref add_instruction(const operation& op, std::vector<instruction_ref> args);
template <class... Ts>
instruction_ref insert_instruction(instruction_ref ins, operation op, Ts... args)
{
return insert_instruction(ins, op, {args...});
}
instruction_ref
insert_instruction(instruction_ref ins, const operation& op, std::vector<instruction_ref> args);
template <class... Ts>
instruction_ref replace_instruction(instruction_ref ins, operation op, Ts... args)
{
return replace_instruction(ins, op, {args...});
}
instruction_ref replace_instruction(instruction_ref ins,
const operation& op,
std::vector<instruction_ref> args) MIGRAPHX_TIDY_CONST;
instruction_ref replace_instruction(instruction_ref ins, instruction_ref rep);
instruction_ref remove_instruction(instruction_ref ins);
instruction_ref remove_instructions(instruction_ref first, instruction_ref last);
instruction_ref move_instruction(instruction_ref src, instruction_ref dst);
instruction_ref move_instructions(instruction_ref src, instruction_ref dst);
template <class... Ts>
instruction_ref add_literal(Ts&&... xs)
{
return add_literal(literal{std::forward<Ts>(xs)...});
}
instruction_ref add_literal(literal l);
instruction_ref add_outline(const shape& s);
instruction_ref add_parameter(std::string name, shape s);
instruction_ref add_return(std::vector<instruction_ref> args);
std::vector<std::string> get_parameter_names() const;
shape get_parameter_shape(std::string name) const;
instruction_ref get_parameter(std::string name) const;
std::unordered_map<std::string, shape> get_parameter_shapes() const;
bool has_instruction(instruction_ref ins) const;
std::size_t size() const;
instruction_ref begin() const;
instruction_ref end() const;
std::vector<shape> get_output_shapes() const;
instruction_ref validate() const;
void finalize(context& ctx);
value to_value() const;
void from_value(const value& v);
void debug_print() const;
void debug_print(instruction_ref ins) const;
void debug_print(const std::vector<instruction_ref>& inss) const;
void print(const std::function<void(instruction_ref,
const std::unordered_map<instruction_ref, std::string>&)>&
print_func) const;
void print_graph(std::ostream& os, bool brief = false) const;
void print_cpp(std::ostream& os) const;
void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const;
module& sort();
friend std::ostream& operator<<(std::ostream& os, const module& m);
friend bool operator==(const module& x, const module& y);
friend bool operator!=(const module& x, const module& y) { return !(x == y); }
private:
void assign(const module& m);
std::unique_ptr<module_impl> impl;
std::string module_name;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -10,8 +10,7 @@ ...@@ -10,8 +10,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct module;
using module = program;
/** /**
* Process negative axis attributes of ops * Process negative axis attributes of ops
......
...@@ -13,7 +13,7 @@ namespace migraphx { ...@@ -13,7 +13,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
using module = program; struct module;
#ifdef DOXYGEN #ifdef DOXYGEN
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <list> #include <list>
#include <unordered_map> #include <unordered_map>
#include <migraphx/operation.hpp> #include <migraphx/operation.hpp>
#include <migraphx/module.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/builtin.hpp> #include <migraphx/builtin.hpp>
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
...@@ -17,16 +18,11 @@ ...@@ -17,16 +18,11 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
using module = program;
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_COMPILE) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_COMPILE)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_EVAL) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_EVAL)
struct program_impl; struct program_impl;
const operation& get_operation(instruction_ref ins);
using parameter_map = std::unordered_map<std::string, argument>;
/** /**
* @brief Stores the instruction stream * @brief Stores the instruction stream
*/ */
...@@ -45,52 +41,6 @@ struct program ...@@ -45,52 +41,6 @@ struct program
~program() noexcept; ~program() noexcept;
template <class... Ts>
instruction_ref add_instruction(operation op, Ts... args)
{
return add_instruction(op, {args...});
}
instruction_ref add_instruction(const operation& op, std::vector<instruction_ref> args);
template <class... Ts>
instruction_ref insert_instruction(instruction_ref ins, operation op, Ts... args)
{
return insert_instruction(ins, op, {args...});
}
instruction_ref
insert_instruction(instruction_ref ins, const operation& op, std::vector<instruction_ref> args);
template <class... Ts>
instruction_ref replace_instruction(instruction_ref ins, operation op, Ts... args)
{
return replace_instruction(ins, op, {args...});
}
instruction_ref replace_instruction(instruction_ref ins,
const operation& op,
std::vector<instruction_ref> args) MIGRAPHX_TIDY_CONST;
instruction_ref replace_instruction(instruction_ref ins, instruction_ref rep);
instruction_ref remove_instruction(instruction_ref ins);
instruction_ref remove_instructions(instruction_ref first, instruction_ref last);
instruction_ref move_instruction(instruction_ref src, instruction_ref dst);
instruction_ref move_instructions(instruction_ref src, instruction_ref dst);
template <class... Ts>
instruction_ref add_literal(Ts&&... xs)
{
return add_literal(literal{std::forward<Ts>(xs)...});
}
instruction_ref add_literal(literal l);
instruction_ref add_outline(const shape& s);
instruction_ref add_parameter(std::string name, shape s);
instruction_ref add_return(std::vector<instruction_ref> args);
std::vector<std::string> get_parameter_names() const; std::vector<std::string> get_parameter_names() const;
shape get_parameter_shape(std::string name) const; shape get_parameter_shape(std::string name) const;
...@@ -101,11 +51,7 @@ struct program ...@@ -101,11 +51,7 @@ struct program
std::vector<argument> eval(parameter_map params) const; std::vector<argument> eval(parameter_map params) const;
bool has_instruction(instruction_ref ins) const;
std::size_t size() const; std::size_t size() const;
instruction_ref begin() const;
instruction_ref end() const;
std::vector<shape> get_output_shapes() const; std::vector<shape> get_output_shapes() const;
...@@ -126,13 +72,16 @@ struct program ...@@ -126,13 +72,16 @@ struct program
void debug_print() const; void debug_print() const;
void debug_print(instruction_ref ins) const; void debug_print(instruction_ref ins) const;
void debug_print(const std::vector<instruction_ref>& inss) const; void print(const std::function<void(instruction_ref,
const std::unordered_map<instruction_ref, std::string>&)>&
print_func) const;
void print_graph(std::ostream& os, bool brief = false) const; void print_graph(std::ostream& os, bool brief = false) const;
void print_cpp(std::ostream& os) const; void print_cpp(std::ostream& os) const;
void dry_run(parameter_map params) const; void dry_run(parameter_map params) const;
void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const; void annotate(std::ostream& os, const std::function<void(instruction_ref)>& a) const;
program& sort(); program& sort();
...@@ -140,8 +89,8 @@ struct program ...@@ -140,8 +89,8 @@ struct program
friend bool operator==(const program& x, const program& y); friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); } friend bool operator!=(const program& x, const program& y) { return !(x == y); }
module* get_main_module() { return this; } module* get_main_module();
const module* get_main_module() const { return this; } const module* get_main_module() const;
private: private:
void assign(const program& p); void assign(const program& p);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment