Unverified Commit 2466dd6f authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Refactor program to module (#684)



* code backup

* clang format

* change corresponding tool files

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent de10423f
......@@ -33,7 +33,7 @@ struct check_context
};
std::string name() const { return "check_context"; }
void apply(program& p) const { p.insert_instruction(p.begin(), op{}); }
void apply(module& p) const { p.insert_instruction(p.begin(), op{}); }
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -9,6 +9,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
/**
* Remove instructions where the output is not used.
......@@ -16,7 +17,7 @@ struct program;
struct dead_code_elimination
{
std::string name() const { return "dead_code_elimination"; }
void apply(program& p) const;
void apply(module& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -9,6 +9,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
/**
* Decompose operators.
......@@ -16,7 +17,7 @@ struct program;
struct decompose
{
std::string name() const { return "decompose"; }
void apply(program& p) const;
void apply(module& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -9,6 +9,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
/**
* Remove memory allocations. This will create a parameter which is the max of all memory used in
......@@ -19,7 +20,7 @@ struct eliminate_allocation
std::string allocation_op{};
std::size_t alignment = 32;
std::string name() const { return "eliminate_allocation"; }
void apply(program& p) const;
void apply(module& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -9,6 +9,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
/**
* Remove identical instructions.
......@@ -16,7 +17,7 @@ struct program;
struct eliminate_common_subexpression
{
std::string name() const { return "eliminate_common_subexpression"; }
void apply(program& p) const;
void apply(module& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -10,6 +10,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
/**
* Remove concat operators by having each operator can write to different chunk of memory.
......@@ -18,7 +19,7 @@ struct eliminate_concat
{
concat_optimization concat_opt;
std::string name() const { return "eliminate_concat"; }
void apply(program& p) const;
void apply(module& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -9,6 +9,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
/**
* Remove contiguous instructions by checking if the operator can use non-standard shapes.
......@@ -16,7 +17,7 @@ struct program;
struct eliminate_contiguous
{
std::string name() const { return "eliminate_contiguous"; }
void apply(program& p) const;
void apply(module& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -9,6 +9,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
/**
* Remove identity instructions. Currently when used as the last pass, it will
......@@ -18,7 +19,7 @@ struct program;
struct eliminate_identity
{
std::string name() const { return "eliminate_identity"; }
void apply(program& p) const;
void apply(module& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -11,6 +11,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
/**
* Remove pads if they can be written as an
......@@ -19,11 +20,11 @@ struct program;
struct eliminate_pad
{
std::string name() const { return "eliminate_pad"; }
void apply(program& p) const;
void apply(module& p) const;
template <class T>
void update_op(T, const instruction_ref& input, const instruction_ref& ins, program& p) const;
void update_op(T, const instruction_ref& input, const instruction_ref& ins, module& p) const;
void update_pooling(const instruction_ref& input, const instruction_ref& ins, program& p) const;
void update_pooling(const instruction_ref& input, const instruction_ref& ins, module& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -209,7 +209,7 @@ struct matcher_result
/// Match a single instruction
template <class M>
matcher_result match_instruction(program& p, instruction_ref ins, M&& m)
matcher_result match_instruction(module& p, instruction_ref ins, M&& m)
{
assert(ins != p.end());
matcher_result result;
......@@ -223,7 +223,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES)
/// Find matches for an instruction in the program
template <class... Ms>
void find_matches(program& p, instruction_ref ins, Ms&&... ms)
void find_matches(module& p, instruction_ref ins, Ms&&... ms)
{
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const
......@@ -250,7 +250,7 @@ void find_matches(program& p, instruction_ref ins, Ms&&... ms)
/// Find matches in a program
template <class... Ms>
void find_matches(program& p, Ms&&... ms)
void find_matches(module& p, Ms&&... ms)
{
for(auto ins : iterator_for(p))
{
......@@ -264,7 +264,7 @@ struct find_skip
M m;
M matcher() const { return m; }
void apply(program&, const matcher_result&) const {}
void apply(module&, const matcher_result&) const {}
};
template <class M>
......
......@@ -8,6 +8,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
/**
* Remove memory allocations. It uses graph coloring to find memory allocations that can be reused.
......@@ -17,7 +18,7 @@ struct memory_coloring
std::string allocation_op{};
bool verify = false;
std::string name() const { return "memory coloring"; }
void apply(program& p) const;
void apply(module& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -13,6 +13,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
#ifdef DOXYGEN
......@@ -23,7 +24,7 @@ struct pass
/// A unique name used to identify the pass
std::string name() const;
/// Run the pass on the program
void apply(program& p) const;
void apply(module& p) const;
};
#else
......@@ -34,7 +35,7 @@ struct pass
* struct pass
* {
* std::string name() const;
* void apply(program & p) const;
* void apply(module & p) const;
* };
*
*/
......@@ -108,7 +109,7 @@ struct pass
return (*this).private_detail_te_get_handle().name();
}
void apply(program& p) const
void apply(module& p) const
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().apply(p);
......@@ -127,8 +128,8 @@ struct pass
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0;
virtual void apply(program& p) const = 0;
virtual std::string name() const = 0;
virtual void apply(module& p) const = 0;
};
template <typename PrivateDetailTypeErasedT>
......@@ -161,7 +162,7 @@ struct pass
std::string name() const override { return private_detail_te_value.name(); }
void apply(program& p) const override { private_detail_te_value.apply(p); }
void apply(module& p) const override { private_detail_te_value.apply(p); }
PrivateDetailTypeErasedT private_detail_te_value;
};
......
......@@ -17,7 +17,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void run_passes(program& prog, const std::vector<pass>& passes, tracer trace = tracer{});
void run_passes(module& modl, const std::vector<pass>& passes, tracer trace = tracer{});
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -17,8 +17,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
using module = program;
using module_ref = module*;
using module = program;
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_COMPILE)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_EVAL)
......@@ -26,6 +25,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_EVAL)
struct program_impl;
const operation& get_operation(instruction_ref ins);
using parameter_map = std::unordered_map<std::string, argument>;
/**
* @brief Stores the instruction stream
......@@ -45,8 +45,6 @@ struct program
~program() noexcept;
using parameter_map = std::unordered_map<std::string, argument>;
template <class... Ts>
instruction_ref add_instruction(operation op, Ts... args)
{
......@@ -143,6 +141,7 @@ struct program
friend bool operator!=(const program& x, const program& y) { return !(x == y); }
module* get_main_module() { return this; }
const module* get_main_module() const { return this; }
private:
void assign(const program& p);
......
......@@ -8,6 +8,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
/**
* Replace instructions which take all literals with a literal of the computation.
......@@ -15,7 +16,7 @@ struct program;
struct propagate_constant
{
std::string name() const { return "propagate_constant"; }
void apply(program& p) const;
void apply(module& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -38,7 +38,7 @@ capture_arguments(program& prog, T&& t, const std::vector<std::string>& ins_name
void quantize_int8(program& prog,
const target& t,
const std::vector<program::parameter_map>& calibration,
const std::vector<parameter_map>& calibration,
const std::vector<std::string>& ins_names = {"dot", "convolution"});
void quantize_int8_impl(program& prog,
const std::vector<std::pair<float, float>>& quant_params,
......
......@@ -9,6 +9,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
/**
* Decompose operators.
......@@ -16,7 +17,7 @@ struct program;
struct remap
{
std::string name() const { return "remap"; }
void apply(program& p) const;
void apply(module& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -9,6 +9,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
/**
* Rewrite batchnorm to a multiply and add.
......@@ -16,7 +17,7 @@ struct program;
struct rewrite_batchnorm
{
std::string name() const { return "rewrite_batchnorm"; }
void apply(program& p) const;
void apply(module& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -8,6 +8,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
/**
* Rewrite pooling to reduce_mean
......@@ -15,7 +16,7 @@ struct program;
struct rewrite_pooling
{
std::string name() const { return "rewrite_pooling"; }
void apply(program& prog) const;
void apply(module& prog) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -12,6 +12,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
using module = program;
/**
* Rewrite rnn to gemm and add.
......@@ -19,22 +20,22 @@ struct program;
struct rewrite_rnn
{
std::string name() const { return "rewrite_rnn"; }
void apply(program& prog) const;
void apply(module& prog) const;
private:
// for vanilla rnn operators
void apply_vanilla_rnn(program& prog, instruction_ref ins) const;
void apply_vanilla_rnn(module& prog, instruction_ref ins) const;
std::vector<instruction_ref> vanilla_rnn_cell(bool is_forward,
program& prog,
module& prog,
instruction_ref ins,
std::vector<instruction_ref> inputs,
operation& actv_func) const;
std::vector<operation> vanilla_rnn_actv_funcs(instruction_ref ins) const;
// for gru operators
void apply_gru(program& prog, instruction_ref ins) const;
void apply_gru(module& prog, instruction_ref ins) const;
std::vector<instruction_ref> gru_cell(bool is_forward,
program& prog,
module& prog,
instruction_ref ins,
std::vector<instruction_ref> inputs,
int linear_before_reset,
......@@ -44,9 +45,9 @@ struct rewrite_rnn
std::vector<operation> gru_actv_funcs(instruction_ref ins) const;
// for lstm operators
void apply_lstm(program& prog, instruction_ref ins) const;
void apply_lstm(module& prog, instruction_ref ins) const;
std::vector<instruction_ref> lstm_cell(bool is_forward,
program& prog,
module& prog,
instruction_ref ins,
std::vector<instruction_ref> inputs,
const operation& actv_func1,
......@@ -55,14 +56,14 @@ struct rewrite_rnn
std::vector<operation> lstm_actv_funcs(instruction_ref ins) const;
bool is_variable_seq_lens(const program& prog, instruction_ref seq_lens) const;
instruction_ref replace_last_hs_output(program& prog,
bool is_variable_seq_lens(const module& prog, instruction_ref seq_lens) const;
instruction_ref replace_last_hs_output(module& prog,
instruction_ref ins,
instruction_ref seq_lens,
instruction_ref last_hs_output,
op::rnn_direction dirct) const;
void replace_last_cell_output(program& prog,
void replace_last_cell_output(module& prog,
instruction_ref ins,
instruction_ref seq_lens,
instruction_ref cell_outputs,
......@@ -70,9 +71,9 @@ struct rewrite_rnn
op::rnn_direction dirct) const;
std::size_t
get_seq_len(const program& prog, instruction_ref input, instruction_ref seq_lens) const;
get_seq_len(const module& prog, instruction_ref input, instruction_ref seq_lens) const;
instruction_ref pad_hidden_states(program& prog,
instruction_ref pad_hidden_states(module& prog,
instruction_ref seq,
instruction_ref seq_lens,
instruction_ref hs) const;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment