Commit 17f4ba28 authored by Paul's avatar Paul
Browse files

Merge branch 'jit-vector-reduce' into jit-vector-softmax

parents a8a8d868 c84154b8
......@@ -9,7 +9,19 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
operation make_op(const std::string& name);
operation make_op(const std::string& name, const value& v);
operation make_op(const std::string& name,
const std::initializer_list<std::pair<std::string, value>>& v);
operation make_op_from_value(const std::string& name, const value& v);
// A template overload is added for migraphx::value so the initializer_list
// cannot be passed in directly. This is to enforce at compile-time that all
// initializer_list are key-value pairs, whereas migraphx::value allows other
// types of initializer_list such as for arrays.
template <class Value>
operation make_op(const std::string& name, const Value& v)
{
return make_op_from_value(name, v);
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -156,6 +156,19 @@ struct id_matcher
}
};
// Forward declare class and constructors
template <class M>
struct basic_matcher;
template <class M>
basic_matcher<M> make_basic_matcher(M m);
template <class F>
basic_matcher<function_matcher<F>> make_basic_fun_matcher(F f);
template <class P>
basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p);
/// The basic matcher provides the all_of composability of the matcher
template <class M>
struct basic_matcher
......@@ -167,8 +180,8 @@ struct basic_matcher
{
// Copy m because we cant capture `this` by value
auto mm = m;
return make_bf_matcher([=](matcher_context& ctx,
instruction_ref ins) -> optional<instruction_ref> {
return make_basic_fun_matcher([=](matcher_context& ctx,
instruction_ref ins) -> optional<instruction_ref> {
auto result = mm.match(ctx, ins);
if(result)
{
......@@ -239,7 +252,39 @@ struct any_matcher : any_matcher_base
struct matcher_result
{
std::unordered_map<std::string, instruction_ref> instructions;
struct instruction_container
{
instruction_container() = default;
instruction_container(std::unordered_map<std::string, instruction_ref> x)
: ins_map(std::move(x))
{
}
instruction_ref operator[](const std::string& name) const
{
auto it = ins_map.find(name);
if(it == ins_map.end())
MIGRAPHX_THROW("Accessing name that wasn't bound in matcher: " + name);
return it->second;
}
auto find(const std::string& name) const { return ins_map.find(name); }
auto begin() const { return ins_map.cbegin(); }
auto end() const { return ins_map.cend(); }
bool has_instructions_in(const module& mod) const
{
return std::all_of(ins_map.begin(), ins_map.end(), [&](auto&& p) {
return mod.has_instruction(p.second);
});
}
private:
std::unordered_map<std::string, instruction_ref> ins_map;
};
instruction_container instructions;
instruction_ref result;
};
......@@ -255,6 +300,7 @@ matcher_result match_instruction(module& mod, instruction_ref ins, M&& m)
{
result.result = ins;
result.instructions = ctx.instructions;
assert(result.instructions.has_instructions_in(mod));
}
else
{
......@@ -533,6 +579,18 @@ auto skip_output(Ms... ms)
});
}
inline auto var(std::string s)
{
return make_basic_fun_matcher(
[=, s = std::move(s)](const matcher_context& ctx,
instruction_ref) -> optional<instruction_ref> {
auto it = ctx.instructions.find(s);
if(it == ctx.instructions.end())
return nullopt;
return it->second;
});
}
inline auto name(std::string s)
{
return make_basic_pred_matcher(
......
......@@ -17,7 +17,7 @@ struct memory_coloring
std::string allocation_op{};
bool verify = false;
std::string name() const { return "memory coloring"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -15,7 +15,7 @@ struct module;
struct propagate_constant
{
std::string name() const { return "propagate_constant"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -16,7 +16,7 @@ struct module;
struct rewrite_batchnorm
{
std::string name() const { return "rewrite_batchnorm"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -15,7 +15,7 @@ struct module;
struct rewrite_pooling
{
std::string name() const { return "rewrite_pooling"; }
void apply(module& prog) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -19,22 +19,22 @@ struct module;
struct rewrite_rnn
{
std::string name() const { return "rewrite_rnn"; }
void apply(module& prog) const;
void apply(module& m) const;
private:
// for vanilla rnn operators
void apply_vanilla_rnn(module& prog, instruction_ref ins) const;
void apply_vanilla_rnn(module& m, instruction_ref ins) const;
std::vector<instruction_ref> vanilla_rnn_cell(bool is_forward,
module& prog,
module& m,
instruction_ref ins,
std::vector<instruction_ref> inputs,
operation& actv_func) const;
std::vector<operation> vanilla_rnn_actv_funcs(instruction_ref ins) const;
// for gru operators
void apply_gru(module& prog, instruction_ref ins) const;
void apply_gru(module& m, instruction_ref ins) const;
std::vector<instruction_ref> gru_cell(bool is_forward,
module& prog,
module& m,
instruction_ref ins,
std::vector<instruction_ref> inputs,
int linear_before_reset,
......@@ -44,9 +44,9 @@ struct rewrite_rnn
std::vector<operation> gru_actv_funcs(instruction_ref ins) const;
// for lstm operators
void apply_lstm(module& prog, instruction_ref ins) const;
void apply_lstm(module& m, instruction_ref ins) const;
std::vector<instruction_ref> lstm_cell(bool is_forward,
module& prog,
module& m,
instruction_ref ins,
std::vector<instruction_ref> inputs,
const operation& actv_func1,
......@@ -55,24 +55,23 @@ struct rewrite_rnn
std::vector<operation> lstm_actv_funcs(instruction_ref ins) const;
bool is_variable_seq_lens(const module& prog, instruction_ref seq_lens) const;
instruction_ref replace_last_hs_output(module& prog,
bool is_variable_seq_lens(const module& m, instruction_ref seq_lens) const;
instruction_ref replace_last_hs_output(module& m,
instruction_ref ins,
instruction_ref seq_lens,
instruction_ref last_hs_output,
op::rnn_direction dirct) const;
void replace_last_cell_output(module& prog,
void replace_last_cell_output(module& m,
instruction_ref ins,
instruction_ref seq_lens,
instruction_ref cell_outputs,
instruction_ref last_cell_output,
op::rnn_direction dirct) const;
std::size_t
get_seq_len(const module& prog, instruction_ref input, instruction_ref seq_lens) const;
std::size_t get_seq_len(const module& m, instruction_ref input, instruction_ref seq_lens) const;
instruction_ref pad_hidden_states(module& prog,
instruction_ref pad_hidden_states(module& m,
instruction_ref seq,
instruction_ref seq_lens,
instruction_ref hs) const;
......
......@@ -19,7 +19,7 @@ struct schedule
schedule_model model{};
bool enable = true;
std::string name() const { return "schedule"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -15,7 +15,7 @@ struct module;
struct simplify_algebra
{
std::string name() const { return "simplify_algebra"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -16,7 +16,7 @@ struct module;
struct simplify_reshapes
{
std::string name() const { return "simplify_reshapes"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -5,20 +5,41 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
operation make_op(const std::string& name) { return load_op(name); }
operation make_op(const std::string& name, const value& v)
template <class F>
operation make_op_generic(const std::string& name, F for_each)
{
if(not(v.is_object() or (v.empty() and v.is_array())))
MIGRAPHX_THROW("Value is not an object");
auto op = load_op(name);
// Merge values
value w = op.to_value();
for(auto&& x : v)
{
w.at(x.get_key()) = x.without_key();
}
for_each([&](const auto& key, const auto& x) {
if(not w.contains(key))
// NOLINTNEXTLINE(performance-inefficient-string-concatenation)
MIGRAPHX_THROW("No key '" + key + "' in " + name);
w.at(key) = x;
});
op.from_value(w);
return op;
}
operation make_op(const std::string& name,
const std::initializer_list<std::pair<std::string, value>>& v)
{
return make_op_generic(name, [&](auto f) {
for(auto&& [key, x] : v)
f(key, x);
});
}
operation make_op_from_value(const std::string& name, const value& v)
{
if(not(v.is_object() or (v.empty() and v.is_array())))
MIGRAPHX_THROW("Value is not an object for make_op: " + name);
return make_op_generic(name, [&](auto f) {
for(auto&& x : v)
f(x.get_key(), x.without_key());
});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -22,6 +22,8 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_FINALIZE)
struct module_impl
{
// A list is used to keep references to an instruction stable
......@@ -553,8 +555,14 @@ instruction_ref module::find_dangling_reference() const
void module::finalize(context& ctx)
{
const bool trace = enabled(MIGRAPHX_TRACE_FINALIZE{});
for(auto ins : iterator_for(*this))
{
if(trace)
{
std::cout << "Finalize: ";
this->debug_print(ins);
}
ins->finalize(ctx);
for(const auto& smod : ins->module_inputs())
{
......
......@@ -4,11 +4,11 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void memory_coloring::apply(module& p) const
void memory_coloring::apply(module& m) const
{
if(!enabled(MIGRAPHX_DISABLE_MEMORY_COLORING{}))
{
memory_coloring_impl opt(&p, allocation_op, verify);
memory_coloring_impl opt(&m, allocation_op, verify);
opt.run();
}
}
......
......@@ -20,9 +20,9 @@ bool skip_propogate(instruction_ref ins)
return false;
}
void propagate_constant::apply(module& p) const
void propagate_constant::apply(module& m) const
{
for(auto i : iterator_for(p))
for(auto i : iterator_for(m))
{
if(i->name() != "@literal")
continue;
......@@ -42,8 +42,8 @@ void propagate_constant::apply(module& p) const
if(not r.empty())
{
assert(r.get_shape() == child->get_shape());
auto l = p.add_literal(r.get_shape(), r.data());
self(p.replace_instruction(child, l));
auto l = m.add_literal(r.get_shape(), r.data());
self(m.replace_instruction(child, l));
}
}
})(i);
......
......@@ -14,9 +14,9 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void rewrite_batchnorm::apply(module& p) const
void rewrite_batchnorm::apply(module& m) const
{
for(auto ins : iterator_for(p))
for(auto ins : iterator_for(m))
{
if(ins->name() != "batch_norm_inference")
continue;
......@@ -46,13 +46,13 @@ void rewrite_batchnorm::apply(module& p) const
});
auto broadcast = op::broadcast{1, ins->get_shape().lens()};
auto a_ins = p.add_literal({a.get_shape(), a.data()});
auto a_broadcast = p.insert_instruction(ins, broadcast, a_ins);
auto mul = p.insert_instruction(ins, make_op("mul"), ins->inputs().front(), a_broadcast);
auto b_ins = p.add_literal({b.get_shape(), b.data()});
auto b_broadcast = p.insert_instruction(ins, broadcast, b_ins);
auto add = p.insert_instruction(ins, make_op("add"), mul, b_broadcast);
p.replace_instruction(ins, add);
auto a_ins = m.add_literal({a.get_shape(), a.data()});
auto a_broadcast = m.insert_instruction(ins, broadcast, a_ins);
auto mul = m.insert_instruction(ins, make_op("mul"), ins->inputs().front(), a_broadcast);
auto b_ins = m.add_literal({b.get_shape(), b.data()});
auto b_broadcast = m.insert_instruction(ins, broadcast, b_ins);
auto add = m.insert_instruction(ins, make_op("add"), mul, b_broadcast);
m.replace_instruction(ins, add);
}
}
......
......@@ -12,9 +12,9 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void rewrite_pooling::apply(module& prog) const
void rewrite_pooling::apply(module& m) const
{
for(auto ins : iterator_for(prog))
for(auto ins : iterator_for(m))
{
if(ins->name() != "pooling")
continue;
......@@ -33,26 +33,25 @@ void rewrite_pooling::apply(module& prog) const
continue;
std::int64_t n = s.lens()[0];
std::int64_t c = s.lens()[1];
auto reshape = prog.insert_instruction(
auto reshape = m.insert_instruction(
ins, make_op("reshape", {{"dims", {n * c, -1}}}), ins->inputs().front());
instruction_ref pooling{};
// average pooling
if(op.mode == op::pooling_mode::average)
{
pooling =
prog.insert_instruction(ins, make_op("reduce_mean", {{"axes", {1}}}), reshape);
pooling = m.insert_instruction(ins, make_op("reduce_mean", {{"axes", {1}}}), reshape);
}
// max pooling
else
{
pooling = prog.insert_instruction(ins, make_op("reduce_max", {{"axes", {1}}}), reshape);
pooling = m.insert_instruction(ins, make_op("reduce_max", {{"axes", {1}}}), reshape);
}
std::vector<int64_t> rsp_lens(lens.size(), 1);
rsp_lens[0] = n;
rsp_lens[1] = c;
prog.replace_instruction(ins, make_op("reshape", {{"dims", rsp_lens}}), pooling);
m.replace_instruction(ins, make_op("reshape", {{"dims", rsp_lens}}), pooling);
}
}
......
......@@ -30,27 +30,27 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void rewrite_rnn::apply(module& prog) const
void rewrite_rnn::apply(module& m) const
{
for(auto ins : iterator_for(prog))
for(auto ins : iterator_for(m))
{
if(ins->name() == "rnn")
{
apply_vanilla_rnn(prog, ins);
apply_vanilla_rnn(m, ins);
}
else if(ins->name() == "gru")
{
apply_gru(prog, ins);
apply_gru(m, ins);
}
else if(ins->name() == "lstm")
{
apply_lstm(prog, ins);
apply_lstm(m, ins);
}
}
}
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
{
assert(ins->name() == "rnn");
// could be 3 to 6 inputs, but the parse_rnn function will
......@@ -71,37 +71,37 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
op::rnn_direction dirct = rnn_op.direction;
// process sequence length
instruction_ref seq_lens = prog.end();
instruction_ref seq_lens = m.end();
if((args.size() >= 5) && args[4]->name() != "undefined")
{
seq_lens = args[4];
}
bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);
bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
instruction_ref last_output{};
if(dirct == op::rnn_direction::bidirectional)
{
// input weight matrix
auto w_forward = prog.insert_instruction(
auto w_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
auto w_reverse = prog.insert_instruction(
auto w_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
// hidden state weight matrix
auto r_forward = prog.insert_instruction(
auto r_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
auto r_reverse = prog.insert_instruction(
auto r_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
// process bias
instruction_ref bias_forward = prog.end();
instruction_ref bias_reverse = prog.end();
instruction_ref bias_forward = m.end();
instruction_ref bias_reverse = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias_forward = prog.insert_instruction(
bias_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
bias_reverse = prog.insert_instruction(
bias_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
}
......@@ -111,57 +111,56 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->name() != "undefined")
{
ih_forward = prog.insert_instruction(
ih_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
ih_reverse = prog.insert_instruction(
ih_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
}
else
{
ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
ih_forward = m.add_literal(migraphx::literal{ih_shape, data});
ih_reverse = m.add_literal(migraphx::literal{ih_shape, data});
}
auto ret_forward =
vanilla_rnn_cell(true,
prog,
m,
ins,
{args[0], w_forward, r_forward, bias_forward, seq_lens, ih_forward},
actv_funcs.at(0));
if(variable_seq_len)
{
args[0] = prog.insert_instruction(
ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
args[0] =
m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
}
auto ret_reverse =
vanilla_rnn_cell(false,
prog,
m,
ins,
{args[0], w_reverse, r_reverse, bias_reverse, seq_lens, ih_reverse},
actv_funcs.at(1));
auto concat_output = prog.insert_instruction(
auto concat_output = m.insert_instruction(
ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
last_output =
prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_output);
last_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_output);
// The following logic is to ensure the last instruction rewritten from
// rnn operator is a concat instruction
// sequence len is 1
if(ret_forward[0] == prog.end())
if(ret_forward[0] == m.end())
{
prog.replace_instruction(
m.replace_instruction(
ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
}
else
{
ret_forward[0] = prog.insert_instruction(
ret_forward[0] = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
ret_reverse[0] = prog.insert_instruction(
ret_reverse[0] = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
prog.replace_instruction(
m.replace_instruction(
ins, make_op("concat", {{"axis", 1}}), {ret_forward[0], ret_reverse[0]});
}
}
......@@ -175,7 +174,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
auto r = args[2];
// process bias and initial hidden state
instruction_ref bias = prog.end();
instruction_ref bias = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias = args[3];
......@@ -189,43 +188,42 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
}
else
{
ih = prog.add_literal(migraphx::literal{ih_shape, data});
ih = m.add_literal(migraphx::literal{ih_shape, data});
}
if(!is_forward and variable_seq_len)
{
args[0] = prog.insert_instruction(
ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
args[0] =
m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
}
auto ret = vanilla_rnn_cell(
is_forward, prog, ins, {args[0], w, r, bias, seq_lens, ih}, actv_funcs.at(0));
last_output = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
is_forward, m, ins, {args[0], w, r, bias, seq_lens, ih}, actv_funcs.at(0));
last_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
// following logic is to ensure the last instruction is a
// concat instruction
// sequence len is 1
if(ret[0] == prog.end())
if(ret[0] == m.end())
{
prog.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
}
else
{
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0];
prog.replace_instruction(
ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
}
}
// in case of all sequences are of the same lengths and shorter than the
// max sequence length, need to pad 0's at the end for output hidden states
ins = pad_hidden_states(prog, args[0], seq_lens, ins);
replace_last_hs_output(prog, ins, seq_lens, last_output, dirct);
ins = pad_hidden_states(m, args[0], seq_lens, ins);
replace_last_hs_output(m, ins, seq_lens, last_output, dirct);
}
std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
module& prog,
module& m,
instruction_ref ins,
std::vector<instruction_ref> inputs,
operation& actv_func) const
......@@ -240,60 +238,60 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
// squeeze and transpose w
std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tran_sw = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
auto sw = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tran_sw = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
// squeeze and transpose r
auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto tran_sr = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr);
auto sr = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto tran_sr = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr);
// initial hidden state
auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
auto sih = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
auto sih_lens = sih->get_shape().lens();
// bias
instruction_ref bb{};
if(bias != prog.end())
if(bias != m.end())
{
long hs = static_cast<long>(r->get_shape().lens()[2]);
auto sbias = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
auto wb = prog.insert_instruction(
auto sbias = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
auto wb = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), sbias);
auto rb = prog.insert_instruction(
auto rb = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), sbias);
auto wrb = prog.insert_instruction(ins, make_op("add"), wb, rb);
bb = prog.insert_instruction(
auto wrb = m.insert_instruction(ins, make_op("add"), wb, rb);
bb = m.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"out_lens", sih_lens}}), wrb);
}
instruction_ref hidden_out = prog.end();
instruction_ref hidden_out = m.end();
instruction_ref last_out{};
last_out = prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), sih);
long seq_len = get_seq_len(prog, seq, seq_lens);
last_out = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), sih);
long seq_len = get_seq_len(m, seq, seq_lens);
for(long i = 0; i < seq_len; i++)
{
long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction(
auto xt = m.insert_instruction(
ins,
make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
seq);
auto cont_xt = prog.insert_instruction(ins, make_op("contiguous"), xt);
xt = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
auto xt_wi = prog.insert_instruction(ins, make_op("dot"), xt, tran_sw);
auto ht_ri = prog.insert_instruction(ins, make_op("dot"), sih, tran_sr);
if(bias != prog.end())
auto cont_xt = m.insert_instruction(ins, make_op("contiguous"), xt);
xt = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
auto xt_wi = m.insert_instruction(ins, make_op("dot"), xt, tran_sw);
auto ht_ri = m.insert_instruction(ins, make_op("dot"), sih, tran_sr);
if(bias != m.end())
{
xt_wi = prog.insert_instruction(ins, make_op("add"), xt_wi, bb);
xt_wi = m.insert_instruction(ins, make_op("add"), xt_wi, bb);
}
auto xt_ht = prog.insert_instruction(ins, make_op("add"), xt_wi, ht_ri);
auto xt_ht = m.insert_instruction(ins, make_op("add"), xt_wi, ht_ri);
// apply activation function
auto ht = prog.insert_instruction(ins, actv_func, xt_ht);
auto ht = m.insert_instruction(ins, actv_func, xt_ht);
sih = ht;
// add the dimensions of sequence length (axis 0 for sequence length,
// axis 1 for num_directions
last_out = prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), ht);
last_out = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), ht);
// concatenation for the last last_out is performed in the apply()
// function to ensure the last instruction is concat, then we have
......@@ -304,14 +302,14 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
{
hidden_out = (seq_index == 0)
? last_out
: prog.insert_instruction(
: m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), hidden_out, last_out);
}
else
{
hidden_out = (seq_index == seq_len - 1)
? last_out
: prog.insert_instruction(
: m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), last_out, hidden_out);
}
}
......@@ -358,7 +356,7 @@ std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins)
}
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
{
assert(ins->name() == "gru");
const auto actv_funcs = gru_actv_funcs(ins);
......@@ -379,37 +377,37 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
op::rnn_direction dirct = gru_op.direction;
// process sequence length
instruction_ref seq_lens = prog.end();
instruction_ref seq_lens = m.end();
if((args.size() >= 5) && args[4]->name() != "undefined")
{
seq_lens = args[4];
}
bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);
bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
instruction_ref last_output{};
if(dirct == op::rnn_direction::bidirectional)
{
// w weight matrix
auto w_forward = prog.insert_instruction(
auto w_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
auto w_reverse = prog.insert_instruction(
auto w_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
// r weight matrix
auto r_forward = prog.insert_instruction(
auto r_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
auto r_reverse = prog.insert_instruction(
auto r_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
// bias
instruction_ref bias_forward = prog.end();
instruction_ref bias_reverse = prog.end();
instruction_ref bias_forward = m.end();
instruction_ref bias_reverse = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias_forward = prog.insert_instruction(
bias_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
bias_reverse = prog.insert_instruction(
bias_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
}
......@@ -418,20 +416,20 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->name() != "undefined")
{
ih_forward = prog.insert_instruction(
ih_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
ih_reverse = prog.insert_instruction(
ih_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
}
else
{
ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
ih_forward = m.add_literal(migraphx::literal{ih_shape, data});
ih_reverse = m.add_literal(migraphx::literal{ih_shape, data});
}
auto ret_forward =
gru_cell(true,
prog,
m,
ins,
{args[0], w_forward, r_forward, bias_forward, seq_lens, ih_forward},
gru_op.linear_before_reset,
......@@ -440,38 +438,37 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
if(variable_seq_len)
{
args[0] = prog.insert_instruction(
ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
args[0] =
m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
}
auto ret_reverse =
gru_cell(false,
prog,
m,
ins,
{args[0], w_reverse, r_reverse, bias_reverse, seq_lens, ih_reverse},
gru_op.linear_before_reset,
actv_funcs.at(2),
actv_funcs.at(3));
auto concat_output = prog.insert_instruction(
auto concat_output = m.insert_instruction(
ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
last_output =
prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_output);
last_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_output);
// The following logic is to ensure the last instruction rewritten
// from gru operator is a concat
if(ret_forward[0] == prog.end())
if(ret_forward[0] == m.end())
{
prog.replace_instruction(
m.replace_instruction(
ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
}
else
{
ret_forward[0] = prog.insert_instruction(
ret_forward[0] = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
ret_reverse[0] = prog.insert_instruction(
ret_reverse[0] = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
prog.replace_instruction(
m.replace_instruction(
ins, make_op("concat", {{"axis", 1}}), {ret_forward[0], ret_reverse[0]});
}
}
......@@ -483,7 +480,7 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
auto r = args[2];
// bias
instruction_ref bias = prog.end();
instruction_ref bias = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias = args[3];
......@@ -497,47 +494,46 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
}
else
{
ih = prog.add_literal(migraphx::literal{ih_shape, data});
ih = m.add_literal(migraphx::literal{ih_shape, data});
}
if(!is_forward and variable_seq_len)
{
args[0] = prog.insert_instruction(
ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
args[0] =
m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
}
auto ret = gru_cell(is_forward,
prog,
m,
ins,
{args[0], w, r, bias, seq_lens, ih},
gru_op.linear_before_reset,
actv_funcs.at(0),
actv_funcs.at(1));
last_output = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
last_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
if(ret[0] == prog.end())
if(ret[0] == m.end())
{
prog.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
}
else
{
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0];
prog.replace_instruction(
ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
}
}
// in case of all sequences are of the same lengths and shorter than the
// max sequence length, need to pad 0's at the end for output hidden states
ins = pad_hidden_states(prog, args[0], seq_lens, ins);
replace_last_hs_output(prog, ins, seq_lens, last_output, dirct);
ins = pad_hidden_states(m, args[0], seq_lens, ins);
replace_last_hs_output(m, ins, seq_lens, last_output, dirct);
}
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
module& prog,
module& m,
instruction_ref ins,
std::vector<instruction_ref> inputs,
int linear_before_reset,
......@@ -552,7 +548,7 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
auto seq_lens = inputs.at(4);
auto ih = inputs.at(5);
instruction_ref hidden_states = prog.end();
instruction_ref hidden_states = m.end();
instruction_ref last_output{};
migraphx::shape seq_shape = seq->get_shape();
migraphx::shape r_shape = r->get_shape();
......@@ -560,127 +556,127 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
migraphx::shape ss(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]});
std::vector<float> data(ss.elements(), 1.0f);
auto l1 = prog.add_literal(migraphx::literal{ss, data});
auto l1 = m.add_literal(migraphx::literal{ss, data});
// w matrix squeeze to 2-dim and do a transpose
std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tw = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
auto sw = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tw = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
// r slide to two part, zr and h
auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto rzr = prog.insert_instruction(
auto sr = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto rzr = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2 * hs}}}), sr);
auto trzr = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rzr);
auto trzr = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rzr);
auto rh = prog.insert_instruction(
auto rh = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), sr);
auto trh = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rh);
auto trh = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rh);
// initial states
auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
auto sih = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
size_t bs = ih->get_shape().lens()[1];
// bias
instruction_ref bwb{};
instruction_ref brb_zr{};
instruction_ref brb_h{};
if(bias != prog.end())
if(bias != m.end())
{
auto sbias = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
auto wb = prog.insert_instruction(
auto sbias = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
auto wb = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {3 * hs}}}), sbias);
bwb = prog.insert_instruction(
bwb = m.insert_instruction(
ins,
make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(3 * hs)}}}),
wb);
auto rb_zr = prog.insert_instruction(
auto rb_zr = m.insert_instruction(
ins,
make_op("slice", {{"axes", {0}}, {"starts", {3 * hs}}, {"ends", {5 * hs}}}),
sbias);
auto rb_h = prog.insert_instruction(
auto rb_h = m.insert_instruction(
ins,
make_op("slice", {{"axes", {0}}, {"starts", {5 * hs}}, {"ends", {6 * hs}}}),
sbias);
brb_zr = prog.insert_instruction(
brb_zr = m.insert_instruction(
ins,
make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(2 * hs)}}}),
rb_zr);
brb_h = prog.insert_instruction(
brb_h = m.insert_instruction(
ins,
make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(hs)}}}),
rb_h);
}
long seq_len = get_seq_len(prog, seq, seq_lens);
long seq_len = get_seq_len(m, seq, seq_lens);
for(long i = 0; i < seq_len; i++)
{
long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction(
auto xt = m.insert_instruction(
ins,
make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
seq);
auto cont_xt = prog.insert_instruction(ins, make_op("contiguous"), xt);
xt = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
auto cont_xt = m.insert_instruction(ins, make_op("contiguous"), xt);
xt = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
auto xt_w = prog.insert_instruction(ins, make_op("dot"), xt, tw);
auto ih1_rzr = prog.insert_instruction(ins, make_op("dot"), sih, trzr);
if(bias != prog.end())
auto xt_w = m.insert_instruction(ins, make_op("dot"), xt, tw);
auto ih1_rzr = m.insert_instruction(ins, make_op("dot"), sih, trzr);
if(bias != m.end())
{
xt_w = prog.insert_instruction(ins, make_op("add"), xt_w, bwb);
ih1_rzr = prog.insert_instruction(ins, make_op("add"), ih1_rzr, brb_zr);
xt_w = m.insert_instruction(ins, make_op("add"), xt_w, bwb);
ih1_rzr = m.insert_instruction(ins, make_op("add"), ih1_rzr, brb_zr);
}
auto xw_z = prog.insert_instruction(
auto xw_z = m.insert_instruction(
ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), xt_w);
auto xw_r = prog.insert_instruction(
auto xw_r = m.insert_instruction(
ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), xt_w);
auto xw_h = prog.insert_instruction(
auto xw_h = m.insert_instruction(
ins, make_op("slice", {{"axes", {1}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), xt_w);
auto hr_z = prog.insert_instruction(
auto hr_z = m.insert_instruction(
ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), ih1_rzr);
auto hr_r = prog.insert_instruction(
auto hr_r = m.insert_instruction(
ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), ih1_rzr);
auto xw_hr_z = prog.insert_instruction(ins, make_op("add"), xw_z, hr_z);
auto zt = prog.insert_instruction(ins, actv_func1, xw_hr_z);
auto xw_hr_z = m.insert_instruction(ins, make_op("add"), xw_z, hr_z);
auto zt = m.insert_instruction(ins, actv_func1, xw_hr_z);
auto xw_hr_r = prog.insert_instruction(ins, make_op("add"), xw_r, hr_r);
auto rt = prog.insert_instruction(ins, actv_func1, xw_hr_r);
auto xw_hr_r = m.insert_instruction(ins, make_op("add"), xw_r, hr_r);
auto rt = m.insert_instruction(ins, actv_func1, xw_hr_r);
instruction_ref hr_h{};
if(linear_before_reset == 0)
{
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto rt_ht1 = prog.insert_instruction(ins, make_op("mul"), rt, sih);
hr_h = prog.insert_instruction(ins, make_op("dot"), rt_ht1, trh);
if(bias != prog.end())
auto rt_ht1 = m.insert_instruction(ins, make_op("mul"), rt, sih);
hr_h = m.insert_instruction(ins, make_op("dot"), rt_ht1, trh);
if(bias != m.end())
{
hr_h = prog.insert_instruction(ins, make_op("add"), hr_h, brb_h);
hr_h = m.insert_instruction(ins, make_op("add"), hr_h, brb_h);
}
}
else
{
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto ht1_rh = prog.insert_instruction(ins, make_op("dot"), sih, trh);
if(bias != prog.end())
auto ht1_rh = m.insert_instruction(ins, make_op("dot"), sih, trh);
if(bias != m.end())
{
ht1_rh = prog.insert_instruction(ins, make_op("add"), ht1_rh, brb_h);
ht1_rh = m.insert_instruction(ins, make_op("add"), ht1_rh, brb_h);
}
hr_h = prog.insert_instruction(ins, make_op("mul"), rt, ht1_rh);
hr_h = m.insert_instruction(ins, make_op("mul"), rt, ht1_rh);
}
auto xw_hr_h = prog.insert_instruction(ins, make_op("add"), xw_h, hr_h);
auto ht = prog.insert_instruction(ins, actv_func2, xw_hr_h);
auto xw_hr_h = m.insert_instruction(ins, make_op("add"), xw_h, hr_h);
auto ht = m.insert_instruction(ins, actv_func2, xw_hr_h);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto one_minus_zt = prog.insert_instruction(ins, make_op("sub"), l1, zt);
auto one_minus_zt_ht = prog.insert_instruction(ins, make_op("mul"), one_minus_zt, ht);
auto zt_ht1 = prog.insert_instruction(ins, make_op("mul"), zt, sih);
sih = prog.insert_instruction(ins, make_op("add"), one_minus_zt_ht, zt_ht1);
last_output = prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), sih);
auto one_minus_zt = m.insert_instruction(ins, make_op("sub"), l1, zt);
auto one_minus_zt_ht = m.insert_instruction(ins, make_op("mul"), one_minus_zt, ht);
auto zt_ht1 = m.insert_instruction(ins, make_op("mul"), zt, sih);
sih = m.insert_instruction(ins, make_op("add"), one_minus_zt_ht, zt_ht1);
last_output = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), sih);
if(i < seq_len - 1)
{
......@@ -689,7 +685,7 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
hidden_states =
(seq_index == 0)
? last_output
: prog.insert_instruction(
: m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), hidden_states, last_output);
}
else
......@@ -697,7 +693,7 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
hidden_states =
(seq_index == seq_len - 1)
? last_output
: prog.insert_instruction(
: m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), last_output, hidden_states);
}
}
......@@ -748,7 +744,7 @@ std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
// for lstm operators
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
{
assert(ins->name() == "lstm");
auto args = ins->inputs();
......@@ -767,13 +763,13 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
op::rnn_direction dirct = lstm_op.direction;
// process sequence length
instruction_ref seq_lens = prog.end();
instruction_ref seq_lens = m.end();
if((args.size() >= 5) && args[4]->name() != "undefined")
{
seq_lens = args[4];
}
bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);
bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
instruction_ref last_hs_output{};
instruction_ref last_cell_output{};
......@@ -783,25 +779,25 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
{
// input weight matrix
// input weight matrix
auto w_forward = prog.insert_instruction(
auto w_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
auto w_reverse = prog.insert_instruction(
auto w_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
// hidden state weight matrix
auto r_forward = prog.insert_instruction(
auto r_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
auto r_reverse = prog.insert_instruction(
auto r_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
// process bias
instruction_ref bias_forward = prog.end();
instruction_ref bias_reverse = prog.end();
instruction_ref bias_forward = m.end();
instruction_ref bias_reverse = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias_forward = prog.insert_instruction(
bias_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
bias_reverse = prog.insert_instruction(
bias_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
}
......@@ -810,15 +806,15 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
instruction_ref ih_reverse{};
if(args.size() >= 6 && args[5]->name() != "undefined")
{
ih_forward = prog.insert_instruction(
ih_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
ih_reverse = prog.insert_instruction(
ih_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
}
else
{
ih_forward = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
ih_reverse = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
ih_forward = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
ih_reverse = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
}
// process initial cell value
......@@ -826,30 +822,30 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
instruction_ref ic_reverse{};
if(args.size() >= 7 && args[6]->name() != "undefined")
{
ic_forward = prog.insert_instruction(
ic_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[6]);
ic_reverse = prog.insert_instruction(
ic_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[6]);
}
else
{
ic_forward = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
ic_reverse = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
ic_forward = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
ic_reverse = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
}
// process weight of the peephole
instruction_ref pph_forward = prog.end();
instruction_ref pph_reverse = prog.end();
instruction_ref pph_forward = m.end();
instruction_ref pph_reverse = m.end();
if(args.size() == 8 && args[7]->name() != "undefined")
{
pph_forward = prog.insert_instruction(
pph_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[7]);
pph_reverse = prog.insert_instruction(
pph_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[7]);
}
auto ret_forward = lstm_cell(true,
prog,
m,
ins,
{args[0],
w_forward,
......@@ -865,11 +861,11 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
if(variable_seq_len)
{
args[0] = prog.insert_instruction(
ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
args[0] =
m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
}
auto ret_reverse = lstm_cell(false,
prog,
m,
ins,
{args[0],
w_reverse,
......@@ -883,36 +879,36 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
actv_funcs.at(4),
actv_funcs.at(5));
auto concat_hs_output = prog.insert_instruction(
auto concat_hs_output = m.insert_instruction(
ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
auto concat_cell_output = prog.insert_instruction(
auto concat_cell_output = m.insert_instruction(
ins, make_op("concat", {{"axis", 1}}), ret_forward[3], ret_reverse[3]);
last_hs_output =
prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_hs_output);
m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_hs_output);
last_cell_output =
prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_cell_output);
m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_cell_output);
// the following logic is to ensure the last instruction is a concat
if(ret_forward[0] == prog.end())
if(ret_forward[0] == m.end())
{
cell_outputs = concat_cell_output;
}
else
{
ret_forward[1] = prog.insert_instruction(
ret_forward[1] = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
ret_reverse[1] = prog.insert_instruction(
ret_reverse[1] = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
ret_forward[3] = prog.insert_instruction(
ret_forward[3] = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), ret_forward[2], ret_forward[3]);
ret_reverse[3] = prog.insert_instruction(
ret_reverse[3] = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), ret_reverse[3], ret_reverse[2]);
cell_outputs = prog.insert_instruction(
cell_outputs = m.insert_instruction(
ins, make_op("concat", {{"axis", 1}}), ret_forward[3], ret_reverse[3]);
}
hidden_state = prog.replace_instruction(
hidden_state = m.replace_instruction(
ins, make_op("concat", {{"axis", 1}}), {ret_forward[1], ret_reverse[1]});
}
else
......@@ -923,7 +919,7 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
auto r = args[2];
// bias
instruction_ref bias = prog.end();
instruction_ref bias = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias = args[3];
......@@ -937,7 +933,7 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
}
else
{
ih = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
ih = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
}
// initial cell value
......@@ -948,11 +944,11 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
}
else
{
ic = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
ic = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
}
// process weight of the peephole
instruction_ref pph = prog.end();
instruction_ref pph = m.end();
if(args.size() == 8 && args[7]->name() != "undefined")
{
pph = args[7];
......@@ -960,54 +956,53 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
if(!is_forward and variable_seq_len)
{
args[0] = prog.insert_instruction(
ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
args[0] =
m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
}
auto ret = lstm_cell(is_forward,
prog,
m,
ins,
{args[0], w, r, bias, seq_lens, ih, ic, pph},
actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(2));
last_hs_output = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
last_cell_output =
prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[3]);
last_hs_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
last_cell_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[3]);
if(ret[0] == prog.end())
if(ret[0] == m.end())
{
cell_outputs = ret[3];
hidden_state = prog.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
hidden_state = m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
}
else
{
auto concat_cell_arg0 = is_forward ? ret[2] : ret[3];
auto concat_cell_arg1 = is_forward ? ret[3] : ret[2];
cell_outputs = prog.insert_instruction(
cell_outputs = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), concat_cell_arg0, concat_cell_arg1);
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0];
hidden_state = prog.replace_instruction(
hidden_state = m.replace_instruction(
ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
}
}
// in case of all sequences are of the same lengths and shorter than the
// max sequence length, need to pad 0's at the end for output hidden states
hidden_state = pad_hidden_states(prog, args[0], seq_lens, hidden_state);
hidden_state = pad_hidden_states(m, args[0], seq_lens, hidden_state);
// replace last hidden states with corresponding instructions
ins = replace_last_hs_output(prog, hidden_state, seq_lens, last_hs_output, dirct);
ins = replace_last_hs_output(m, hidden_state, seq_lens, last_hs_output, dirct);
// replace last cell outputs with corresponding instructions
replace_last_cell_output(prog, ins, seq_lens, cell_outputs, last_cell_output, dirct);
replace_last_cell_output(m, ins, seq_lens, cell_outputs, last_cell_output, dirct);
}
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
module& prog,
module& m,
instruction_ref ins,
std::vector<instruction_ref> inputs,
const operation& actv_func1,
......@@ -1025,8 +1020,8 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto ic = inputs.at(6);
auto pph = inputs.at(7);
instruction_ref hidden_states = prog.end();
instruction_ref cell_outputs = prog.end();
instruction_ref hidden_states = m.end();
instruction_ref cell_outputs = m.end();
instruction_ref last_hs_output{};
instruction_ref last_cell_output{};
......@@ -1037,35 +1032,35 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
std::vector<int64_t> perm{1, 0};
// w matrix, squeeze and transpose
auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tsw = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
auto sw = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tsw = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
// r matrix, squeeze and transpose
auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto tsr = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr);
auto sr = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto tsr = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr);
// initial hidden state
auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
auto sih = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
// initial cell state
auto sic = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ic);
auto sic = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ic);
auto ic_lens = sic->get_shape().lens();
// bias
instruction_ref wrb{};
if(bias != prog.end())
if(bias != m.end())
{
auto sbias = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
auto ub_wb = prog.insert_instruction(
auto sbias = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
auto ub_wb = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {4 * hs}}}), sbias);
auto ub_rb = prog.insert_instruction(
auto ub_rb = m.insert_instruction(
ins,
make_op("slice", {{"axes", {0}}, {"starts", {4 * hs}}, {"ends", {8 * hs}}}),
sbias);
auto ub_wrb = prog.insert_instruction(ins, make_op("add"), ub_wb, ub_rb);
auto ub_wrb = m.insert_instruction(ins, make_op("add"), ub_wb, ub_rb);
wrb = prog.insert_instruction(
wrb = m.insert_instruction(
ins,
make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, 4 * static_cast<size_t>(hs)}}}),
ub_wrb);
......@@ -1075,92 +1070,91 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
instruction_ref pphi_brcst{};
instruction_ref ppho_brcst{};
instruction_ref pphf_brcst{};
if(pph != prog.end())
if(pph != m.end())
{
auto spph = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), pph);
auto pphi = prog.insert_instruction(
auto spph = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), pph);
auto pphi = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), spph);
pphi_brcst = prog.insert_instruction(
pphi_brcst = m.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), pphi);
auto ppho = prog.insert_instruction(
auto ppho = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), spph);
ppho_brcst = prog.insert_instruction(
ppho_brcst = m.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), ppho);
auto pphf = prog.insert_instruction(
auto pphf = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), spph);
pphf_brcst = prog.insert_instruction(
pphf_brcst = m.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), pphf);
}
long seq_len = get_seq_len(prog, seq, seq_lens);
long seq_len = get_seq_len(m, seq, seq_lens);
for(long i = 0; i < seq_len; ++i)
{
long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction(
auto xt = m.insert_instruction(
ins,
make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
seq);
auto cont_xt = prog.insert_instruction(ins, make_op("contiguous"), xt);
xt = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
auto cont_xt = m.insert_instruction(ins, make_op("contiguous"), xt);
xt = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
auto xt_tsw = prog.insert_instruction(ins, make_op("dot"), xt, tsw);
auto sih_tsr = prog.insert_instruction(ins, make_op("dot"), sih, tsr);
auto xt_sih = prog.insert_instruction(ins, make_op("add"), xt_tsw, sih_tsr);
if(bias != prog.end())
auto xt_tsw = m.insert_instruction(ins, make_op("dot"), xt, tsw);
auto sih_tsr = m.insert_instruction(ins, make_op("dot"), sih, tsr);
auto xt_sih = m.insert_instruction(ins, make_op("add"), xt_tsw, sih_tsr);
if(bias != m.end())
{
xt_sih = prog.insert_instruction(ins, make_op("add"), xt_sih, wrb);
xt_sih = m.insert_instruction(ins, make_op("add"), xt_sih, wrb);
}
auto it_before_actv = prog.insert_instruction(
auto it_before_actv = m.insert_instruction(
ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), xt_sih);
auto ot_before_actv = prog.insert_instruction(
auto ot_before_actv = m.insert_instruction(
ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), xt_sih);
auto ft_before_actv = prog.insert_instruction(
auto ft_before_actv = m.insert_instruction(
ins,
make_op("slice", {{"axes", {1}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}),
xt_sih);
auto ct_before_actv = prog.insert_instruction(
auto ct_before_actv = m.insert_instruction(
ins,
make_op("slice", {{"axes", {1}}, {"starts", {3 * hs}}, {"ends", {4 * hs}}}),
xt_sih);
if(pph != prog.end())
if(pph != m.end())
{
auto pphi_ct = prog.insert_instruction(ins, make_op("mul"), pphi_brcst, sic);
it_before_actv = prog.insert_instruction(ins, make_op("add"), it_before_actv, pphi_ct);
auto pphi_ct = m.insert_instruction(ins, make_op("mul"), pphi_brcst, sic);
it_before_actv = m.insert_instruction(ins, make_op("add"), it_before_actv, pphi_ct);
auto pphf_ct = prog.insert_instruction(ins, make_op("mul"), pphf_brcst, sic);
ft_before_actv = prog.insert_instruction(ins, make_op("add"), ft_before_actv, pphf_ct);
auto pphf_ct = m.insert_instruction(ins, make_op("mul"), pphf_brcst, sic);
ft_before_actv = m.insert_instruction(ins, make_op("add"), ft_before_actv, pphf_ct);
}
auto it = prog.insert_instruction(ins, actv_func1, it_before_actv);
auto ft = prog.insert_instruction(ins, actv_func1, ft_before_actv);
auto ct = prog.insert_instruction(ins, actv_func2, ct_before_actv);
auto it = m.insert_instruction(ins, actv_func1, it_before_actv);
auto ft = m.insert_instruction(ins, actv_func1, ft_before_actv);
auto ct = m.insert_instruction(ins, actv_func2, ct_before_actv);
// equation Ct = ft (.) Ct-1 + it (.) ct
auto ft_cell = prog.insert_instruction(ins, make_op("mul"), ft, sic);
auto it_ct = prog.insert_instruction(ins, make_op("mul"), it, ct);
auto cellt = prog.insert_instruction(ins, make_op("add"), ft_cell, it_ct);
auto ft_cell = m.insert_instruction(ins, make_op("mul"), ft, sic);
auto it_ct = m.insert_instruction(ins, make_op("mul"), it, ct);
auto cellt = m.insert_instruction(ins, make_op("add"), ft_cell, it_ct);
if(pph != prog.end())
if(pph != m.end())
{
auto ppho_cellt = prog.insert_instruction(ins, make_op("mul"), ppho_brcst, cellt);
ot_before_actv =
prog.insert_instruction(ins, make_op("add"), ot_before_actv, ppho_cellt);
auto ppho_cellt = m.insert_instruction(ins, make_op("mul"), ppho_brcst, cellt);
ot_before_actv = m.insert_instruction(ins, make_op("add"), ot_before_actv, ppho_cellt);
}
auto ot = prog.insert_instruction(ins, actv_func1, ot_before_actv);
auto ot = m.insert_instruction(ins, actv_func1, ot_before_actv);
// Ht = ot (.) h(Ct)
auto h_cellt = prog.insert_instruction(ins, actv_func3, cellt);
auto ht = prog.insert_instruction(ins, make_op("mul"), ot, h_cellt);
auto h_cellt = m.insert_instruction(ins, actv_func3, cellt);
auto ht = m.insert_instruction(ins, make_op("mul"), ot, h_cellt);
sic = cellt;
sih = ht;
last_hs_output = prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), ht);
last_hs_output = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), ht);
last_cell_output =
prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), cellt);
m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), cellt);
if(i < seq_len - 1)
{
......@@ -1173,12 +1167,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
{
auto concat_hs_arg0 = is_forward ? hidden_states : last_hs_output;
auto concat_hs_arg1 = is_forward ? last_hs_output : hidden_states;
hidden_states = prog.insert_instruction(
hidden_states = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), concat_hs_arg0, concat_hs_arg1);
auto concat_cell_arg0 = is_forward ? cell_outputs : last_cell_output;
auto concat_cell_arg1 = is_forward ? last_cell_output : cell_outputs;
cell_outputs = prog.insert_instruction(
cell_outputs = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), concat_cell_arg0, concat_cell_arg1);
}
}
......@@ -1266,10 +1260,10 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
}
}
bool rewrite_rnn::is_variable_seq_lens(const module& prog, instruction_ref seq_lens) const
bool rewrite_rnn::is_variable_seq_lens(const module& m, instruction_ref seq_lens) const
{
bool is_var_lens = false;
if(seq_lens != prog.end())
if(seq_lens != m.end())
{
if(seq_lens->can_eval())
{
......@@ -1296,12 +1290,12 @@ bool rewrite_rnn::is_variable_seq_lens(const module& prog, instruction_ref seq_l
}
std::size_t
rewrite_rnn::get_seq_len(const module& prog, instruction_ref input, instruction_ref seq_lens) const
rewrite_rnn::get_seq_len(const module& m, instruction_ref input, instruction_ref seq_lens) const
{
bool is_var_lens = is_variable_seq_lens(prog, seq_lens);
bool is_var_lens = is_variable_seq_lens(m, seq_lens);
auto input_shape = input->get_shape();
auto length = input_shape.lens()[0];
if(!is_var_lens and seq_lens != prog.end())
if(!is_var_lens and seq_lens != m.end())
{
auto arg_len = seq_lens->eval();
std::vector<std::size_t> vec_lens;
......@@ -1312,33 +1306,33 @@ rewrite_rnn::get_seq_len(const module& prog, instruction_ref input, instruction_
return length;
}
instruction_ref rewrite_rnn::replace_last_hs_output(module& prog,
instruction_ref rewrite_rnn::replace_last_hs_output(module& m,
instruction_ref ins,
instruction_ref seq_lens,
instruction_ref last_hs_output,
op::rnn_direction dirct) const
{
bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);
bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
instruction_ref result_ins{};
if(variable_seq_len)
{
result_ins = prog.insert_instruction(
std::next(ins),
make_op("rnn_var_sl_shift_output",
{{"output_name", "hidden_states"}, {"direction", dirct}}),
ins,
seq_lens);
prog.replace_instruction(ins, result_ins);
result_ins =
m.insert_instruction(std::next(ins),
make_op("rnn_var_sl_shift_output",
{{"output_name", "hidden_states"}, {"direction", dirct}}),
ins,
seq_lens);
m.replace_instruction(ins, result_ins);
auto hs_outputs = find_all(result_ins->outputs(),
[&](auto i) { return i->name() == "rnn_last_hs_output"; });
for(auto& hs_out : hs_outputs)
{
auto inputs = hs_out->inputs();
prog.replace_instruction(hs_out,
make_op("rnn_var_sl_last_output", {{"direction", dirct}}),
inputs.front(),
seq_lens);
m.replace_instruction(hs_out,
make_op("rnn_var_sl_last_output", {{"direction", dirct}}),
inputs.front(),
seq_lens);
}
}
else
......@@ -1348,7 +1342,7 @@ instruction_ref rewrite_rnn::replace_last_hs_output(module& prog,
for(auto& hs_out : hs_outputs)
{
prog.replace_instruction(hs_out, last_hs_output);
m.replace_instruction(hs_out, last_hs_output);
}
result_ins = ins;
......@@ -1357,14 +1351,14 @@ instruction_ref rewrite_rnn::replace_last_hs_output(module& prog,
return result_ins;
}
void rewrite_rnn::replace_last_cell_output(module& prog,
void rewrite_rnn::replace_last_cell_output(module& m,
instruction_ref ins,
instruction_ref seq_lens,
instruction_ref cell_outputs,
instruction_ref last_cell_output,
op::rnn_direction dirct) const
{
bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);
bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
auto ins_outputs =
find_all(ins->outputs(), [&](auto i) { return i->name() == "rnn_last_cell_output"; });
......@@ -1372,7 +1366,7 @@ void rewrite_rnn::replace_last_cell_output(module& prog,
{
if(!ins_outputs.empty())
{
cell_outputs = prog.insert_instruction(
cell_outputs = m.insert_instruction(
std::next(ins),
make_op("rnn_var_sl_shift_output",
{{"output_name", "cell_outputs"}, {"direction", dirct}}),
......@@ -1382,10 +1376,10 @@ void rewrite_rnn::replace_last_cell_output(module& prog,
for(auto co : ins_outputs)
{
prog.replace_instruction(co,
make_op("rnn_var_sl_last_output", {{"direction", dirct}}),
cell_outputs,
seq_lens);
m.replace_instruction(co,
make_op("rnn_var_sl_last_output", {{"direction", dirct}}),
cell_outputs,
seq_lens);
}
}
// replace the rnn_last_cell_output with the last_cell_output. The while
......@@ -1394,18 +1388,18 @@ void rewrite_rnn::replace_last_cell_output(module& prog,
{
for(auto co : ins_outputs)
{
prog.replace_instruction(co, last_cell_output);
m.replace_instruction(co, last_cell_output);
}
}
}
instruction_ref rewrite_rnn::pad_hidden_states(module& prog,
instruction_ref rewrite_rnn::pad_hidden_states(module& m,
instruction_ref seq,
instruction_ref seq_lens,
instruction_ref hs) const
{
auto max_seq_len = seq->get_shape().lens()[0];
auto seq_len = get_seq_len(prog, seq, seq_lens);
auto seq_len = get_seq_len(m, seq, seq_lens);
// condition of all sequence are of the same length and
// less than max_seq_len, we need to append the hs outputs
......@@ -1417,10 +1411,9 @@ instruction_ref rewrite_rnn::pad_hidden_states(module& prog,
pad_lens[0] = static_cast<std::size_t>(max_seq_len - seq_len);
shape pad_s{s.type(), pad_lens};
std::vector<float> pad_data(pad_s.elements(), 0.0f);
auto pl = prog.add_literal(pad_s, pad_data.begin(), pad_data.end());
hs_padded =
prog.insert_instruction(std::next(hs), make_op("concat", {{"axis", 0}}), hs, pl);
prog.replace_instruction(hs, hs_padded);
auto pl = m.add_literal(pad_s, pad_data.begin(), pad_data.end());
hs_padded = m.insert_instruction(std::next(hs), make_op("concat", {{"axis", 0}}), hs, pl);
m.replace_instruction(hs, hs_padded);
}
return hs_padded;
......
......@@ -42,7 +42,7 @@ struct stream_info
std::unordered_map<instruction_ref, std::size_t> iweights;
ins_dep_map mod_implicit_deps;
void calc_implicit_deps(const module& p) { mod_implicit_deps = p.calc_implicit_deps(); }
void calc_implicit_deps(const module& m) { mod_implicit_deps = m.calc_implicit_deps(); }
void accumulate_weights(instruction_ref last, const schedule_model& model)
{
......@@ -116,15 +116,15 @@ struct stream_info
}
};
std::size_t assign_streams(module& p, std::size_t n)
std::size_t assign_streams(module& m, std::size_t n)
{
assert(n > 0);
partition critical;
std::unordered_map<instruction_ref, std::deque<partition>> partitions;
partitions.reserve(weights.size());
fix([&](auto self, auto ins, auto& part) {
assert(not is_end(ins, p.end()));
if(not p.has_instruction(ins))
assert(not is_end(ins, m.end()));
if(not m.has_instruction(ins))
return;
if(contains(partitions, ins))
return;
......@@ -151,8 +151,8 @@ struct stream_info
}
}
// Sort instructions
p.move_instruction(ins, p.end());
})(std::prev(p.end()), critical);
m.move_instruction(ins, m.end());
})(std::prev(m.end()), critical);
// Set the critical partition to stream 0
set_stream(critical, 0);
......@@ -197,13 +197,13 @@ struct stream_info
}
};
void sort(module& p, std::size_t)
void sort(module& m, std::size_t)
{
std::set<weight_ins, compare_weight_ins> children;
std::unordered_map<instruction_ref, std::size_t> visited;
auto last = std::prev(p.end());
auto last = std::prev(m.end());
auto mw = this->weights.at(last);
auto nw = mw / (p.size() + 1);
auto nw = mw / (m.size() + 1);
auto add_child = [&](auto ins) {
auto x = 1 + (mw - this->weights.at(ins)) / (nw + 1);
auto w = x * this->iweights.at(ins);
......@@ -222,10 +222,10 @@ struct stream_info
// Pop the first element
auto top = children.begin()->second;
children.erase(children.begin());
p.move_instruction(top, p.begin());
m.move_instruction(top, m.begin());
for(auto ins : top->inputs())
{
if(not p.has_instruction(ins))
if(not m.has_instruction(ins))
continue;
add_child(ins);
}
......@@ -234,7 +234,7 @@ struct stream_info
{
for(auto ins : mod_implicit_deps.at(top))
{
assert(p.has_instruction(ins));
assert(m.has_instruction(ins));
add_child(ins);
}
}
......@@ -242,12 +242,12 @@ struct stream_info
// move dangling parameter to the front so as not be removed
auto ins = std::next(last);
while(ins != p.end())
while(ins != m.end())
{
auto next = std::next(ins);
if(ins->name() == "@param")
{
p.move_instruction(ins, p.begin());
m.move_instruction(ins, m.begin());
}
ins = next;
}
......@@ -364,18 +364,18 @@ struct stream_info
}
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>>
find_concurrent_instructions(module& p) const
find_concurrent_instructions(module& m) const
{
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result;
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> merge_from;
dominator_info di = compute_dominator(p);
result.reserve(p.size());
merge_from.reserve(p.size());
for(auto ins : reverse_iterator_for(p))
dominator_info di = compute_dominator(m);
result.reserve(m.size());
merge_from.reserve(m.size());
for(auto ins : reverse_iterator_for(m))
{
for(auto&& arg : ins->outputs())
{
if(not p.has_instruction(arg))
if(not m.has_instruction(arg))
continue;
if(is_merge_point(arg))
merge_from[ins].insert(arg);
......@@ -415,18 +415,18 @@ struct stream_info
}
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>
get_conflicts(module& p)
get_conflicts(module& m)
{
using conflict_table_type =
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>;
conflict_table_type conflict_table;
auto concur_ins = this->find_concurrent_instructions(p);
auto concur_ins = this->find_concurrent_instructions(m);
// Compute an index for each instruction
std::unordered_map<instruction_ref, std::size_t> ins2index;
std::size_t index_total = 0;
for(auto ins : iterator_for(p))
for(auto ins : iterator_for(m))
ins2index[ins] = index_total++;
std::vector<conflict_table_type> thread_conflict_tables(
......@@ -507,21 +507,21 @@ struct stream_info
}
};
void schedule::apply(module& p) const
void schedule::apply(module& m) const
{
if(not enable)
return;
stream_info si;
si.calc_implicit_deps(p);
auto last = std::prev(p.end());
si.calc_implicit_deps(m);
auto last = std::prev(m.end());
si.accumulate_weights(last, model);
auto nstreams = si.assign_streams(p, model.concurrency());
si.sort(p, model.concurrency());
auto nstreams = si.assign_streams(m, model.concurrency());
si.sort(m, model.concurrency());
if(enabled(MIGRAPHX_TRACE_COMPILE{}) or enabled(MIGRAPHX_TRACE_SCHEDULE{}))
{
p.annotate(std::cout, [&](auto ins) {
m.annotate(std::cout, [&](auto ins) {
if(ins->name() == "@param" and not contains(si.weights, ins))
return;
......@@ -548,9 +548,9 @@ void schedule::apply(module& p) const
std::unordered_map<instruction_ref, std::size_t> ins2wait;
std::unordered_map<std::size_t, std::unordered_set<std::size_t>> waited_for;
std::unordered_map<instruction_ref, std::unordered_set<std::size_t>> ins2waited;
ins2wait.reserve(p.size());
ins2waited.reserve(p.size());
for(auto ins : iterator_for(p))
ins2wait.reserve(m.size());
ins2waited.reserve(m.size());
for(auto ins : iterator_for(m))
{
// Only schedule instructions that have a stream
if(not si.has_stream(ins))
......@@ -559,7 +559,7 @@ void schedule::apply(module& p) const
// Schedule instruction on the stream
auto stream = si.get_stream(ins);
assert(stream < model.concurrency());
model.sched(p, ins, stream);
model.sched(m, ins, stream);
// Insert wait instructions
if(si.is_merge_point(ins, stream))
{
......@@ -572,14 +572,14 @@ void schedule::apply(module& p) const
if(not contains(ins2wait, i))
{
ins2wait[i] = wait_id;
model.record(p, i, wait_id);
model.record(m, i, wait_id);
wait_id++;
}
auto w = ins2wait.at(i);
// If we already waited for the event on this stream then dont
// insert another wait event
if(not contains(waited_for[stream], w))
model.wait(p, ins, w);
model.wait(m, ins, w);
// Store the event as waited
waited_for[stream].insert(w);
// Store all wait events that have been waited on prior to the recorded instruction
......@@ -594,7 +594,7 @@ void schedule::apply(module& p) const
}
// Add memory conflicts
auto conflict_table = si.get_conflicts(p);
auto conflict_table = si.get_conflicts(m);
for(auto&& ip : conflict_table)
{
if(ip.second.empty())
......@@ -602,7 +602,7 @@ void schedule::apply(module& p) const
std::vector<instruction_ref> args;
args.push_back(ip.first);
args.insert(args.end(), ip.second.begin(), ip.second.end());
p.insert_instruction(std::next(ip.first), make_op("identity"), args);
m.insert_instruction(std::next(ip.first), make_op("identity"), args);
}
}
......
......@@ -42,7 +42,7 @@ struct find_mul_conv
match::name("broadcast").bind("a")));
}
void apply(module& p, match::matcher_result r) const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto conv_ins = r.instructions["conv"];
......@@ -53,14 +53,14 @@ struct find_mul_conv
if(broadcast_op.axis != 1)
return;
auto new_a = p.insert_instruction(
auto new_a = m.insert_instruction(
ins,
make_op("broadcast", {{"axis", 0}, {"out_lens", w_ins->get_shape().lens()}}),
a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, make_op("mul"), new_a, w_ins);
auto new_conv = p.insert_instruction(
auto new_mul = m.insert_instruction(ins, make_op("mul"), new_a, w_ins);
auto new_conv = m.insert_instruction(
ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_mul);
p.replace_instruction(ins, new_conv);
m.replace_instruction(ins, new_conv);
}
};
......@@ -80,7 +80,7 @@ struct find_mul_slice_conv
match::name("broadcast")(match::is_constant()).bind("a")));
}
void apply(module& p, match::matcher_result r) const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto slice_ins = r.instructions["slice"];
......@@ -116,38 +116,38 @@ struct find_mul_slice_conv
auto w_slice_op = slice_op;
w_slice_op.axes = {0};
auto slice_w_ins = p.insert_instruction(ins, w_slice_op, w_ins);
auto slice_w_ins = m.insert_instruction(ins, w_slice_op, w_ins);
auto new_a = p.insert_instruction(
auto new_a = m.insert_instruction(
ins,
make_op("broadcast", {{"axis", 0}, {"out_lens", slice_w_ins->get_shape().lens()}}),
a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, make_op("mul"), new_a, slice_w_ins);
auto new_mul = m.insert_instruction(ins, make_op("mul"), new_a, slice_w_ins);
std::vector<instruction_ref> sliced_weights;
if(slice_op.starts.front() != 0)
sliced_weights.push_back(p.insert_instruction(
sliced_weights.push_back(m.insert_instruction(
ins,
make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", slice_op.starts}}),
w_ins));
sliced_weights.push_back(new_mul);
int64_t end_axis = w_ins->get_shape().lens().at(0);
if(slice_op.ends.front() != end_axis)
sliced_weights.push_back(p.insert_instruction(
sliced_weights.push_back(m.insert_instruction(
ins,
make_op("slice", {{"axes", {0}}, {"starts", slice_op.ends}, {"ends", {end_axis}}}),
w_ins));
auto new_weights =
p.insert_instruction(ins, make_op("concat", {{"axis", 0}}), sliced_weights);
m.insert_instruction(ins, make_op("concat", {{"axis", 0}}), sliced_weights);
auto new_conv = p.insert_instruction(
auto new_conv = m.insert_instruction(
ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_weights);
assert(conv_ins->get_shape() == new_conv->get_shape());
auto slice1 = p.insert_instruction(ins, slice_op, new_conv);
auto slice1 = m.insert_instruction(ins, slice_op, new_conv);
assert(ins->get_shape().lens() == slice1->get_shape().lens());
p.replace_instruction(ins, slice1);
m.replace_instruction(ins, slice1);
// TODO: Check each slice doesn't overlap and that it occurs after slice_ins
auto outputs = conv_ins->outputs();
for(auto output : outputs)
......@@ -171,7 +171,7 @@ struct find_mul_add
match::is_constant().bind("a")));
}
void apply(module& p, match::matcher_result r) const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_ins = r.instructions["a"];
......@@ -179,9 +179,9 @@ struct find_mul_add
auto x_ins = r.instructions["x"];
assert(x_ins != b_ins);
auto ax_ins = p.insert_instruction(ins, make_op("mul"), a_ins, x_ins);
auto ab_ins = p.insert_instruction(ins, make_op("mul"), a_ins, b_ins);
p.replace_instruction(ins, make_op("add"), ax_ins, ab_ins);
auto ax_ins = m.insert_instruction(ins, make_op("mul"), a_ins, x_ins);
auto ab_ins = m.insert_instruction(ins, make_op("mul"), a_ins, b_ins);
m.replace_instruction(ins, make_op("add"), ax_ins, ab_ins);
}
};
......@@ -193,15 +193,15 @@ struct find_add_lit_broadcast
match::either_arg(0, 1)(op_lit_broadcast("add", "a", "x"), lit_broadcast().bind("b")));
}
void apply(module& p, match::matcher_result r) const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto sumab = p.insert_instruction(ins, make_op("add"), a_ins, b_ins);
p.replace_instruction(ins, make_op("add"), x_ins, sumab);
auto sumab = m.insert_instruction(ins, make_op("add"), a_ins, b_ins);
m.replace_instruction(ins, make_op("add"), x_ins, sumab);
}
};
......@@ -213,7 +213,7 @@ struct find_double_add_lit_broadcast
match::args(op_lit_broadcast("add", "a", "x"), op_lit_broadcast("add", "b", "y")));
}
void apply(module& p, match::matcher_result r) const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
......@@ -228,17 +228,17 @@ struct find_double_add_lit_broadcast
if(a_ins->inputs().at(0)->get_shape() != b_ins->inputs().at(0)->get_shape())
return;
auto op = a_ins->get_operator();
auto presum = p.insert_instruction(
auto presum = m.insert_instruction(
ins, make_op("add"), a_ins->inputs().at(0), b_ins->inputs().at(0));
sumab = p.insert_instruction(ins, op, presum);
sumab = m.insert_instruction(ins, op, presum);
}
else
{
sumab = p.insert_instruction(ins, make_op("add"), a_ins, b_ins);
sumab = m.insert_instruction(ins, make_op("add"), a_ins, b_ins);
}
auto sumxy = p.insert_instruction(ins, make_op("add"), x_ins, y_ins);
p.replace_instruction(ins, make_op("add"), sumxy, sumab);
auto sumxy = m.insert_instruction(ins, make_op("add"), x_ins, y_ins);
m.replace_instruction(ins, make_op("add"), sumxy, sumab);
}
};
......@@ -251,7 +251,7 @@ struct find_inner_broadcast
match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y")));
}
void apply(module& p, match::matcher_result r) const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
......@@ -263,9 +263,9 @@ struct find_inner_broadcast
if(xbroadcast.axis != ybroadcast.axis)
return;
auto op = p.insert_instruction(
auto op = m.insert_instruction(
ins, ins->get_operator(), x_ins->inputs().front(), y_ins->inputs().front());
p.replace_instruction(ins, xbroadcast, op);
m.replace_instruction(ins, xbroadcast, op);
}
};
......@@ -296,7 +296,7 @@ struct find_concat_op
return op.name() == "broadcast" or op.attributes().contains("pointwise");
}
void apply(module& p, const match::matcher_result& r) const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto axis = any_cast<op::concat>(ins->get_operator()).axis;
......@@ -330,10 +330,10 @@ struct find_concat_op
return j->inputs().at(i);
});
auto concat =
p.insert_instruction(ins, make_op("concat", {{"axis", iaxis}}), inputs);
m.insert_instruction(ins, make_op("concat", {{"axis", iaxis}}), inputs);
concats.push_back(concat);
}
auto y = p.insert_instruction(ins, op, concats);
auto y = m.insert_instruction(ins, op, concats);
return {y};
};
......@@ -349,9 +349,9 @@ struct find_concat_op
};
group_unique(ins->inputs().begin(), ins->inputs().end(), update_args, pred);
if(args.size() == 1)
p.replace_instruction(ins, args.front());
m.replace_instruction(ins, args.front());
else
p.replace_instruction(ins, make_op("concat", {{"axis", axis}}), args);
m.replace_instruction(ins, make_op("concat", {{"axis", axis}}), args);
}
};
......@@ -478,14 +478,14 @@ struct find_splits
return true;
}
void apply(module& p, const match::matcher_result& r) const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto splits = get_splits(ins);
if(splits.empty())
return;
for(const auto& group : get_split_groups(p, splits))
for(const auto& group : get_split_groups(m, splits))
{
auto start = group.front();
auto split_front = splits.front();
......@@ -500,10 +500,10 @@ struct find_splits
std::next(group.begin()), group.end(), [&](auto i) { return i == start; }));
auto split_idx = 0;
instruction_ref c = p.end();
instruction_ref c = m.end();
if(start->inputs().size() == 1)
{
c = p.insert_instruction(std::next(ins), op, ins);
c = m.insert_instruction(std::next(ins), op, ins);
}
else if(start->inputs().size() == 2)
{
......@@ -530,7 +530,7 @@ struct find_splits
return;
for(auto data : data_args)
p.move_instructions(data, ins);
m.move_instructions(data, ins);
auto slice_op = any_cast<op::slice>(splits.front()->get_operator());
assert(not slice_op.axes.empty());
......@@ -538,16 +538,16 @@ struct find_splits
return;
auto concat_axis = slice_op.axes.front();
// TODO: Check if axises match
auto concat = p.insert_instruction(
auto concat = m.insert_instruction(
ins, make_op("concat", {{"axis", concat_axis}}), data_args);
std::vector<instruction_ref> args;
args.resize(2);
args[split_idx] = ins;
args[data_idx] = concat;
c = p.insert_instruction(std::next(ins), op, args);
c = m.insert_instruction(std::next(ins), op, args);
}
if(c != p.end())
if(c != m.end())
{
for(auto i : group)
{
......@@ -560,11 +560,11 @@ struct find_splits
if(not contains({"reshape", "squeeze", "unsqueeze"}, output->name()))
continue;
auto x =
p.insert_instruction(output, make_op("contiguous"), output->inputs());
p.replace_instruction(output, output->get_operator(), x);
m.insert_instruction(output, make_op("contiguous"), output->inputs());
m.replace_instruction(output, output->get_operator(), x);
}
p.replace_instruction(i, split->get_operator(), c);
m.replace_instruction(i, split->get_operator(), c);
}
}
}
......@@ -579,7 +579,7 @@ struct find_split_concat
match::name("slice")(match::all_of[match::outputs()](match::name("concat")))));
}
void apply(module& p, const match::matcher_result& r) const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
......@@ -619,9 +619,9 @@ struct find_split_concat
args.erase(std::next(it), it + splits.size());
if(args.size() == 1)
p.replace_instruction(concat, args.front());
m.replace_instruction(concat, args.front());
else
p.replace_instruction(concat, concat->get_operator(), args);
m.replace_instruction(concat, concat->get_operator(), args);
}
};
......@@ -664,7 +664,7 @@ struct find_add_convs
return x.stride[0] / y.stride[0];
}
void apply(module& p, match::matcher_result r) const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_conv = r.instructions["a"];
......@@ -693,7 +693,7 @@ struct find_add_convs
if(n == 0)
return;
new_op = a_op;
b_input = p.insert_instruction(
b_input = m.insert_instruction(
ins, make_op("step", {{"axes", {2, 3}}, {"steps", {n, n}}}), b_input);
}
else if(b_op.stride < a_op.stride)
......@@ -702,7 +702,7 @@ struct find_add_convs
if(n == 0)
return;
new_op = b_op;
a_input = p.insert_instruction(
a_input = m.insert_instruction(
ins, make_op("step", {{"axes", {2, 3}}, {"steps", {n, n}}}), a_input);
}
else
......@@ -713,10 +713,10 @@ struct find_add_convs
}
auto concat_input =
p.insert_instruction(ins, make_op("concat", {{"axis", 1}}), a_input, b_input);
m.insert_instruction(ins, make_op("concat", {{"axis", 1}}), a_input, b_input);
auto concat_weights =
p.insert_instruction(ins, make_op("concat", {{"axis", 1}}), a_weights, b_weights);
p.replace_instruction(ins, new_op, concat_input, concat_weights);
m.insert_instruction(ins, make_op("concat", {{"axis", 1}}), a_weights, b_weights);
m.replace_instruction(ins, new_op, concat_input, concat_weights);
}
};
......@@ -737,7 +737,7 @@ struct find_conv_dot_horiz_fusion
{
auto matcher() const { return horiz_conv_dot(); }
void apply(module& p, const match::matcher_result& r) const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
......@@ -785,16 +785,16 @@ struct find_conv_dot_horiz_fusion
}
for(auto arg : args)
p.move_instructions(arg, input);
m.move_instructions(arg, input);
// TODO: Check if axises match
auto concat =
p.insert_instruction(input, make_op("concat", {{"axis", concat_axis}}), args);
auto fused = p.insert_instruction(std::next(input), op, input, concat);
m.insert_instruction(input, make_op("concat", {{"axis", concat_axis}}), args);
auto fused = m.insert_instruction(std::next(input), op, input, concat);
int64_t offset = 0;
for(auto arg : range(start, last))
{
int64_t len = arg->get_shape().lens()[axis];
p.replace_instruction(
m.replace_instruction(
arg,
make_op("slice",
{{"axes", {axis}}, {"starts", {offset}}, {"ends", {offset + len}}}),
......@@ -815,16 +815,16 @@ struct find_div_const
return match::name("div")(match::arg(1)(match::is_constant().bind("c")));
}
void apply(module& p, match::matcher_result r) const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto c_ins = r.instructions["c"];
auto recip = p.insert_instruction(std::next(c_ins), make_op("recip"), c_ins);
auto recip = m.insert_instruction(std::next(c_ins), make_op("recip"), c_ins);
auto args = ins->inputs();
p.replace_instruction(ins, make_op("mul"), args.front(), recip);
m.replace_instruction(ins, make_op("mul"), args.front(), recip);
}
};
......@@ -835,16 +835,16 @@ struct find_sub_const
return match::name("sub")(match::arg(1)(match::is_constant().bind("c")));
}
void apply(module& p, match::matcher_result r) const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto c_ins = r.instructions["c"];
auto neg = p.insert_instruction(std::next(c_ins), make_op("neg"), c_ins);
auto neg = m.insert_instruction(std::next(c_ins), make_op("neg"), c_ins);
auto args = ins->inputs();
p.replace_instruction(ins, make_op("add"), args.front(), neg);
m.replace_instruction(ins, make_op("add"), args.front(), neg);
}
};
......@@ -856,12 +856,12 @@ struct find_rsqrt
match::name("sqrt")(match::used_once(), match::args(match::any().bind("x")))));
}
void apply(module& p, match::matcher_result r) const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
p.replace_instruction(ins, make_op("rsqrt"), x_ins);
m.replace_instruction(ins, make_op("rsqrt"), x_ins);
}
};
......@@ -881,7 +881,7 @@ struct find_split_reshape
.bind("reshape");
}
void apply(module& p, match::matcher_result r) const
void apply(module& m, const match::matcher_result& r) const
{
auto slc = r.instructions["slice"];
auto rsp = r.instructions["reshape"];
......@@ -936,14 +936,14 @@ struct find_split_reshape
rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0});
// insert the reshape instruction
auto rsp_ins = p.insert_instruction(
auto rsp_ins = m.insert_instruction(
std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input);
// replace the original reshape with slice
int64_t start = 0;
for(std::size_t i = 0; i < vec_rsp.size(); ++i)
{
p.replace_instruction(
m.replace_instruction(
vec_rsp[i],
make_op(
"slice",
......@@ -962,7 +962,7 @@ struct find_split_transpose
.bind("trans");
}
void apply(module& p, match::matcher_result r) const
void apply(module& m, const match::matcher_result& r) const
{
auto slc = r.instructions["slice"];
auto trans = r.instructions["trans"];
......@@ -988,7 +988,7 @@ struct find_split_transpose
}
// insert an transpose instruction
auto tr = p.insert_instruction(
auto tr = m.insert_instruction(
std::next(input), make_op("transpose", {{"permutation", perm}}), input);
// compute the axis in the slice
......@@ -1003,7 +1003,7 @@ struct find_split_transpose
auto starts = oper.starts;
auto ends = oper.ends;
auto tr_orig = in->outputs().front();
p.replace_instruction(
m.replace_instruction(
tr_orig,
make_op("slice", {{"axes", {axis_new}}, {"starts", starts}, {"ends", ends}}),
tr);
......@@ -1011,12 +1011,12 @@ struct find_split_transpose
}
};
void simplify_algebra::apply(module& p) const
void simplify_algebra::apply(module& m) const
{
// Run simplifications multiple times
for(int i = 0; i < 8; i++)
{
match::find_matches(p,
match::find_matches(m,
find_inner_broadcast{},
find_double_add_lit_broadcast{},
find_add_lit_broadcast{},
......@@ -1033,7 +1033,7 @@ void simplify_algebra::apply(module& p) const
find_splits{},
find_split_reshape{},
find_split_transpose{});
dead_code_elimination{}.apply(p);
dead_code_elimination{}.apply(m);
}
}
......
......@@ -53,7 +53,7 @@ struct match_find_quantizable_ops
match::arg(1)(dequantizelinear_op("x2", "scale2")));
}
void apply(module& m, match::matcher_result r) const
void apply(module& m, const match::matcher_result& r) const
{
auto qop = r.result;
auto q1 = r.instructions["x1"];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment