Unverified Commit 2466dd6f authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Refactor program to module (#684)



* code backup

* clang format

* change corresponding tool files

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent de10423f
...@@ -33,7 +33,7 @@ struct check_context ...@@ -33,7 +33,7 @@ struct check_context
}; };
std::string name() const { return "check_context"; } std::string name() const { return "check_context"; }
void apply(program& p) const { p.insert_instruction(p.begin(), op{}); } void apply(module& p) const { p.insert_instruction(p.begin(), op{}); }
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -9,6 +9,7 @@ namespace migraphx { ...@@ -9,6 +9,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
using module = program;
/** /**
* Remove instructions where the output is not used. * Remove instructions where the output is not used.
...@@ -16,7 +17,7 @@ struct program; ...@@ -16,7 +17,7 @@ struct program;
struct dead_code_elimination struct dead_code_elimination
{ {
std::string name() const { return "dead_code_elimination"; } std::string name() const { return "dead_code_elimination"; }
void apply(program& p) const; void apply(module& p) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -9,6 +9,7 @@ namespace migraphx { ...@@ -9,6 +9,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
using module = program;
/** /**
* Decompose operators. * Decompose operators.
...@@ -16,7 +17,7 @@ struct program; ...@@ -16,7 +17,7 @@ struct program;
struct decompose struct decompose
{ {
std::string name() const { return "decompose"; } std::string name() const { return "decompose"; }
void apply(program& p) const; void apply(module& p) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -9,6 +9,7 @@ namespace migraphx { ...@@ -9,6 +9,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
using module = program;
/** /**
* Remove memory allocations. This will create a parameter which is the max of all memory used in * Remove memory allocations. This will create a parameter which is the max of all memory used in
...@@ -19,7 +20,7 @@ struct eliminate_allocation ...@@ -19,7 +20,7 @@ struct eliminate_allocation
std::string allocation_op{}; std::string allocation_op{};
std::size_t alignment = 32; std::size_t alignment = 32;
std::string name() const { return "eliminate_allocation"; } std::string name() const { return "eliminate_allocation"; }
void apply(program& p) const; void apply(module& p) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -9,6 +9,7 @@ namespace migraphx { ...@@ -9,6 +9,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
using module = program;
/** /**
* Remove identical instructions. * Remove identical instructions.
...@@ -16,7 +17,7 @@ struct program; ...@@ -16,7 +17,7 @@ struct program;
struct eliminate_common_subexpression struct eliminate_common_subexpression
{ {
std::string name() const { return "eliminate_common_subexpression"; } std::string name() const { return "eliminate_common_subexpression"; }
void apply(program& p) const; void apply(module& p) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -10,6 +10,7 @@ namespace migraphx { ...@@ -10,6 +10,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
using module = program;
/** /**
* Remove concat operators by having each operator can write to different chunk of memory. * Remove concat operators by having each operator can write to different chunk of memory.
...@@ -18,7 +19,7 @@ struct eliminate_concat ...@@ -18,7 +19,7 @@ struct eliminate_concat
{ {
concat_optimization concat_opt; concat_optimization concat_opt;
std::string name() const { return "eliminate_concat"; } std::string name() const { return "eliminate_concat"; }
void apply(program& p) const; void apply(module& p) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -9,6 +9,7 @@ namespace migraphx { ...@@ -9,6 +9,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
using module = program;
/** /**
* Remove contiguous instructions by checking if the operator can use non-standard shapes. * Remove contiguous instructions by checking if the operator can use non-standard shapes.
...@@ -16,7 +17,7 @@ struct program; ...@@ -16,7 +17,7 @@ struct program;
struct eliminate_contiguous struct eliminate_contiguous
{ {
std::string name() const { return "eliminate_contiguous"; } std::string name() const { return "eliminate_contiguous"; }
void apply(program& p) const; void apply(module& p) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -9,6 +9,7 @@ namespace migraphx { ...@@ -9,6 +9,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
using module = program;
/** /**
* Remove identity instructions. Currently when used as the last pass, it will * Remove identity instructions. Currently when used as the last pass, it will
...@@ -18,7 +19,7 @@ struct program; ...@@ -18,7 +19,7 @@ struct program;
struct eliminate_identity struct eliminate_identity
{ {
std::string name() const { return "eliminate_identity"; } std::string name() const { return "eliminate_identity"; }
void apply(program& p) const; void apply(module& p) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -11,6 +11,7 @@ namespace migraphx { ...@@ -11,6 +11,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
using module = program;
/** /**
* Remove pads if they can be written as an * Remove pads if they can be written as an
...@@ -19,11 +20,11 @@ struct program; ...@@ -19,11 +20,11 @@ struct program;
struct eliminate_pad struct eliminate_pad
{ {
std::string name() const { return "eliminate_pad"; } std::string name() const { return "eliminate_pad"; }
void apply(program& p) const; void apply(module& p) const;
template <class T> template <class T>
void update_op(T, const instruction_ref& input, const instruction_ref& ins, program& p) const; void update_op(T, const instruction_ref& input, const instruction_ref& ins, module& p) const;
void update_pooling(const instruction_ref& input, const instruction_ref& ins, program& p) const; void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& p) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -209,7 +209,7 @@ struct matcher_result ...@@ -209,7 +209,7 @@ struct matcher_result
/// Match a single instruction /// Match a single instruction
template <class M> template <class M>
matcher_result match_instruction(program& p, instruction_ref ins, M&& m) matcher_result match_instruction(module& p, instruction_ref ins, M&& m)
{ {
assert(ins != p.end()); assert(ins != p.end());
matcher_result result; matcher_result result;
...@@ -223,7 +223,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES) ...@@ -223,7 +223,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES)
/// Find matches for an instruction in the program /// Find matches for an instruction in the program
template <class... Ms> template <class... Ms>
void find_matches(program& p, instruction_ref ins, Ms&&... ms) void find_matches(module& p, instruction_ref ins, Ms&&... ms)
{ {
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5 #if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const const
...@@ -250,7 +250,7 @@ void find_matches(program& p, instruction_ref ins, Ms&&... ms) ...@@ -250,7 +250,7 @@ void find_matches(program& p, instruction_ref ins, Ms&&... ms)
/// Find matches in a program /// Find matches in a program
template <class... Ms> template <class... Ms>
void find_matches(program& p, Ms&&... ms) void find_matches(module& p, Ms&&... ms)
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
...@@ -264,7 +264,7 @@ struct find_skip ...@@ -264,7 +264,7 @@ struct find_skip
M m; M m;
M matcher() const { return m; } M matcher() const { return m; }
void apply(program&, const matcher_result&) const {} void apply(module&, const matcher_result&) const {}
}; };
template <class M> template <class M>
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
using module = program;
/** /**
* Remove memory allocations. It uses graph coloring to find memory allocations that can be reused. * Remove memory allocations. It uses graph coloring to find memory allocations that can be reused.
...@@ -17,7 +18,7 @@ struct memory_coloring ...@@ -17,7 +18,7 @@ struct memory_coloring
std::string allocation_op{}; std::string allocation_op{};
bool verify = false; bool verify = false;
std::string name() const { return "memory coloring"; } std::string name() const { return "memory coloring"; }
void apply(program& p) const; void apply(module& p) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -13,6 +13,7 @@ namespace migraphx { ...@@ -13,6 +13,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
using module = program;
#ifdef DOXYGEN #ifdef DOXYGEN
...@@ -23,7 +24,7 @@ struct pass ...@@ -23,7 +24,7 @@ struct pass
/// A unique name used to identify the pass /// A unique name used to identify the pass
std::string name() const; std::string name() const;
/// Run the pass on the program /// Run the pass on the program
void apply(program& p) const; void apply(module& p) const;
}; };
#else #else
...@@ -34,7 +35,7 @@ struct pass ...@@ -34,7 +35,7 @@ struct pass
* struct pass * struct pass
* { * {
* std::string name() const; * std::string name() const;
* void apply(program & p) const; * void apply(module & p) const;
* }; * };
* *
*/ */
...@@ -108,7 +109,7 @@ struct pass ...@@ -108,7 +109,7 @@ struct pass
return (*this).private_detail_te_get_handle().name(); return (*this).private_detail_te_get_handle().name();
} }
void apply(program& p) const void apply(module& p) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().apply(p); (*this).private_detail_te_get_handle().apply(p);
...@@ -127,8 +128,8 @@ struct pass ...@@ -127,8 +128,8 @@ struct pass
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0; virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0; virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual void apply(program& p) const = 0; virtual void apply(module& p) const = 0;
}; };
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
...@@ -161,7 +162,7 @@ struct pass ...@@ -161,7 +162,7 @@ struct pass
std::string name() const override { return private_detail_te_value.name(); } std::string name() const override { return private_detail_te_value.name(); }
void apply(program& p) const override { private_detail_te_value.apply(p); } void apply(module& p) const override { private_detail_te_value.apply(p); }
PrivateDetailTypeErasedT private_detail_te_value; PrivateDetailTypeErasedT private_detail_te_value;
}; };
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void run_passes(program& prog, const std::vector<pass>& passes, tracer trace = tracer{}); void run_passes(module& modl, const std::vector<pass>& passes, tracer trace = tracer{});
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -17,8 +17,7 @@ ...@@ -17,8 +17,7 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
using module = program; using module = program;
using module_ref = module*;
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_COMPILE) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_COMPILE)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_EVAL) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_EVAL)
...@@ -26,6 +25,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_EVAL) ...@@ -26,6 +25,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_EVAL)
struct program_impl; struct program_impl;
const operation& get_operation(instruction_ref ins); const operation& get_operation(instruction_ref ins);
using parameter_map = std::unordered_map<std::string, argument>;
/** /**
* @brief Stores the instruction stream * @brief Stores the instruction stream
...@@ -45,8 +45,6 @@ struct program ...@@ -45,8 +45,6 @@ struct program
~program() noexcept; ~program() noexcept;
using parameter_map = std::unordered_map<std::string, argument>;
template <class... Ts> template <class... Ts>
instruction_ref add_instruction(operation op, Ts... args) instruction_ref add_instruction(operation op, Ts... args)
{ {
...@@ -143,6 +141,7 @@ struct program ...@@ -143,6 +141,7 @@ struct program
friend bool operator!=(const program& x, const program& y) { return !(x == y); } friend bool operator!=(const program& x, const program& y) { return !(x == y); }
module* get_main_module() { return this; } module* get_main_module() { return this; }
const module* get_main_module() const { return this; }
private: private:
void assign(const program& p); void assign(const program& p);
......
...@@ -8,6 +8,7 @@ namespace migraphx { ...@@ -8,6 +8,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
using module = program;
/** /**
* Replace instructions which take all literals with a literal of the computation. * Replace instructions which take all literals with a literal of the computation.
...@@ -15,7 +16,7 @@ struct program; ...@@ -15,7 +16,7 @@ struct program;
struct propagate_constant struct propagate_constant
{ {
std::string name() const { return "propagate_constant"; } std::string name() const { return "propagate_constant"; }
void apply(program& p) const; void apply(module& p) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -38,7 +38,7 @@ capture_arguments(program& prog, T&& t, const std::vector<std::string>& ins_name ...@@ -38,7 +38,7 @@ capture_arguments(program& prog, T&& t, const std::vector<std::string>& ins_name
void quantize_int8(program& prog, void quantize_int8(program& prog,
const target& t, const target& t,
const std::vector<program::parameter_map>& calibration, const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names = {"dot", "convolution"}); const std::vector<std::string>& ins_names = {"dot", "convolution"});
void quantize_int8_impl(program& prog, void quantize_int8_impl(program& prog,
const std::vector<std::pair<float, float>>& quant_params, const std::vector<std::pair<float, float>>& quant_params,
......
...@@ -9,6 +9,7 @@ namespace migraphx { ...@@ -9,6 +9,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
using module = program;
/** /**
* Decompose operators. * Decompose operators.
...@@ -16,7 +17,7 @@ struct program; ...@@ -16,7 +17,7 @@ struct program;
struct remap struct remap
{ {
std::string name() const { return "remap"; } std::string name() const { return "remap"; }
void apply(program& p) const; void apply(module& p) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -9,6 +9,7 @@ namespace migraphx { ...@@ -9,6 +9,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
using module = program;
/** /**
* Rewrite batchnorm to a multiply and add. * Rewrite batchnorm to a multiply and add.
...@@ -16,7 +17,7 @@ struct program; ...@@ -16,7 +17,7 @@ struct program;
struct rewrite_batchnorm struct rewrite_batchnorm
{ {
std::string name() const { return "rewrite_batchnorm"; } std::string name() const { return "rewrite_batchnorm"; }
void apply(program& p) const; void apply(module& p) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -8,6 +8,7 @@ namespace migraphx { ...@@ -8,6 +8,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
using module = program;
/** /**
* Rewrite pooling to reduce_mean * Rewrite pooling to reduce_mean
...@@ -15,7 +16,7 @@ struct program; ...@@ -15,7 +16,7 @@ struct program;
struct rewrite_pooling struct rewrite_pooling
{ {
std::string name() const { return "rewrite_pooling"; } std::string name() const { return "rewrite_pooling"; }
void apply(program& prog) const; void apply(module& prog) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -12,6 +12,7 @@ namespace migraphx { ...@@ -12,6 +12,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
using module = program;
/** /**
* Rewrite rnn to gemm and add. * Rewrite rnn to gemm and add.
...@@ -19,22 +20,22 @@ struct program; ...@@ -19,22 +20,22 @@ struct program;
struct rewrite_rnn struct rewrite_rnn
{ {
std::string name() const { return "rewrite_rnn"; } std::string name() const { return "rewrite_rnn"; }
void apply(program& prog) const; void apply(module& prog) const;
private: private:
// for vanilla rnn operators // for vanilla rnn operators
void apply_vanilla_rnn(program& prog, instruction_ref ins) const; void apply_vanilla_rnn(module& prog, instruction_ref ins) const;
std::vector<instruction_ref> vanilla_rnn_cell(bool is_forward, std::vector<instruction_ref> vanilla_rnn_cell(bool is_forward,
program& prog, module& prog,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref> inputs, std::vector<instruction_ref> inputs,
operation& actv_func) const; operation& actv_func) const;
std::vector<operation> vanilla_rnn_actv_funcs(instruction_ref ins) const; std::vector<operation> vanilla_rnn_actv_funcs(instruction_ref ins) const;
// for gru operators // for gru operators
void apply_gru(program& prog, instruction_ref ins) const; void apply_gru(module& prog, instruction_ref ins) const;
std::vector<instruction_ref> gru_cell(bool is_forward, std::vector<instruction_ref> gru_cell(bool is_forward,
program& prog, module& prog,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref> inputs, std::vector<instruction_ref> inputs,
int linear_before_reset, int linear_before_reset,
...@@ -44,9 +45,9 @@ struct rewrite_rnn ...@@ -44,9 +45,9 @@ struct rewrite_rnn
std::vector<operation> gru_actv_funcs(instruction_ref ins) const; std::vector<operation> gru_actv_funcs(instruction_ref ins) const;
// for lstm operators // for lstm operators
void apply_lstm(program& prog, instruction_ref ins) const; void apply_lstm(module& prog, instruction_ref ins) const;
std::vector<instruction_ref> lstm_cell(bool is_forward, std::vector<instruction_ref> lstm_cell(bool is_forward,
program& prog, module& prog,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref> inputs, std::vector<instruction_ref> inputs,
const operation& actv_func1, const operation& actv_func1,
...@@ -55,14 +56,14 @@ struct rewrite_rnn ...@@ -55,14 +56,14 @@ struct rewrite_rnn
std::vector<operation> lstm_actv_funcs(instruction_ref ins) const; std::vector<operation> lstm_actv_funcs(instruction_ref ins) const;
bool is_variable_seq_lens(const program& prog, instruction_ref seq_lens) const; bool is_variable_seq_lens(const module& prog, instruction_ref seq_lens) const;
instruction_ref replace_last_hs_output(program& prog, instruction_ref replace_last_hs_output(module& prog,
instruction_ref ins, instruction_ref ins,
instruction_ref seq_lens, instruction_ref seq_lens,
instruction_ref last_hs_output, instruction_ref last_hs_output,
op::rnn_direction dirct) const; op::rnn_direction dirct) const;
void replace_last_cell_output(program& prog, void replace_last_cell_output(module& prog,
instruction_ref ins, instruction_ref ins,
instruction_ref seq_lens, instruction_ref seq_lens,
instruction_ref cell_outputs, instruction_ref cell_outputs,
...@@ -70,9 +71,9 @@ struct rewrite_rnn ...@@ -70,9 +71,9 @@ struct rewrite_rnn
op::rnn_direction dirct) const; op::rnn_direction dirct) const;
std::size_t std::size_t
get_seq_len(const program& prog, instruction_ref input, instruction_ref seq_lens) const; get_seq_len(const module& prog, instruction_ref input, instruction_ref seq_lens) const;
instruction_ref pad_hidden_states(program& prog, instruction_ref pad_hidden_states(module& prog,
instruction_ref seq, instruction_ref seq,
instruction_ref seq_lens, instruction_ref seq_lens,
instruction_ref hs) const; instruction_ref hs) const;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment