Unverified Commit d9a5acbd authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into jit-vector-reduce

parents d0b7fc9a a27dd28c
...@@ -9,7 +9,19 @@ namespace migraphx { ...@@ -9,7 +9,19 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
operation make_op(const std::string& name); operation make_op(const std::string& name);
operation make_op(const std::string& name, const value& v); operation make_op(const std::string& name,
const std::initializer_list<std::pair<std::string, value>>& v);
operation make_op_from_value(const std::string& name, const value& v);
// A template overload is added for migraphx::value so the initializer_list
// cannot be passed in directly. This is to enforce at compile-time that all
// initializer_list are key-value pairs, whereas migraphx::value allows other
// types of initializer_list such as for arrays.
template <class Value>
operation make_op(const std::string& name, const Value& v)
{
return make_op_from_value(name, v);
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -156,6 +156,19 @@ struct id_matcher ...@@ -156,6 +156,19 @@ struct id_matcher
} }
}; };
// Forward declare class and constructors
template <class M>
struct basic_matcher;
template <class M>
basic_matcher<M> make_basic_matcher(M m);
template <class F>
basic_matcher<function_matcher<F>> make_basic_fun_matcher(F f);
template <class P>
basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p);
/// The basic matcher provides the all_of composability of the matcher /// The basic matcher provides the all_of composability of the matcher
template <class M> template <class M>
struct basic_matcher struct basic_matcher
...@@ -167,8 +180,8 @@ struct basic_matcher ...@@ -167,8 +180,8 @@ struct basic_matcher
{ {
// Copy m because we cant capture `this` by value // Copy m because we cant capture `this` by value
auto mm = m; auto mm = m;
return make_bf_matcher([=](matcher_context& ctx, return make_basic_fun_matcher([=](matcher_context& ctx,
instruction_ref ins) -> optional<instruction_ref> { instruction_ref ins) -> optional<instruction_ref> {
auto result = mm.match(ctx, ins); auto result = mm.match(ctx, ins);
if(result) if(result)
{ {
...@@ -239,7 +252,39 @@ struct any_matcher : any_matcher_base ...@@ -239,7 +252,39 @@ struct any_matcher : any_matcher_base
struct matcher_result struct matcher_result
{ {
std::unordered_map<std::string, instruction_ref> instructions; struct instruction_container
{
instruction_container() = default;
instruction_container(std::unordered_map<std::string, instruction_ref> x)
: ins_map(std::move(x))
{
}
instruction_ref operator[](const std::string& name) const
{
auto it = ins_map.find(name);
if(it == ins_map.end())
MIGRAPHX_THROW("Accessing name that wasn't bound in matcher: " + name);
return it->second;
}
auto find(const std::string& name) const { return ins_map.find(name); }
auto begin() const { return ins_map.cbegin(); }
auto end() const { return ins_map.cend(); }
bool has_instructions_in(const module& mod) const
{
return std::all_of(ins_map.begin(), ins_map.end(), [&](auto&& p) {
return mod.has_instruction(p.second);
});
}
private:
std::unordered_map<std::string, instruction_ref> ins_map;
};
instruction_container instructions;
instruction_ref result; instruction_ref result;
}; };
...@@ -255,6 +300,7 @@ matcher_result match_instruction(module& mod, instruction_ref ins, M&& m) ...@@ -255,6 +300,7 @@ matcher_result match_instruction(module& mod, instruction_ref ins, M&& m)
{ {
result.result = ins; result.result = ins;
result.instructions = ctx.instructions; result.instructions = ctx.instructions;
assert(result.instructions.has_instructions_in(mod));
} }
else else
{ {
...@@ -533,6 +579,18 @@ auto skip_output(Ms... ms) ...@@ -533,6 +579,18 @@ auto skip_output(Ms... ms)
}); });
} }
inline auto var(std::string s)
{
return make_basic_fun_matcher(
[=, s = std::move(s)](const matcher_context& ctx,
instruction_ref) -> optional<instruction_ref> {
auto it = ctx.instructions.find(s);
if(it == ctx.instructions.end())
return nullopt;
return it->second;
});
}
inline auto name(std::string s) inline auto name(std::string s)
{ {
return make_basic_pred_matcher( return make_basic_pred_matcher(
......
...@@ -17,7 +17,7 @@ struct memory_coloring ...@@ -17,7 +17,7 @@ struct memory_coloring
std::string allocation_op{}; std::string allocation_op{};
bool verify = false; bool verify = false;
std::string name() const { return "memory coloring"; } std::string name() const { return "memory coloring"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -15,7 +15,7 @@ struct module; ...@@ -15,7 +15,7 @@ struct module;
struct propagate_constant struct propagate_constant
{ {
std::string name() const { return "propagate_constant"; } std::string name() const { return "propagate_constant"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -16,7 +16,7 @@ struct module; ...@@ -16,7 +16,7 @@ struct module;
struct rewrite_batchnorm struct rewrite_batchnorm
{ {
std::string name() const { return "rewrite_batchnorm"; } std::string name() const { return "rewrite_batchnorm"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -15,7 +15,7 @@ struct module; ...@@ -15,7 +15,7 @@ struct module;
struct rewrite_pooling struct rewrite_pooling
{ {
std::string name() const { return "rewrite_pooling"; } std::string name() const { return "rewrite_pooling"; }
void apply(module& prog) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -19,22 +19,22 @@ struct module; ...@@ -19,22 +19,22 @@ struct module;
struct rewrite_rnn struct rewrite_rnn
{ {
std::string name() const { return "rewrite_rnn"; } std::string name() const { return "rewrite_rnn"; }
void apply(module& prog) const; void apply(module& m) const;
private: private:
// for vanilla rnn operators // for vanilla rnn operators
void apply_vanilla_rnn(module& prog, instruction_ref ins) const; void apply_vanilla_rnn(module& m, instruction_ref ins) const;
std::vector<instruction_ref> vanilla_rnn_cell(bool is_forward, std::vector<instruction_ref> vanilla_rnn_cell(bool is_forward,
module& prog, module& m,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref> inputs, std::vector<instruction_ref> inputs,
operation& actv_func) const; operation& actv_func) const;
std::vector<operation> vanilla_rnn_actv_funcs(instruction_ref ins) const; std::vector<operation> vanilla_rnn_actv_funcs(instruction_ref ins) const;
// for gru operators // for gru operators
void apply_gru(module& prog, instruction_ref ins) const; void apply_gru(module& m, instruction_ref ins) const;
std::vector<instruction_ref> gru_cell(bool is_forward, std::vector<instruction_ref> gru_cell(bool is_forward,
module& prog, module& m,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref> inputs, std::vector<instruction_ref> inputs,
int linear_before_reset, int linear_before_reset,
...@@ -44,9 +44,9 @@ struct rewrite_rnn ...@@ -44,9 +44,9 @@ struct rewrite_rnn
std::vector<operation> gru_actv_funcs(instruction_ref ins) const; std::vector<operation> gru_actv_funcs(instruction_ref ins) const;
// for lstm operators // for lstm operators
void apply_lstm(module& prog, instruction_ref ins) const; void apply_lstm(module& m, instruction_ref ins) const;
std::vector<instruction_ref> lstm_cell(bool is_forward, std::vector<instruction_ref> lstm_cell(bool is_forward,
module& prog, module& m,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref> inputs, std::vector<instruction_ref> inputs,
const operation& actv_func1, const operation& actv_func1,
...@@ -55,24 +55,23 @@ struct rewrite_rnn ...@@ -55,24 +55,23 @@ struct rewrite_rnn
std::vector<operation> lstm_actv_funcs(instruction_ref ins) const; std::vector<operation> lstm_actv_funcs(instruction_ref ins) const;
bool is_variable_seq_lens(const module& prog, instruction_ref seq_lens) const; bool is_variable_seq_lens(const module& m, instruction_ref seq_lens) const;
instruction_ref replace_last_hs_output(module& prog, instruction_ref replace_last_hs_output(module& m,
instruction_ref ins, instruction_ref ins,
instruction_ref seq_lens, instruction_ref seq_lens,
instruction_ref last_hs_output, instruction_ref last_hs_output,
op::rnn_direction dirct) const; op::rnn_direction dirct) const;
void replace_last_cell_output(module& prog, void replace_last_cell_output(module& m,
instruction_ref ins, instruction_ref ins,
instruction_ref seq_lens, instruction_ref seq_lens,
instruction_ref cell_outputs, instruction_ref cell_outputs,
instruction_ref last_cell_output, instruction_ref last_cell_output,
op::rnn_direction dirct) const; op::rnn_direction dirct) const;
std::size_t std::size_t get_seq_len(const module& m, instruction_ref input, instruction_ref seq_lens) const;
get_seq_len(const module& prog, instruction_ref input, instruction_ref seq_lens) const;
instruction_ref pad_hidden_states(module& prog, instruction_ref pad_hidden_states(module& m,
instruction_ref seq, instruction_ref seq,
instruction_ref seq_lens, instruction_ref seq_lens,
instruction_ref hs) const; instruction_ref hs) const;
......
...@@ -19,7 +19,7 @@ struct schedule ...@@ -19,7 +19,7 @@ struct schedule
schedule_model model{}; schedule_model model{};
bool enable = true; bool enable = true;
std::string name() const { return "schedule"; } std::string name() const { return "schedule"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -15,7 +15,7 @@ struct module; ...@@ -15,7 +15,7 @@ struct module;
struct simplify_algebra struct simplify_algebra
{ {
std::string name() const { return "simplify_algebra"; } std::string name() const { return "simplify_algebra"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -16,7 +16,7 @@ struct module; ...@@ -16,7 +16,7 @@ struct module;
struct simplify_reshapes struct simplify_reshapes
{ {
std::string name() const { return "simplify_reshapes"; } std::string name() const { return "simplify_reshapes"; }
void apply(module& p) const; void apply(module& m) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -5,20 +5,41 @@ namespace migraphx { ...@@ -5,20 +5,41 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
operation make_op(const std::string& name) { return load_op(name); } operation make_op(const std::string& name) { return load_op(name); }
operation make_op(const std::string& name, const value& v)
template <class F>
operation make_op_generic(const std::string& name, F for_each)
{ {
if(not(v.is_object() or (v.empty() and v.is_array())))
MIGRAPHX_THROW("Value is not an object");
auto op = load_op(name); auto op = load_op(name);
// Merge values // Merge values
value w = op.to_value(); value w = op.to_value();
for(auto&& x : v) for_each([&](const auto& key, const auto& x) {
{ if(not w.contains(key))
w.at(x.get_key()) = x.without_key(); // NOLINTNEXTLINE(performance-inefficient-string-concatenation)
} MIGRAPHX_THROW("No key '" + key + "' in " + name);
w.at(key) = x;
});
op.from_value(w); op.from_value(w);
return op; return op;
} }
operation make_op(const std::string& name,
const std::initializer_list<std::pair<std::string, value>>& v)
{
return make_op_generic(name, [&](auto f) {
for(auto&& [key, x] : v)
f(key, x);
});
}
operation make_op_from_value(const std::string& name, const value& v)
{
if(not(v.is_object() or (v.empty() and v.is_array())))
MIGRAPHX_THROW("Value is not an object for make_op: " + name);
return make_op_generic(name, [&](auto f) {
for(auto&& x : v)
f(x.get_key(), x.without_key());
});
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -4,11 +4,11 @@ ...@@ -4,11 +4,11 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void memory_coloring::apply(module& p) const void memory_coloring::apply(module& m) const
{ {
if(!enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{})) if(!enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{}))
{ {
memory_coloring_impl opt(&p, allocation_op, verify); memory_coloring_impl opt(&m, allocation_op, verify);
opt.run(); opt.run();
} }
} }
......
...@@ -20,9 +20,9 @@ bool skip_propogate(instruction_ref ins) ...@@ -20,9 +20,9 @@ bool skip_propogate(instruction_ref ins)
return false; return false;
} }
void propagate_constant::apply(module& p) const void propagate_constant::apply(module& m) const
{ {
for(auto i : iterator_for(p)) for(auto i : iterator_for(m))
{ {
if(i->name() != "@literal") if(i->name() != "@literal")
continue; continue;
...@@ -42,8 +42,8 @@ void propagate_constant::apply(module& p) const ...@@ -42,8 +42,8 @@ void propagate_constant::apply(module& p) const
if(not r.empty()) if(not r.empty())
{ {
assert(r.get_shape() == child->get_shape()); assert(r.get_shape() == child->get_shape());
auto l = p.add_literal(r.get_shape(), r.data()); auto l = m.add_literal(r.get_shape(), r.data());
self(p.replace_instruction(child, l)); self(m.replace_instruction(child, l));
} }
} }
})(i); })(i);
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void rewrite_batchnorm::apply(module& p) const void rewrite_batchnorm::apply(module& m) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(m))
{ {
if(ins->name() != "batch_norm_inference") if(ins->name() != "batch_norm_inference")
continue; continue;
...@@ -46,13 +46,13 @@ void rewrite_batchnorm::apply(module& p) const ...@@ -46,13 +46,13 @@ void rewrite_batchnorm::apply(module& p) const
}); });
auto broadcast = op::broadcast{1, ins->get_shape().lens()}; auto broadcast = op::broadcast{1, ins->get_shape().lens()};
auto a_ins = p.add_literal({a.get_shape(), a.data()}); auto a_ins = m.add_literal({a.get_shape(), a.data()});
auto a_broadcast = p.insert_instruction(ins, broadcast, a_ins); auto a_broadcast = m.insert_instruction(ins, broadcast, a_ins);
auto mul = p.insert_instruction(ins, make_op("mul"), ins->inputs().front(), a_broadcast); auto mul = m.insert_instruction(ins, make_op("mul"), ins->inputs().front(), a_broadcast);
auto b_ins = p.add_literal({b.get_shape(), b.data()}); auto b_ins = m.add_literal({b.get_shape(), b.data()});
auto b_broadcast = p.insert_instruction(ins, broadcast, b_ins); auto b_broadcast = m.insert_instruction(ins, broadcast, b_ins);
auto add = p.insert_instruction(ins, make_op("add"), mul, b_broadcast); auto add = m.insert_instruction(ins, make_op("add"), mul, b_broadcast);
p.replace_instruction(ins, add); m.replace_instruction(ins, add);
} }
} }
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void rewrite_pooling::apply(module& prog) const void rewrite_pooling::apply(module& m) const
{ {
for(auto ins : iterator_for(prog)) for(auto ins : iterator_for(m))
{ {
if(ins->name() != "pooling") if(ins->name() != "pooling")
continue; continue;
...@@ -33,26 +33,25 @@ void rewrite_pooling::apply(module& prog) const ...@@ -33,26 +33,25 @@ void rewrite_pooling::apply(module& prog) const
continue; continue;
std::int64_t n = s.lens()[0]; std::int64_t n = s.lens()[0];
std::int64_t c = s.lens()[1]; std::int64_t c = s.lens()[1];
auto reshape = prog.insert_instruction( auto reshape = m.insert_instruction(
ins, make_op("reshape", {{"dims", {n * c, -1}}}), ins->inputs().front()); ins, make_op("reshape", {{"dims", {n * c, -1}}}), ins->inputs().front());
instruction_ref pooling{}; instruction_ref pooling{};
// average pooling // average pooling
if(op.mode == op::pooling_mode::average) if(op.mode == op::pooling_mode::average)
{ {
pooling = pooling = m.insert_instruction(ins, make_op("reduce_mean", {{"axes", {1}}}), reshape);
prog.insert_instruction(ins, make_op("reduce_mean", {{"axes", {1}}}), reshape);
} }
// max pooling // max pooling
else else
{ {
pooling = prog.insert_instruction(ins, make_op("reduce_max", {{"axes", {1}}}), reshape); pooling = m.insert_instruction(ins, make_op("reduce_max", {{"axes", {1}}}), reshape);
} }
std::vector<int64_t> rsp_lens(lens.size(), 1); std::vector<int64_t> rsp_lens(lens.size(), 1);
rsp_lens[0] = n; rsp_lens[0] = n;
rsp_lens[1] = c; rsp_lens[1] = c;
prog.replace_instruction(ins, make_op("reshape", {{"dims", rsp_lens}}), pooling); m.replace_instruction(ins, make_op("reshape", {{"dims", rsp_lens}}), pooling);
} }
} }
......
This diff is collapsed.
...@@ -42,7 +42,7 @@ struct stream_info ...@@ -42,7 +42,7 @@ struct stream_info
std::unordered_map<instruction_ref, std::size_t> iweights; std::unordered_map<instruction_ref, std::size_t> iweights;
ins_dep_map mod_implicit_deps; ins_dep_map mod_implicit_deps;
void calc_implicit_deps(const module& p) { mod_implicit_deps = p.calc_implicit_deps(); } void calc_implicit_deps(const module& m) { mod_implicit_deps = m.calc_implicit_deps(); }
void accumulate_weights(instruction_ref last, const schedule_model& model) void accumulate_weights(instruction_ref last, const schedule_model& model)
{ {
...@@ -116,15 +116,15 @@ struct stream_info ...@@ -116,15 +116,15 @@ struct stream_info
} }
}; };
std::size_t assign_streams(module& p, std::size_t n) std::size_t assign_streams(module& m, std::size_t n)
{ {
assert(n > 0); assert(n > 0);
partition critical; partition critical;
std::unordered_map<instruction_ref, std::deque<partition>> partitions; std::unordered_map<instruction_ref, std::deque<partition>> partitions;
partitions.reserve(weights.size()); partitions.reserve(weights.size());
fix([&](auto self, auto ins, auto& part) { fix([&](auto self, auto ins, auto& part) {
assert(not is_end(ins, p.end())); assert(not is_end(ins, m.end()));
if(not p.has_instruction(ins)) if(not m.has_instruction(ins))
return; return;
if(contains(partitions, ins)) if(contains(partitions, ins))
return; return;
...@@ -151,8 +151,8 @@ struct stream_info ...@@ -151,8 +151,8 @@ struct stream_info
} }
} }
// Sort instructions // Sort instructions
p.move_instruction(ins, p.end()); m.move_instruction(ins, m.end());
})(std::prev(p.end()), critical); })(std::prev(m.end()), critical);
// Set the critical partition to stream 0 // Set the critical partition to stream 0
set_stream(critical, 0); set_stream(critical, 0);
...@@ -197,13 +197,13 @@ struct stream_info ...@@ -197,13 +197,13 @@ struct stream_info
} }
}; };
void sort(module& p, std::size_t) void sort(module& m, std::size_t)
{ {
std::set<weight_ins, compare_weight_ins> children; std::set<weight_ins, compare_weight_ins> children;
std::unordered_map<instruction_ref, std::size_t> visited; std::unordered_map<instruction_ref, std::size_t> visited;
auto last = std::prev(p.end()); auto last = std::prev(m.end());
auto mw = this->weights.at(last); auto mw = this->weights.at(last);
auto nw = mw / (p.size() + 1); auto nw = mw / (m.size() + 1);
auto add_child = [&](auto ins) { auto add_child = [&](auto ins) {
auto x = 1 + (mw - this->weights.at(ins)) / (nw + 1); auto x = 1 + (mw - this->weights.at(ins)) / (nw + 1);
auto w = x * this->iweights.at(ins); auto w = x * this->iweights.at(ins);
...@@ -222,10 +222,10 @@ struct stream_info ...@@ -222,10 +222,10 @@ struct stream_info
// Pop the first element // Pop the first element
auto top = children.begin()->second; auto top = children.begin()->second;
children.erase(children.begin()); children.erase(children.begin());
p.move_instruction(top, p.begin()); m.move_instruction(top, m.begin());
for(auto ins : top->inputs()) for(auto ins : top->inputs())
{ {
if(not p.has_instruction(ins)) if(not m.has_instruction(ins))
continue; continue;
add_child(ins); add_child(ins);
} }
...@@ -234,7 +234,7 @@ struct stream_info ...@@ -234,7 +234,7 @@ struct stream_info
{ {
for(auto ins : mod_implicit_deps.at(top)) for(auto ins : mod_implicit_deps.at(top))
{ {
assert(p.has_instruction(ins)); assert(m.has_instruction(ins));
add_child(ins); add_child(ins);
} }
} }
...@@ -242,12 +242,12 @@ struct stream_info ...@@ -242,12 +242,12 @@ struct stream_info
// move dangling parameter to the front so as not be removed // move dangling parameter to the front so as not be removed
auto ins = std::next(last); auto ins = std::next(last);
while(ins != p.end()) while(ins != m.end())
{ {
auto next = std::next(ins); auto next = std::next(ins);
if(ins->name() == "@param") if(ins->name() == "@param")
{ {
p.move_instruction(ins, p.begin()); m.move_instruction(ins, m.begin());
} }
ins = next; ins = next;
} }
...@@ -364,18 +364,18 @@ struct stream_info ...@@ -364,18 +364,18 @@ struct stream_info
} }
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>>
find_concurrent_instructions(module& p) const find_concurrent_instructions(module& m) const
{ {
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result; std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result;
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> merge_from; std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> merge_from;
dominator_info di = compute_dominator(p); dominator_info di = compute_dominator(m);
result.reserve(p.size()); result.reserve(m.size());
merge_from.reserve(p.size()); merge_from.reserve(m.size());
for(auto ins : reverse_iterator_for(p)) for(auto ins : reverse_iterator_for(m))
{ {
for(auto&& arg : ins->outputs()) for(auto&& arg : ins->outputs())
{ {
if(not p.has_instruction(arg)) if(not m.has_instruction(arg))
continue; continue;
if(is_merge_point(arg)) if(is_merge_point(arg))
merge_from[ins].insert(arg); merge_from[ins].insert(arg);
...@@ -415,18 +415,18 @@ struct stream_info ...@@ -415,18 +415,18 @@ struct stream_info
} }
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>
get_conflicts(module& p) get_conflicts(module& m)
{ {
using conflict_table_type = using conflict_table_type =
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>; std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>;
conflict_table_type conflict_table; conflict_table_type conflict_table;
auto concur_ins = this->find_concurrent_instructions(p); auto concur_ins = this->find_concurrent_instructions(m);
// Compute an index for each instruction // Compute an index for each instruction
std::unordered_map<instruction_ref, std::size_t> ins2index; std::unordered_map<instruction_ref, std::size_t> ins2index;
std::size_t index_total = 0; std::size_t index_total = 0;
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(m))
ins2index[ins] = index_total++; ins2index[ins] = index_total++;
std::vector<conflict_table_type> thread_conflict_tables( std::vector<conflict_table_type> thread_conflict_tables(
...@@ -507,21 +507,21 @@ struct stream_info ...@@ -507,21 +507,21 @@ struct stream_info
} }
}; };
void schedule::apply(module& p) const void schedule::apply(module& m) const
{ {
if(not enable) if(not enable)
return; return;
stream_info si; stream_info si;
si.calc_implicit_deps(p); si.calc_implicit_deps(m);
auto last = std::prev(p.end()); auto last = std::prev(m.end());
si.accumulate_weights(last, model); si.accumulate_weights(last, model);
auto nstreams = si.assign_streams(p, model.concurrency()); auto nstreams = si.assign_streams(m, model.concurrency());
si.sort(p, model.concurrency()); si.sort(m, model.concurrency());
if(enabled(MIGRAPHX_TRACE_COMPILE{}) or enabled(MIGRAPHX_TRACE_SCHEDULE{})) if(enabled(MIGRAPHX_TRACE_COMPILE{}) or enabled(MIGRAPHX_TRACE_SCHEDULE{}))
{ {
p.annotate(std::cout, [&](auto ins) { m.annotate(std::cout, [&](auto ins) {
if(ins->name() == "@param" and not contains(si.weights, ins)) if(ins->name() == "@param" and not contains(si.weights, ins))
return; return;
...@@ -548,9 +548,9 @@ void schedule::apply(module& p) const ...@@ -548,9 +548,9 @@ void schedule::apply(module& p) const
std::unordered_map<instruction_ref, std::size_t> ins2wait; std::unordered_map<instruction_ref, std::size_t> ins2wait;
std::unordered_map<std::size_t, std::unordered_set<std::size_t>> waited_for; std::unordered_map<std::size_t, std::unordered_set<std::size_t>> waited_for;
std::unordered_map<instruction_ref, std::unordered_set<std::size_t>> ins2waited; std::unordered_map<instruction_ref, std::unordered_set<std::size_t>> ins2waited;
ins2wait.reserve(p.size()); ins2wait.reserve(m.size());
ins2waited.reserve(p.size()); ins2waited.reserve(m.size());
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(m))
{ {
// Only schedule instructions that have a stream // Only schedule instructions that have a stream
if(not si.has_stream(ins)) if(not si.has_stream(ins))
...@@ -559,7 +559,7 @@ void schedule::apply(module& p) const ...@@ -559,7 +559,7 @@ void schedule::apply(module& p) const
// Schedule instruction on the stream // Schedule instruction on the stream
auto stream = si.get_stream(ins); auto stream = si.get_stream(ins);
assert(stream < model.concurrency()); assert(stream < model.concurrency());
model.sched(p, ins, stream); model.sched(m, ins, stream);
// Insert wait instructions // Insert wait instructions
if(si.is_merge_point(ins, stream)) if(si.is_merge_point(ins, stream))
{ {
...@@ -572,14 +572,14 @@ void schedule::apply(module& p) const ...@@ -572,14 +572,14 @@ void schedule::apply(module& p) const
if(not contains(ins2wait, i)) if(not contains(ins2wait, i))
{ {
ins2wait[i] = wait_id; ins2wait[i] = wait_id;
model.record(p, i, wait_id); model.record(m, i, wait_id);
wait_id++; wait_id++;
} }
auto w = ins2wait.at(i); auto w = ins2wait.at(i);
// If we already waited for the event on this stream then dont // If we already waited for the event on this stream then dont
// insert another wait event // insert another wait event
if(not contains(waited_for[stream], w)) if(not contains(waited_for[stream], w))
model.wait(p, ins, w); model.wait(m, ins, w);
// Store the event as waited // Store the event as waited
waited_for[stream].insert(w); waited_for[stream].insert(w);
// Store all wait events that have been waited on prior to the recorded instruction // Store all wait events that have been waited on prior to the recorded instruction
...@@ -594,7 +594,7 @@ void schedule::apply(module& p) const ...@@ -594,7 +594,7 @@ void schedule::apply(module& p) const
} }
// Add memory conflicts // Add memory conflicts
auto conflict_table = si.get_conflicts(p); auto conflict_table = si.get_conflicts(m);
for(auto&& ip : conflict_table) for(auto&& ip : conflict_table)
{ {
if(ip.second.empty()) if(ip.second.empty())
...@@ -602,7 +602,7 @@ void schedule::apply(module& p) const ...@@ -602,7 +602,7 @@ void schedule::apply(module& p) const
std::vector<instruction_ref> args; std::vector<instruction_ref> args;
args.push_back(ip.first); args.push_back(ip.first);
args.insert(args.end(), ip.second.begin(), ip.second.end()); args.insert(args.end(), ip.second.begin(), ip.second.end());
p.insert_instruction(std::next(ip.first), make_op("identity"), args); m.insert_instruction(std::next(ip.first), make_op("identity"), args);
} }
} }
......
This diff is collapsed.
...@@ -53,7 +53,7 @@ struct match_find_quantizable_ops ...@@ -53,7 +53,7 @@ struct match_find_quantizable_ops
match::arg(1)(dequantizelinear_op("x2", "scale2"))); match::arg(1)(dequantizelinear_op("x2", "scale2")));
} }
void apply(module& m, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto qop = r.result; auto qop = r.result;
auto q1 = r.instructions["x1"]; auto q1 = r.instructions["x1"];
......
...@@ -70,19 +70,19 @@ struct find_reshaper ...@@ -70,19 +70,19 @@ struct find_reshaper
match::any_of[match::outputs()](match::name(reshaper_names()))); match::any_of[match::outputs()](match::name(reshaper_names())));
} }
void apply(module& p, const match::matcher_result& mr) const void apply(module& m, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
std::vector<instruction_ref> reshapes{ins}; std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back())) while(is_reshaper(reshapes.back()))
{ {
assert(!reshapes.back()->inputs().empty()); assert(!reshapes.back()->inputs().empty());
assert(p.has_instruction(reshapes.back()->inputs().front())); assert(m.has_instruction(reshapes.back()->inputs().front()));
auto input = reshapes.back()->inputs().front(); auto input = reshapes.back()->inputs().front();
reshapes.push_back(input); reshapes.push_back(input);
} }
std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()}; std::pair<instruction_ref, instruction_ref> r{m.end(), m.end()};
for(auto start : iterator_for(reshapes)) for(auto start : iterator_for(reshapes))
{ {
auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) { auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
...@@ -96,7 +96,7 @@ struct find_reshaper ...@@ -96,7 +96,7 @@ struct find_reshaper
} }
if(r.first != r.second) if(r.first != r.second)
{ {
p.replace_instruction(r.first, r.second); m.replace_instruction(r.first, r.second);
} }
} }
}; };
...@@ -117,10 +117,10 @@ struct find_nop_reshapes ...@@ -117,10 +117,10 @@ struct find_nop_reshapes
return match::name(reshapes)(match::same_shape(match::arg(0))); return match::name(reshapes)(match::same_shape(match::arg(0)));
} }
void apply(module& p, const match::matcher_result& mr) const void apply(module& m, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
p.replace_instruction(ins, ins->inputs().front()); m.replace_instruction(ins, ins->inputs().front());
} }
}; };
...@@ -132,7 +132,7 @@ struct find_transpose ...@@ -132,7 +132,7 @@ struct find_transpose
match::skip_output(match::name("contiguous"))(match::name("transpose")))); match::skip_output(match::name("contiguous"))(match::name("transpose"))));
} }
void apply(module& p, const match::matcher_result& mr) const void apply(module& m, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
auto x = ins; auto x = ins;
...@@ -149,11 +149,11 @@ struct find_transpose ...@@ -149,11 +149,11 @@ struct find_transpose
return; return;
if(is_no_transpose(dims)) if(is_no_transpose(dims))
{ {
p.replace_instruction(ins, t->inputs().front()); m.replace_instruction(ins, t->inputs().front());
} }
else else
{ {
p.replace_instruction( m.replace_instruction(
ins, make_op("transpose", {{"permutation", dims}}), t->inputs().front()); ins, make_op("transpose", {{"permutation", dims}}), t->inputs().front());
} }
} }
...@@ -223,7 +223,7 @@ struct find_nested_slice ...@@ -223,7 +223,7 @@ struct find_nested_slice
return result; return result;
} }
void apply(module& p, const match::matcher_result& mr) const void apply(module& m, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
auto slice = ins->inputs().front(); auto slice = ins->inputs().front();
...@@ -241,7 +241,7 @@ struct find_nested_slice ...@@ -241,7 +241,7 @@ struct find_nested_slice
op.starts.push_back(pp.second.first); op.starts.push_back(pp.second.first);
op.ends.push_back(pp.second.second); op.ends.push_back(pp.second.second);
} }
p.replace_instruction(ins, op, input); m.replace_instruction(ins, op, input);
} }
}; };
...@@ -252,7 +252,7 @@ struct find_concat_transpose ...@@ -252,7 +252,7 @@ struct find_concat_transpose
return match::name("concat")(match::all_of[match::inputs()](match::transpose_shape())); return match::name("concat")(match::all_of[match::inputs()](match::transpose_shape()));
} }
void apply(module& p, const match::matcher_result& mr) const void apply(module& m, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
auto trans_inputs = ins->inputs(); auto trans_inputs = ins->inputs();
...@@ -279,14 +279,14 @@ struct find_concat_transpose ...@@ -279,14 +279,14 @@ struct find_concat_transpose
std::vector<instruction_ref> inputs; std::vector<instruction_ref> inputs;
std::transform( std::transform(
ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) { ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) {
return p.insert_instruction( return m.insert_instruction(
ins, make_op("transpose", {{"permutation", permutation}}), i); ins, make_op("transpose", {{"permutation", permutation}}), i);
}); });
auto concat = p.insert_instruction(ins, op, inputs); auto concat = m.insert_instruction(ins, op, inputs);
auto t = p.insert_instruction( auto t = m.insert_instruction(
ins, make_op("transpose", {{"permutation", ipermutation}}), concat); ins, make_op("transpose", {{"permutation", ipermutation}}), concat);
assert(ins->get_shape().lens() == t->get_shape().lens()); assert(ins->get_shape().lens() == t->get_shape().lens());
p.replace_instruction(ins, t); m.replace_instruction(ins, t);
} }
}; };
...@@ -303,7 +303,7 @@ struct find_nested_concat ...@@ -303,7 +303,7 @@ struct find_nested_concat
return op.axis; return op.axis;
} }
void apply(module& p, const match::matcher_result& mr) const void apply(module& m, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
auto axis = get_axis(ins); auto axis = get_axis(ins);
...@@ -317,7 +317,7 @@ struct find_nested_concat ...@@ -317,7 +317,7 @@ struct find_nested_concat
args.push_back(i); args.push_back(i);
} }
})(ins->inputs()); })(ins->inputs());
p.replace_instruction(ins, ins->get_operator(), args); m.replace_instruction(ins, ins->get_operator(), args);
} }
}; };
...@@ -329,7 +329,7 @@ struct find_resize ...@@ -329,7 +329,7 @@ struct find_resize
match::args(match::name("reshape").bind("data"), match::is_constant().bind("ind"))); match::args(match::name("reshape").bind("data"), match::is_constant().bind("ind")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto ins_rsp = r.instructions["data"]; auto ins_rsp = r.instructions["data"];
...@@ -417,13 +417,13 @@ struct find_resize ...@@ -417,13 +417,13 @@ struct find_resize
} }
auto in_rsp = ins_rsp->inputs().front(); auto in_rsp = ins_rsp->inputs().front();
auto rsp_data = p.insert_instruction( auto rsp_data = m.insert_instruction(
ins_rsp, migraphx::make_op("reshape", {{"dims", in_dims}}), in_rsp); ins_rsp, migraphx::make_op("reshape", {{"dims", in_dims}}), in_rsp);
auto mb_rsp = p.insert_instruction( auto mb_rsp = m.insert_instruction(
ins_rsp, migraphx::make_op("multibroadcast", {{"out_lens", out_dims}}), rsp_data); ins_rsp, migraphx::make_op("multibroadcast", {{"out_lens", out_dims}}), rsp_data);
auto std_mb = p.insert_instruction(ins, migraphx::make_op("contiguous"), mb_rsp); auto std_mb = m.insert_instruction(ins, migraphx::make_op("contiguous"), mb_rsp);
std::vector<int64_t> rsp_dims(out_lens.begin(), out_lens.end()); std::vector<int64_t> rsp_dims(out_lens.begin(), out_lens.end());
p.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), std_mb); m.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), std_mb);
} }
}; };
...@@ -436,7 +436,7 @@ struct find_where_op ...@@ -436,7 +436,7 @@ struct find_where_op
match::is_constant().bind("ind"))); match::is_constant().bind("ind")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto concat = r.instructions["data"]; auto concat = r.instructions["data"];
...@@ -475,11 +475,11 @@ struct find_where_op ...@@ -475,11 +475,11 @@ struct find_where_op
if(val) if(val)
{ {
p.replace_instruction(ins, inputs.at(0)); m.replace_instruction(ins, inputs.at(0));
} }
else else
{ {
p.replace_instruction(ins, inputs.at(1)); m.replace_instruction(ins, inputs.at(1));
} }
} }
}; };
...@@ -496,7 +496,7 @@ struct find_reshape_cont ...@@ -496,7 +496,7 @@ struct find_reshape_cont
match::any())); match::any()));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto ins_cont = r.instructions["cont"]; auto ins_cont = r.instructions["cont"];
...@@ -530,11 +530,11 @@ struct find_reshape_cont ...@@ -530,11 +530,11 @@ struct find_reshape_cont
else else
{ {
inputs.push_back( inputs.push_back(
p.insert_instruction(ins, make_op("reshape", {{"dims", dims}}), in)); m.insert_instruction(ins, make_op("reshape", {{"dims", dims}}), in));
} }
} }
auto out = p.insert_instruction(ins, ins->get_operator(), inputs); auto out = m.insert_instruction(ins, ins->get_operator(), inputs);
p.replace_instruction(ins, make_op("reshape", {{"dims", out_dims}}), out); m.replace_instruction(ins, make_op("reshape", {{"dims", out_dims}}), out);
} }
}; };
...@@ -564,25 +564,25 @@ struct find_transpose_contiguous_reshaper_unary ...@@ -564,25 +564,25 @@ struct find_transpose_contiguous_reshaper_unary
match::args(match_transpose_contiguous_reshaper())); match::args(match_transpose_contiguous_reshaper()));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto reshaper_ins = r.instructions["reshaper_ins"]; auto reshaper_ins = r.instructions["reshaper_ins"];
auto trans_ins = r.instructions["trans_ins"]; auto trans_ins = r.instructions["trans_ins"];
auto cont_ins = r.instructions["cont_ins"]; auto cont_ins = r.instructions["cont_ins"];
auto unary_op_name = ins->get_operator().name(); auto unary_op_name = ins->get_operator().name();
auto unary_ins = p.insert_instruction(cont_ins, make_op(unary_op_name), trans_ins); auto unary_ins = m.insert_instruction(cont_ins, make_op(unary_op_name), trans_ins);
auto new_cont_ins = p.insert_instruction(cont_ins, make_op("contiguous"), unary_ins); auto new_cont_ins = m.insert_instruction(cont_ins, make_op("contiguous"), unary_ins);
// older cont and reshape are removed by deadcode elimination // older cont and reshape are removed by deadcode elimination
p.replace_instruction(ins, reshaper_ins->get_operator(), new_cont_ins); m.replace_instruction(ins, reshaper_ins->get_operator(), new_cont_ins);
} }
}; };
void simplify_reshapes::apply(module& p) const void simplify_reshapes::apply(module& m) const
{ {
for(int i = 0; i < 2; i++) for(int i = 0; i < 2; i++)
{ {
match::find_matches(p, match::find_matches(m,
find_where_op{}, find_where_op{},
find_resize{}, find_resize{},
find_reshape_cont{}, find_reshape_cont{},
...@@ -594,7 +594,7 @@ void simplify_reshapes::apply(module& p) const ...@@ -594,7 +594,7 @@ void simplify_reshapes::apply(module& p) const
find_nested_slice{}, find_nested_slice{},
find_nested_concat{}, find_nested_concat{},
find_transpose_contiguous_reshaper_unary{}); find_transpose_contiguous_reshaper_unary{});
dead_code_elimination{}.apply(p); dead_code_elimination{}.apply(m);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment