Commit 11e155c2 authored by Paul's avatar Paul
Browse files

Merge

parents 8a9c5bce aa7ff911
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void rewrite_pooling::apply(module& prog) const void rewrite_pooling::apply(module& m) const
{ {
for(auto ins : iterator_for(prog)) for(auto ins : iterator_for(m))
{ {
if(ins->name() != "pooling") if(ins->name() != "pooling")
continue; continue;
...@@ -33,26 +33,25 @@ void rewrite_pooling::apply(module& prog) const ...@@ -33,26 +33,25 @@ void rewrite_pooling::apply(module& prog) const
continue; continue;
std::int64_t n = s.lens()[0]; std::int64_t n = s.lens()[0];
std::int64_t c = s.lens()[1]; std::int64_t c = s.lens()[1];
auto reshape = prog.insert_instruction( auto reshape = m.insert_instruction(
ins, make_op("reshape", {{"dims", {n * c, -1}}}), ins->inputs().front()); ins, make_op("reshape", {{"dims", {n * c, -1}}}), ins->inputs().front());
instruction_ref pooling{}; instruction_ref pooling{};
// average pooling // average pooling
if(op.mode == "average") if(op.mode == op::pooling_mode::average)
{ {
pooling = pooling = m.insert_instruction(ins, make_op("reduce_mean", {{"axes", {1}}}), reshape);
prog.insert_instruction(ins, make_op("reduce_mean", {{"axes", {1}}}), reshape);
} }
// max pooling // max pooling
else else
{ {
pooling = prog.insert_instruction(ins, make_op("reduce_max", {{"axes", {1}}}), reshape); pooling = m.insert_instruction(ins, make_op("reduce_max", {{"axes", {1}}}), reshape);
} }
std::vector<int64_t> rsp_lens(lens.size(), 1); std::vector<int64_t> rsp_lens(lens.size(), 1);
rsp_lens[0] = n; rsp_lens[0] = n;
rsp_lens[1] = c; rsp_lens[1] = c;
prog.replace_instruction(ins, make_op("reshape", {{"dims", rsp_lens}}), pooling); m.replace_instruction(ins, make_op("reshape", {{"dims", rsp_lens}}), pooling);
} }
} }
......
...@@ -30,27 +30,27 @@ ...@@ -30,27 +30,27 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void rewrite_rnn::apply(module& prog) const void rewrite_rnn::apply(module& m) const
{ {
for(auto ins : iterator_for(prog)) for(auto ins : iterator_for(m))
{ {
if(ins->name() == "rnn") if(ins->name() == "rnn")
{ {
apply_vanilla_rnn(prog, ins); apply_vanilla_rnn(m, ins);
} }
else if(ins->name() == "gru") else if(ins->name() == "gru")
{ {
apply_gru(prog, ins); apply_gru(m, ins);
} }
else if(ins->name() == "lstm") else if(ins->name() == "lstm")
{ {
apply_lstm(prog, ins); apply_lstm(m, ins);
} }
} }
} }
// NOLINTNEXTLINE(readability-function-cognitive-complexity) // NOLINTNEXTLINE(readability-function-cognitive-complexity)
void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
{ {
assert(ins->name() == "rnn"); assert(ins->name() == "rnn");
// could be 3 to 6 inputs, but the parse_rnn function will // could be 3 to 6 inputs, but the parse_rnn function will
...@@ -71,37 +71,37 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const ...@@ -71,37 +71,37 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
op::rnn_direction dirct = rnn_op.direction; op::rnn_direction dirct = rnn_op.direction;
// process sequence length // process sequence length
instruction_ref seq_lens = prog.end(); instruction_ref seq_lens = m.end();
if((args.size() >= 5) && args[4]->name() != "undefined") if((args.size() >= 5) && args[4]->name() != "undefined")
{ {
seq_lens = args[4]; seq_lens = args[4];
} }
bool variable_seq_len = is_variable_seq_lens(prog, seq_lens); bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
instruction_ref last_output{}; instruction_ref last_output{};
if(dirct == op::rnn_direction::bidirectional) if(dirct == op::rnn_direction::bidirectional)
{ {
// input weight matrix // input weight matrix
auto w_forward = prog.insert_instruction( auto w_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
auto w_reverse = prog.insert_instruction( auto w_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]); ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
// hidden state weight matrix // hidden state weight matrix
auto r_forward = prog.insert_instruction( auto r_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
auto r_reverse = prog.insert_instruction( auto r_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]); ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
// process bias // process bias
instruction_ref bias_forward = prog.end(); instruction_ref bias_forward = m.end();
instruction_ref bias_reverse = prog.end(); instruction_ref bias_reverse = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 && args[3]->name() != "undefined")
{ {
bias_forward = prog.insert_instruction( bias_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
bias_reverse = prog.insert_instruction( bias_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]); ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
} }
...@@ -111,57 +111,56 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const ...@@ -111,57 +111,56 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
instruction_ref ih_reverse{}; instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->name() != "undefined") if(args.size() == 6 && args[5]->name() != "undefined")
{ {
ih_forward = prog.insert_instruction( ih_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
ih_reverse = prog.insert_instruction( ih_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]); ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
} }
else else
{ {
ih_forward = prog.add_literal(migraphx::literal{ih_shape, data}); ih_forward = m.add_literal(migraphx::literal{ih_shape, data});
ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data}); ih_reverse = m.add_literal(migraphx::literal{ih_shape, data});
} }
auto ret_forward = auto ret_forward =
vanilla_rnn_cell(true, vanilla_rnn_cell(true,
prog, m,
ins, ins,
{args[0], w_forward, r_forward, bias_forward, seq_lens, ih_forward}, {args[0], w_forward, r_forward, bias_forward, seq_lens, ih_forward},
actv_funcs.at(0)); actv_funcs.at(0));
if(variable_seq_len) if(variable_seq_len)
{ {
args[0] = prog.insert_instruction( args[0] =
ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens); m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
} }
auto ret_reverse = auto ret_reverse =
vanilla_rnn_cell(false, vanilla_rnn_cell(false,
prog, m,
ins, ins,
{args[0], w_reverse, r_reverse, bias_reverse, seq_lens, ih_reverse}, {args[0], w_reverse, r_reverse, bias_reverse, seq_lens, ih_reverse},
actv_funcs.at(1)); actv_funcs.at(1));
auto concat_output = prog.insert_instruction( auto concat_output = m.insert_instruction(
ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]); ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
last_output = last_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_output);
prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_output);
// The following logic is to ensure the last instruction rewritten from // The following logic is to ensure the last instruction rewritten from
// rnn operator is a concat instruction // rnn operator is a concat instruction
// sequence len is 1 // sequence len is 1
if(ret_forward[0] == prog.end()) if(ret_forward[0] == m.end())
{ {
prog.replace_instruction( m.replace_instruction(
ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]); ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
} }
else else
{ {
ret_forward[0] = prog.insert_instruction( ret_forward[0] = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]); ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
ret_reverse[0] = prog.insert_instruction( ret_reverse[0] = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]); ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
prog.replace_instruction( m.replace_instruction(
ins, make_op("concat", {{"axis", 1}}), {ret_forward[0], ret_reverse[0]}); ins, make_op("concat", {{"axis", 1}}), {ret_forward[0], ret_reverse[0]});
} }
} }
...@@ -175,7 +174,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const ...@@ -175,7 +174,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
auto r = args[2]; auto r = args[2];
// process bias and initial hidden state // process bias and initial hidden state
instruction_ref bias = prog.end(); instruction_ref bias = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 && args[3]->name() != "undefined")
{ {
bias = args[3]; bias = args[3];
...@@ -189,43 +188,42 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const ...@@ -189,43 +188,42 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
} }
else else
{ {
ih = prog.add_literal(migraphx::literal{ih_shape, data}); ih = m.add_literal(migraphx::literal{ih_shape, data});
} }
if(!is_forward and variable_seq_len) if(!is_forward and variable_seq_len)
{ {
args[0] = prog.insert_instruction( args[0] =
ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens); m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
} }
auto ret = vanilla_rnn_cell( auto ret = vanilla_rnn_cell(
is_forward, prog, ins, {args[0], w, r, bias, seq_lens, ih}, actv_funcs.at(0)); is_forward, m, ins, {args[0], w, r, bias, seq_lens, ih}, actv_funcs.at(0));
last_output = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]); last_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
// following logic is to ensure the last instruction is a // following logic is to ensure the last instruction is a
// concat instruction // concat instruction
// sequence len is 1 // sequence len is 1
if(ret[0] == prog.end()) if(ret[0] == m.end())
{ {
prog.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]); m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
} }
else else
{ {
auto concat_arg0 = is_forward ? ret[0] : ret[1]; auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0]; auto concat_arg1 = is_forward ? ret[1] : ret[0];
prog.replace_instruction( m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
} }
} }
// in case of all sequences are of the same lengths and shorter than the // in case of all sequences are of the same lengths and shorter than the
// max sequence length, need to pad 0's at the end for output hidden states // max sequence length, need to pad 0's at the end for output hidden states
ins = pad_hidden_states(prog, args[0], seq_lens, ins); ins = pad_hidden_states(m, args[0], seq_lens, ins);
replace_last_hs_output(prog, ins, seq_lens, last_output, dirct); replace_last_hs_output(m, ins, seq_lens, last_output, dirct);
} }
std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward, std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
module& prog, module& m,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref> inputs, std::vector<instruction_ref> inputs,
operation& actv_func) const operation& actv_func) const
...@@ -240,60 +238,60 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward, ...@@ -240,60 +238,60 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
// squeeze and transpose w // squeeze and transpose w
std::vector<int64_t> perm{1, 0}; std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w); auto sw = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tran_sw = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw); auto tran_sw = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
// squeeze and transpose r // squeeze and transpose r
auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r); auto sr = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto tran_sr = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr); auto tran_sr = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr);
// initial hidden state // initial hidden state
auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih); auto sih = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
auto sih_lens = sih->get_shape().lens(); auto sih_lens = sih->get_shape().lens();
// bias // bias
instruction_ref bb{}; instruction_ref bb{};
if(bias != prog.end()) if(bias != m.end())
{ {
long hs = static_cast<long>(r->get_shape().lens()[2]); long hs = static_cast<long>(r->get_shape().lens()[2]);
auto sbias = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias); auto sbias = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
auto wb = prog.insert_instruction( auto wb = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), sbias); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), sbias);
auto rb = prog.insert_instruction( auto rb = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), sbias); ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), sbias);
auto wrb = prog.insert_instruction(ins, make_op("add"), wb, rb); auto wrb = m.insert_instruction(ins, make_op("add"), wb, rb);
bb = prog.insert_instruction( bb = m.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"out_lens", sih_lens}}), wrb); ins, make_op("broadcast", {{"axis", 1}, {"out_lens", sih_lens}}), wrb);
} }
instruction_ref hidden_out = prog.end(); instruction_ref hidden_out = m.end();
instruction_ref last_out{}; instruction_ref last_out{};
last_out = prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), sih); last_out = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), sih);
long seq_len = get_seq_len(prog, seq, seq_lens); long seq_len = get_seq_len(m, seq, seq_lens);
for(long i = 0; i < seq_len; i++) for(long i = 0; i < seq_len; i++)
{ {
long seq_index = is_forward ? i : (seq_len - 1 - i); long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction( auto xt = m.insert_instruction(
ins, ins,
make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}), make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
seq); seq);
auto cont_xt = prog.insert_instruction(ins, make_op("contiguous"), xt); auto cont_xt = m.insert_instruction(ins, make_op("contiguous"), xt);
xt = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt); xt = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
auto xt_wi = prog.insert_instruction(ins, make_op("dot"), xt, tran_sw); auto xt_wi = m.insert_instruction(ins, make_op("dot"), xt, tran_sw);
auto ht_ri = prog.insert_instruction(ins, make_op("dot"), sih, tran_sr); auto ht_ri = m.insert_instruction(ins, make_op("dot"), sih, tran_sr);
if(bias != prog.end()) if(bias != m.end())
{ {
xt_wi = prog.insert_instruction(ins, make_op("add"), xt_wi, bb); xt_wi = m.insert_instruction(ins, make_op("add"), xt_wi, bb);
} }
auto xt_ht = prog.insert_instruction(ins, make_op("add"), xt_wi, ht_ri); auto xt_ht = m.insert_instruction(ins, make_op("add"), xt_wi, ht_ri);
// apply activation function // apply activation function
auto ht = prog.insert_instruction(ins, actv_func, xt_ht); auto ht = m.insert_instruction(ins, actv_func, xt_ht);
sih = ht; sih = ht;
// add the dimensions of sequence length (axis 0 for sequence length, // add the dimensions of sequence length (axis 0 for sequence length,
// axis 1 for num_directions // axis 1 for num_directions
last_out = prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), ht); last_out = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), ht);
// concatenation for the last last_out is performed in the apply() // concatenation for the last last_out is performed in the apply()
// function to ensure the last instruction is concat, then we have // function to ensure the last instruction is concat, then we have
...@@ -304,14 +302,14 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward, ...@@ -304,14 +302,14 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
{ {
hidden_out = (seq_index == 0) hidden_out = (seq_index == 0)
? last_out ? last_out
: prog.insert_instruction( : m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), hidden_out, last_out); ins, make_op("concat", {{"axis", 0}}), hidden_out, last_out);
} }
else else
{ {
hidden_out = (seq_index == seq_len - 1) hidden_out = (seq_index == seq_len - 1)
? last_out ? last_out
: prog.insert_instruction( : m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), last_out, hidden_out); ins, make_op("concat", {{"axis", 0}}), last_out, hidden_out);
} }
} }
...@@ -358,7 +356,7 @@ std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins) ...@@ -358,7 +356,7 @@ std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins)
} }
// NOLINTNEXTLINE(readability-function-cognitive-complexity) // NOLINTNEXTLINE(readability-function-cognitive-complexity)
void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
{ {
assert(ins->name() == "gru"); assert(ins->name() == "gru");
const auto actv_funcs = gru_actv_funcs(ins); const auto actv_funcs = gru_actv_funcs(ins);
...@@ -379,37 +377,37 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const ...@@ -379,37 +377,37 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
op::rnn_direction dirct = gru_op.direction; op::rnn_direction dirct = gru_op.direction;
// process sequence length // process sequence length
instruction_ref seq_lens = prog.end(); instruction_ref seq_lens = m.end();
if((args.size() >= 5) && args[4]->name() != "undefined") if((args.size() >= 5) && args[4]->name() != "undefined")
{ {
seq_lens = args[4]; seq_lens = args[4];
} }
bool variable_seq_len = is_variable_seq_lens(prog, seq_lens); bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
instruction_ref last_output{}; instruction_ref last_output{};
if(dirct == op::rnn_direction::bidirectional) if(dirct == op::rnn_direction::bidirectional)
{ {
// w weight matrix // w weight matrix
auto w_forward = prog.insert_instruction( auto w_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
auto w_reverse = prog.insert_instruction( auto w_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]); ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
// r weight matrix // r weight matrix
auto r_forward = prog.insert_instruction( auto r_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
auto r_reverse = prog.insert_instruction( auto r_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]); ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
// bias // bias
instruction_ref bias_forward = prog.end(); instruction_ref bias_forward = m.end();
instruction_ref bias_reverse = prog.end(); instruction_ref bias_reverse = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 && args[3]->name() != "undefined")
{ {
bias_forward = prog.insert_instruction( bias_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
bias_reverse = prog.insert_instruction( bias_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]); ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
} }
...@@ -418,20 +416,20 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const ...@@ -418,20 +416,20 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
instruction_ref ih_reverse{}; instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->name() != "undefined") if(args.size() == 6 && args[5]->name() != "undefined")
{ {
ih_forward = prog.insert_instruction( ih_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
ih_reverse = prog.insert_instruction( ih_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]); ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
} }
else else
{ {
ih_forward = prog.add_literal(migraphx::literal{ih_shape, data}); ih_forward = m.add_literal(migraphx::literal{ih_shape, data});
ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data}); ih_reverse = m.add_literal(migraphx::literal{ih_shape, data});
} }
auto ret_forward = auto ret_forward =
gru_cell(true, gru_cell(true,
prog, m,
ins, ins,
{args[0], w_forward, r_forward, bias_forward, seq_lens, ih_forward}, {args[0], w_forward, r_forward, bias_forward, seq_lens, ih_forward},
gru_op.linear_before_reset, gru_op.linear_before_reset,
...@@ -440,38 +438,37 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const ...@@ -440,38 +438,37 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
if(variable_seq_len) if(variable_seq_len)
{ {
args[0] = prog.insert_instruction( args[0] =
ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens); m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
} }
auto ret_reverse = auto ret_reverse =
gru_cell(false, gru_cell(false,
prog, m,
ins, ins,
{args[0], w_reverse, r_reverse, bias_reverse, seq_lens, ih_reverse}, {args[0], w_reverse, r_reverse, bias_reverse, seq_lens, ih_reverse},
gru_op.linear_before_reset, gru_op.linear_before_reset,
actv_funcs.at(2), actv_funcs.at(2),
actv_funcs.at(3)); actv_funcs.at(3));
auto concat_output = prog.insert_instruction( auto concat_output = m.insert_instruction(
ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]); ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
last_output = last_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_output);
prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_output);
// The following logic is to ensure the last instruction rewritten // The following logic is to ensure the last instruction rewritten
// from gru operator is a concat // from gru operator is a concat
if(ret_forward[0] == prog.end()) if(ret_forward[0] == m.end())
{ {
prog.replace_instruction( m.replace_instruction(
ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]); ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
} }
else else
{ {
ret_forward[0] = prog.insert_instruction( ret_forward[0] = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]); ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
ret_reverse[0] = prog.insert_instruction( ret_reverse[0] = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]); ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
prog.replace_instruction( m.replace_instruction(
ins, make_op("concat", {{"axis", 1}}), {ret_forward[0], ret_reverse[0]}); ins, make_op("concat", {{"axis", 1}}), {ret_forward[0], ret_reverse[0]});
} }
} }
...@@ -483,7 +480,7 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const ...@@ -483,7 +480,7 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
auto r = args[2]; auto r = args[2];
// bias // bias
instruction_ref bias = prog.end(); instruction_ref bias = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 && args[3]->name() != "undefined")
{ {
bias = args[3]; bias = args[3];
...@@ -497,47 +494,46 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const ...@@ -497,47 +494,46 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
} }
else else
{ {
ih = prog.add_literal(migraphx::literal{ih_shape, data}); ih = m.add_literal(migraphx::literal{ih_shape, data});
} }
if(!is_forward and variable_seq_len) if(!is_forward and variable_seq_len)
{ {
args[0] = prog.insert_instruction( args[0] =
ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens); m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
} }
auto ret = gru_cell(is_forward, auto ret = gru_cell(is_forward,
prog, m,
ins, ins,
{args[0], w, r, bias, seq_lens, ih}, {args[0], w, r, bias, seq_lens, ih},
gru_op.linear_before_reset, gru_op.linear_before_reset,
actv_funcs.at(0), actv_funcs.at(0),
actv_funcs.at(1)); actv_funcs.at(1));
last_output = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]); last_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
if(ret[0] == prog.end()) if(ret[0] == m.end())
{ {
prog.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]); m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
} }
else else
{ {
auto concat_arg0 = is_forward ? ret[0] : ret[1]; auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0]; auto concat_arg1 = is_forward ? ret[1] : ret[0];
prog.replace_instruction( m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
} }
} }
// in case of all sequences are of the same lengths and shorter than the // in case of all sequences are of the same lengths and shorter than the
// max sequence length, need to pad 0's at the end for output hidden states // max sequence length, need to pad 0's at the end for output hidden states
ins = pad_hidden_states(prog, args[0], seq_lens, ins); ins = pad_hidden_states(m, args[0], seq_lens, ins);
replace_last_hs_output(prog, ins, seq_lens, last_output, dirct); replace_last_hs_output(m, ins, seq_lens, last_output, dirct);
} }
// NOLINTNEXTLINE(readability-function-cognitive-complexity) // NOLINTNEXTLINE(readability-function-cognitive-complexity)
std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
module& prog, module& m,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref> inputs, std::vector<instruction_ref> inputs,
int linear_before_reset, int linear_before_reset,
...@@ -552,7 +548,7 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -552,7 +548,7 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
auto seq_lens = inputs.at(4); auto seq_lens = inputs.at(4);
auto ih = inputs.at(5); auto ih = inputs.at(5);
instruction_ref hidden_states = prog.end(); instruction_ref hidden_states = m.end();
instruction_ref last_output{}; instruction_ref last_output{};
migraphx::shape seq_shape = seq->get_shape(); migraphx::shape seq_shape = seq->get_shape();
migraphx::shape r_shape = r->get_shape(); migraphx::shape r_shape = r->get_shape();
...@@ -560,127 +556,127 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -560,127 +556,127 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
migraphx::shape ss(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]}); migraphx::shape ss(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]});
std::vector<float> data(ss.elements(), 1.0f); std::vector<float> data(ss.elements(), 1.0f);
auto l1 = prog.add_literal(migraphx::literal{ss, data}); auto l1 = m.add_literal(migraphx::literal{ss, data});
// w matrix squeeze to 2-dim and do a transpose // w matrix squeeze to 2-dim and do a transpose
std::vector<int64_t> perm{1, 0}; std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w); auto sw = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tw = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw); auto tw = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
// r slide to two part, zr and h // r slide to two part, zr and h
auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r); auto sr = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto rzr = prog.insert_instruction( auto rzr = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2 * hs}}}), sr); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2 * hs}}}), sr);
auto trzr = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rzr); auto trzr = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rzr);
auto rh = prog.insert_instruction( auto rh = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), sr); ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), sr);
auto trh = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rh); auto trh = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rh);
// initial states // initial states
auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih); auto sih = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
size_t bs = ih->get_shape().lens()[1]; size_t bs = ih->get_shape().lens()[1];
// bias // bias
instruction_ref bwb{}; instruction_ref bwb{};
instruction_ref brb_zr{}; instruction_ref brb_zr{};
instruction_ref brb_h{}; instruction_ref brb_h{};
if(bias != prog.end()) if(bias != m.end())
{ {
auto sbias = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias); auto sbias = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
auto wb = prog.insert_instruction( auto wb = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {3 * hs}}}), sbias); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {3 * hs}}}), sbias);
bwb = prog.insert_instruction( bwb = m.insert_instruction(
ins, ins,
make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(3 * hs)}}}), make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(3 * hs)}}}),
wb); wb);
auto rb_zr = prog.insert_instruction( auto rb_zr = m.insert_instruction(
ins, ins,
make_op("slice", {{"axes", {0}}, {"starts", {3 * hs}}, {"ends", {5 * hs}}}), make_op("slice", {{"axes", {0}}, {"starts", {3 * hs}}, {"ends", {5 * hs}}}),
sbias); sbias);
auto rb_h = prog.insert_instruction( auto rb_h = m.insert_instruction(
ins, ins,
make_op("slice", {{"axes", {0}}, {"starts", {5 * hs}}, {"ends", {6 * hs}}}), make_op("slice", {{"axes", {0}}, {"starts", {5 * hs}}, {"ends", {6 * hs}}}),
sbias); sbias);
brb_zr = prog.insert_instruction( brb_zr = m.insert_instruction(
ins, ins,
make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(2 * hs)}}}), make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(2 * hs)}}}),
rb_zr); rb_zr);
brb_h = prog.insert_instruction( brb_h = m.insert_instruction(
ins, ins,
make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(hs)}}}), make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(hs)}}}),
rb_h); rb_h);
} }
long seq_len = get_seq_len(prog, seq, seq_lens); long seq_len = get_seq_len(m, seq, seq_lens);
for(long i = 0; i < seq_len; i++) for(long i = 0; i < seq_len; i++)
{ {
long seq_index = is_forward ? i : (seq_len - 1 - i); long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction( auto xt = m.insert_instruction(
ins, ins,
make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}), make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
seq); seq);
auto cont_xt = prog.insert_instruction(ins, make_op("contiguous"), xt); auto cont_xt = m.insert_instruction(ins, make_op("contiguous"), xt);
xt = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt); xt = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
auto xt_w = prog.insert_instruction(ins, make_op("dot"), xt, tw); auto xt_w = m.insert_instruction(ins, make_op("dot"), xt, tw);
auto ih1_rzr = prog.insert_instruction(ins, make_op("dot"), sih, trzr); auto ih1_rzr = m.insert_instruction(ins, make_op("dot"), sih, trzr);
if(bias != prog.end()) if(bias != m.end())
{ {
xt_w = prog.insert_instruction(ins, make_op("add"), xt_w, bwb); xt_w = m.insert_instruction(ins, make_op("add"), xt_w, bwb);
ih1_rzr = prog.insert_instruction(ins, make_op("add"), ih1_rzr, brb_zr); ih1_rzr = m.insert_instruction(ins, make_op("add"), ih1_rzr, brb_zr);
} }
auto xw_z = prog.insert_instruction( auto xw_z = m.insert_instruction(
ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), xt_w); ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), xt_w);
auto xw_r = prog.insert_instruction( auto xw_r = m.insert_instruction(
ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), xt_w); ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), xt_w);
auto xw_h = prog.insert_instruction( auto xw_h = m.insert_instruction(
ins, make_op("slice", {{"axes", {1}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), xt_w); ins, make_op("slice", {{"axes", {1}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), xt_w);
auto hr_z = prog.insert_instruction( auto hr_z = m.insert_instruction(
ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), ih1_rzr); ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), ih1_rzr);
auto hr_r = prog.insert_instruction( auto hr_r = m.insert_instruction(
ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), ih1_rzr); ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), ih1_rzr);
auto xw_hr_z = prog.insert_instruction(ins, make_op("add"), xw_z, hr_z); auto xw_hr_z = m.insert_instruction(ins, make_op("add"), xw_z, hr_z);
auto zt = prog.insert_instruction(ins, actv_func1, xw_hr_z); auto zt = m.insert_instruction(ins, actv_func1, xw_hr_z);
auto xw_hr_r = prog.insert_instruction(ins, make_op("add"), xw_r, hr_r); auto xw_hr_r = m.insert_instruction(ins, make_op("add"), xw_r, hr_r);
auto rt = prog.insert_instruction(ins, actv_func1, xw_hr_r); auto rt = m.insert_instruction(ins, actv_func1, xw_hr_r);
instruction_ref hr_h{}; instruction_ref hr_h{};
if(linear_before_reset == 0) if(linear_before_reset == 0)
{ {
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) // equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto rt_ht1 = prog.insert_instruction(ins, make_op("mul"), rt, sih); auto rt_ht1 = m.insert_instruction(ins, make_op("mul"), rt, sih);
hr_h = prog.insert_instruction(ins, make_op("dot"), rt_ht1, trh); hr_h = m.insert_instruction(ins, make_op("dot"), rt_ht1, trh);
if(bias != prog.end()) if(bias != m.end())
{ {
hr_h = prog.insert_instruction(ins, make_op("add"), hr_h, brb_h); hr_h = m.insert_instruction(ins, make_op("add"), hr_h, brb_h);
} }
} }
else else
{ {
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto ht1_rh = prog.insert_instruction(ins, make_op("dot"), sih, trh); auto ht1_rh = m.insert_instruction(ins, make_op("dot"), sih, trh);
if(bias != prog.end()) if(bias != m.end())
{ {
ht1_rh = prog.insert_instruction(ins, make_op("add"), ht1_rh, brb_h); ht1_rh = m.insert_instruction(ins, make_op("add"), ht1_rh, brb_h);
} }
hr_h = prog.insert_instruction(ins, make_op("mul"), rt, ht1_rh); hr_h = m.insert_instruction(ins, make_op("mul"), rt, ht1_rh);
} }
auto xw_hr_h = prog.insert_instruction(ins, make_op("add"), xw_h, hr_h); auto xw_hr_h = m.insert_instruction(ins, make_op("add"), xw_h, hr_h);
auto ht = prog.insert_instruction(ins, actv_func2, xw_hr_h); auto ht = m.insert_instruction(ins, actv_func2, xw_hr_h);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1 // equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto one_minus_zt = prog.insert_instruction(ins, make_op("sub"), l1, zt); auto one_minus_zt = m.insert_instruction(ins, make_op("sub"), l1, zt);
auto one_minus_zt_ht = prog.insert_instruction(ins, make_op("mul"), one_minus_zt, ht); auto one_minus_zt_ht = m.insert_instruction(ins, make_op("mul"), one_minus_zt, ht);
auto zt_ht1 = prog.insert_instruction(ins, make_op("mul"), zt, sih); auto zt_ht1 = m.insert_instruction(ins, make_op("mul"), zt, sih);
sih = prog.insert_instruction(ins, make_op("add"), one_minus_zt_ht, zt_ht1); sih = m.insert_instruction(ins, make_op("add"), one_minus_zt_ht, zt_ht1);
last_output = prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), sih); last_output = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), sih);
if(i < seq_len - 1) if(i < seq_len - 1)
{ {
...@@ -689,7 +685,7 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -689,7 +685,7 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
hidden_states = hidden_states =
(seq_index == 0) (seq_index == 0)
? last_output ? last_output
: prog.insert_instruction( : m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), hidden_states, last_output); ins, make_op("concat", {{"axis", 0}}), hidden_states, last_output);
} }
else else
...@@ -697,7 +693,7 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -697,7 +693,7 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
hidden_states = hidden_states =
(seq_index == seq_len - 1) (seq_index == seq_len - 1)
? last_output ? last_output
: prog.insert_instruction( : m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), last_output, hidden_states); ins, make_op("concat", {{"axis", 0}}), last_output, hidden_states);
} }
} }
...@@ -748,7 +744,7 @@ std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const ...@@ -748,7 +744,7 @@ std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
// for lstm operators // for lstm operators
// NOLINTNEXTLINE(readability-function-cognitive-complexity) // NOLINTNEXTLINE(readability-function-cognitive-complexity)
void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
{ {
assert(ins->name() == "lstm"); assert(ins->name() == "lstm");
auto args = ins->inputs(); auto args = ins->inputs();
...@@ -767,13 +763,13 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const ...@@ -767,13 +763,13 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
op::rnn_direction dirct = lstm_op.direction; op::rnn_direction dirct = lstm_op.direction;
// process sequence length // process sequence length
instruction_ref seq_lens = prog.end(); instruction_ref seq_lens = m.end();
if((args.size() >= 5) && args[4]->name() != "undefined") if((args.size() >= 5) && args[4]->name() != "undefined")
{ {
seq_lens = args[4]; seq_lens = args[4];
} }
bool variable_seq_len = is_variable_seq_lens(prog, seq_lens); bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
instruction_ref last_hs_output{}; instruction_ref last_hs_output{};
instruction_ref last_cell_output{}; instruction_ref last_cell_output{};
...@@ -783,25 +779,25 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const ...@@ -783,25 +779,25 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
{ {
// input weight matrix // input weight matrix
// input weight matrix // input weight matrix
auto w_forward = prog.insert_instruction( auto w_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
auto w_reverse = prog.insert_instruction( auto w_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]); ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
// hidden state weight matrix // hidden state weight matrix
auto r_forward = prog.insert_instruction( auto r_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
auto r_reverse = prog.insert_instruction( auto r_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]); ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
// process bias // process bias
instruction_ref bias_forward = prog.end(); instruction_ref bias_forward = m.end();
instruction_ref bias_reverse = prog.end(); instruction_ref bias_reverse = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 && args[3]->name() != "undefined")
{ {
bias_forward = prog.insert_instruction( bias_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
bias_reverse = prog.insert_instruction( bias_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]); ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
} }
...@@ -810,15 +806,15 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const ...@@ -810,15 +806,15 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
instruction_ref ih_reverse{}; instruction_ref ih_reverse{};
if(args.size() >= 6 && args[5]->name() != "undefined") if(args.size() >= 6 && args[5]->name() != "undefined")
{ {
ih_forward = prog.insert_instruction( ih_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
ih_reverse = prog.insert_instruction( ih_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]); ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
} }
else else
{ {
ih_forward = prog.add_literal(migraphx::literal{ihc_shape, ihc_data}); ih_forward = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
ih_reverse = prog.add_literal(migraphx::literal{ihc_shape, ihc_data}); ih_reverse = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
} }
// process initial cell value // process initial cell value
...@@ -826,30 +822,30 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const ...@@ -826,30 +822,30 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
instruction_ref ic_reverse{}; instruction_ref ic_reverse{};
if(args.size() >= 7 && args[6]->name() != "undefined") if(args.size() >= 7 && args[6]->name() != "undefined")
{ {
ic_forward = prog.insert_instruction( ic_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[6]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[6]);
ic_reverse = prog.insert_instruction( ic_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[6]); ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[6]);
} }
else else
{ {
ic_forward = prog.add_literal(migraphx::literal{ihc_shape, ihc_data}); ic_forward = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
ic_reverse = prog.add_literal(migraphx::literal{ihc_shape, ihc_data}); ic_reverse = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
} }
// process weight of the peephole // process weight of the peephole
instruction_ref pph_forward = prog.end(); instruction_ref pph_forward = m.end();
instruction_ref pph_reverse = prog.end(); instruction_ref pph_reverse = m.end();
if(args.size() == 8 && args[7]->name() != "undefined") if(args.size() == 8 && args[7]->name() != "undefined")
{ {
pph_forward = prog.insert_instruction( pph_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[7]); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[7]);
pph_reverse = prog.insert_instruction( pph_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[7]); ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[7]);
} }
auto ret_forward = lstm_cell(true, auto ret_forward = lstm_cell(true,
prog, m,
ins, ins,
{args[0], {args[0],
w_forward, w_forward,
...@@ -865,11 +861,11 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const ...@@ -865,11 +861,11 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
if(variable_seq_len) if(variable_seq_len)
{ {
args[0] = prog.insert_instruction( args[0] =
ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens); m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
} }
auto ret_reverse = lstm_cell(false, auto ret_reverse = lstm_cell(false,
prog, m,
ins, ins,
{args[0], {args[0],
w_reverse, w_reverse,
...@@ -883,36 +879,36 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const ...@@ -883,36 +879,36 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
actv_funcs.at(4), actv_funcs.at(4),
actv_funcs.at(5)); actv_funcs.at(5));
auto concat_hs_output = prog.insert_instruction( auto concat_hs_output = m.insert_instruction(
ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]); ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
auto concat_cell_output = prog.insert_instruction( auto concat_cell_output = m.insert_instruction(
ins, make_op("concat", {{"axis", 1}}), ret_forward[3], ret_reverse[3]); ins, make_op("concat", {{"axis", 1}}), ret_forward[3], ret_reverse[3]);
last_hs_output = last_hs_output =
prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_hs_output); m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_hs_output);
last_cell_output = last_cell_output =
prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_cell_output); m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_cell_output);
// the following logic is to ensure the last instruction is a concat // the following logic is to ensure the last instruction is a concat
if(ret_forward[0] == prog.end()) if(ret_forward[0] == m.end())
{ {
cell_outputs = concat_cell_output; cell_outputs = concat_cell_output;
} }
else else
{ {
ret_forward[1] = prog.insert_instruction( ret_forward[1] = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]); ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
ret_reverse[1] = prog.insert_instruction( ret_reverse[1] = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]); ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
ret_forward[3] = prog.insert_instruction( ret_forward[3] = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), ret_forward[2], ret_forward[3]); ins, make_op("concat", {{"axis", 0}}), ret_forward[2], ret_forward[3]);
ret_reverse[3] = prog.insert_instruction( ret_reverse[3] = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), ret_reverse[3], ret_reverse[2]); ins, make_op("concat", {{"axis", 0}}), ret_reverse[3], ret_reverse[2]);
cell_outputs = prog.insert_instruction( cell_outputs = m.insert_instruction(
ins, make_op("concat", {{"axis", 1}}), ret_forward[3], ret_reverse[3]); ins, make_op("concat", {{"axis", 1}}), ret_forward[3], ret_reverse[3]);
} }
hidden_state = prog.replace_instruction( hidden_state = m.replace_instruction(
ins, make_op("concat", {{"axis", 1}}), {ret_forward[1], ret_reverse[1]}); ins, make_op("concat", {{"axis", 1}}), {ret_forward[1], ret_reverse[1]});
} }
else else
...@@ -923,7 +919,7 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const ...@@ -923,7 +919,7 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
auto r = args[2]; auto r = args[2];
// bias // bias
instruction_ref bias = prog.end(); instruction_ref bias = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined") if(args.size() >= 4 && args[3]->name() != "undefined")
{ {
bias = args[3]; bias = args[3];
...@@ -937,7 +933,7 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const ...@@ -937,7 +933,7 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
} }
else else
{ {
ih = prog.add_literal(migraphx::literal{ihc_shape, ihc_data}); ih = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
} }
// initial cell value // initial cell value
...@@ -948,11 +944,11 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const ...@@ -948,11 +944,11 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
} }
else else
{ {
ic = prog.add_literal(migraphx::literal{ihc_shape, ihc_data}); ic = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
} }
// process weight of the peephole // process weight of the peephole
instruction_ref pph = prog.end(); instruction_ref pph = m.end();
if(args.size() == 8 && args[7]->name() != "undefined") if(args.size() == 8 && args[7]->name() != "undefined")
{ {
pph = args[7]; pph = args[7];
...@@ -960,54 +956,53 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const ...@@ -960,54 +956,53 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
if(!is_forward and variable_seq_len) if(!is_forward and variable_seq_len)
{ {
args[0] = prog.insert_instruction( args[0] =
ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens); m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
} }
auto ret = lstm_cell(is_forward, auto ret = lstm_cell(is_forward,
prog, m,
ins, ins,
{args[0], w, r, bias, seq_lens, ih, ic, pph}, {args[0], w, r, bias, seq_lens, ih, ic, pph},
actv_funcs.at(0), actv_funcs.at(0),
actv_funcs.at(1), actv_funcs.at(1),
actv_funcs.at(2)); actv_funcs.at(2));
last_hs_output = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]); last_hs_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
last_cell_output = last_cell_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[3]);
prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[3]);
if(ret[0] == prog.end()) if(ret[0] == m.end())
{ {
cell_outputs = ret[3]; cell_outputs = ret[3];
hidden_state = prog.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]); hidden_state = m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
} }
else else
{ {
auto concat_cell_arg0 = is_forward ? ret[2] : ret[3]; auto concat_cell_arg0 = is_forward ? ret[2] : ret[3];
auto concat_cell_arg1 = is_forward ? ret[3] : ret[2]; auto concat_cell_arg1 = is_forward ? ret[3] : ret[2];
cell_outputs = prog.insert_instruction( cell_outputs = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), concat_cell_arg0, concat_cell_arg1); ins, make_op("concat", {{"axis", 0}}), concat_cell_arg0, concat_cell_arg1);
auto concat_arg0 = is_forward ? ret[0] : ret[1]; auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0]; auto concat_arg1 = is_forward ? ret[1] : ret[0];
hidden_state = prog.replace_instruction( hidden_state = m.replace_instruction(
ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1); ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
} }
} }
// in case of all sequences are of the same lengths and shorter than the // in case of all sequences are of the same lengths and shorter than the
// max sequence length, need to pad 0's at the end for output hidden states // max sequence length, need to pad 0's at the end for output hidden states
hidden_state = pad_hidden_states(prog, args[0], seq_lens, hidden_state); hidden_state = pad_hidden_states(m, args[0], seq_lens, hidden_state);
// replace last hidden states with corresponding instructions // replace last hidden states with corresponding instructions
ins = replace_last_hs_output(prog, hidden_state, seq_lens, last_hs_output, dirct); ins = replace_last_hs_output(m, hidden_state, seq_lens, last_hs_output, dirct);
// replace last cell outputs with corresponding instructions // replace last cell outputs with corresponding instructions
replace_last_cell_output(prog, ins, seq_lens, cell_outputs, last_cell_output, dirct); replace_last_cell_output(m, ins, seq_lens, cell_outputs, last_cell_output, dirct);
} }
// NOLINTNEXTLINE(readability-function-cognitive-complexity) // NOLINTNEXTLINE(readability-function-cognitive-complexity)
std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
module& prog, module& m,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref> inputs, std::vector<instruction_ref> inputs,
const operation& actv_func1, const operation& actv_func1,
...@@ -1025,8 +1020,8 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1025,8 +1020,8 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto ic = inputs.at(6); auto ic = inputs.at(6);
auto pph = inputs.at(7); auto pph = inputs.at(7);
instruction_ref hidden_states = prog.end(); instruction_ref hidden_states = m.end();
instruction_ref cell_outputs = prog.end(); instruction_ref cell_outputs = m.end();
instruction_ref last_hs_output{}; instruction_ref last_hs_output{};
instruction_ref last_cell_output{}; instruction_ref last_cell_output{};
...@@ -1037,35 +1032,35 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1037,35 +1032,35 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
std::vector<int64_t> perm{1, 0}; std::vector<int64_t> perm{1, 0};
// w matrix, squeeze and transpose // w matrix, squeeze and transpose
auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w); auto sw = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tsw = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw); auto tsw = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
// r matrix, squeeze and transpose // r matrix, squeeze and transpose
auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r); auto sr = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto tsr = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr); auto tsr = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr);
// initial hidden state // initial hidden state
auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih); auto sih = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
// initial cell state // initial cell state
auto sic = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ic); auto sic = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ic);
auto ic_lens = sic->get_shape().lens(); auto ic_lens = sic->get_shape().lens();
// bias // bias
instruction_ref wrb{}; instruction_ref wrb{};
if(bias != prog.end()) if(bias != m.end())
{ {
auto sbias = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias); auto sbias = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
auto ub_wb = prog.insert_instruction( auto ub_wb = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {4 * hs}}}), sbias); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {4 * hs}}}), sbias);
auto ub_rb = prog.insert_instruction( auto ub_rb = m.insert_instruction(
ins, ins,
make_op("slice", {{"axes", {0}}, {"starts", {4 * hs}}, {"ends", {8 * hs}}}), make_op("slice", {{"axes", {0}}, {"starts", {4 * hs}}, {"ends", {8 * hs}}}),
sbias); sbias);
auto ub_wrb = prog.insert_instruction(ins, make_op("add"), ub_wb, ub_rb); auto ub_wrb = m.insert_instruction(ins, make_op("add"), ub_wb, ub_rb);
wrb = prog.insert_instruction( wrb = m.insert_instruction(
ins, ins,
make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, 4 * static_cast<size_t>(hs)}}}), make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, 4 * static_cast<size_t>(hs)}}}),
ub_wrb); ub_wrb);
...@@ -1075,92 +1070,91 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1075,92 +1070,91 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
instruction_ref pphi_brcst{}; instruction_ref pphi_brcst{};
instruction_ref ppho_brcst{}; instruction_ref ppho_brcst{};
instruction_ref pphf_brcst{}; instruction_ref pphf_brcst{};
if(pph != prog.end()) if(pph != m.end())
{ {
auto spph = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), pph); auto spph = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), pph);
auto pphi = prog.insert_instruction( auto pphi = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), spph); ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), spph);
pphi_brcst = prog.insert_instruction( pphi_brcst = m.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), pphi); ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), pphi);
auto ppho = prog.insert_instruction( auto ppho = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), spph); ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), spph);
ppho_brcst = prog.insert_instruction( ppho_brcst = m.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), ppho); ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), ppho);
auto pphf = prog.insert_instruction( auto pphf = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), spph); ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), spph);
pphf_brcst = prog.insert_instruction( pphf_brcst = m.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), pphf); ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), pphf);
} }
long seq_len = get_seq_len(prog, seq, seq_lens); long seq_len = get_seq_len(m, seq, seq_lens);
for(long i = 0; i < seq_len; ++i) for(long i = 0; i < seq_len; ++i)
{ {
long seq_index = is_forward ? i : (seq_len - 1 - i); long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction( auto xt = m.insert_instruction(
ins, ins,
make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}), make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
seq); seq);
auto cont_xt = prog.insert_instruction(ins, make_op("contiguous"), xt); auto cont_xt = m.insert_instruction(ins, make_op("contiguous"), xt);
xt = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt); xt = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
auto xt_tsw = prog.insert_instruction(ins, make_op("dot"), xt, tsw); auto xt_tsw = m.insert_instruction(ins, make_op("dot"), xt, tsw);
auto sih_tsr = prog.insert_instruction(ins, make_op("dot"), sih, tsr); auto sih_tsr = m.insert_instruction(ins, make_op("dot"), sih, tsr);
auto xt_sih = prog.insert_instruction(ins, make_op("add"), xt_tsw, sih_tsr); auto xt_sih = m.insert_instruction(ins, make_op("add"), xt_tsw, sih_tsr);
if(bias != prog.end()) if(bias != m.end())
{ {
xt_sih = prog.insert_instruction(ins, make_op("add"), xt_sih, wrb); xt_sih = m.insert_instruction(ins, make_op("add"), xt_sih, wrb);
} }
auto it_before_actv = prog.insert_instruction( auto it_before_actv = m.insert_instruction(
ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), xt_sih); ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), xt_sih);
auto ot_before_actv = prog.insert_instruction( auto ot_before_actv = m.insert_instruction(
ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), xt_sih); ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), xt_sih);
auto ft_before_actv = prog.insert_instruction( auto ft_before_actv = m.insert_instruction(
ins, ins,
make_op("slice", {{"axes", {1}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), make_op("slice", {{"axes", {1}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}),
xt_sih); xt_sih);
auto ct_before_actv = prog.insert_instruction( auto ct_before_actv = m.insert_instruction(
ins, ins,
make_op("slice", {{"axes", {1}}, {"starts", {3 * hs}}, {"ends", {4 * hs}}}), make_op("slice", {{"axes", {1}}, {"starts", {3 * hs}}, {"ends", {4 * hs}}}),
xt_sih); xt_sih);
if(pph != prog.end()) if(pph != m.end())
{ {
auto pphi_ct = prog.insert_instruction(ins, make_op("mul"), pphi_brcst, sic); auto pphi_ct = m.insert_instruction(ins, make_op("mul"), pphi_brcst, sic);
it_before_actv = prog.insert_instruction(ins, make_op("add"), it_before_actv, pphi_ct); it_before_actv = m.insert_instruction(ins, make_op("add"), it_before_actv, pphi_ct);
auto pphf_ct = prog.insert_instruction(ins, make_op("mul"), pphf_brcst, sic); auto pphf_ct = m.insert_instruction(ins, make_op("mul"), pphf_brcst, sic);
ft_before_actv = prog.insert_instruction(ins, make_op("add"), ft_before_actv, pphf_ct); ft_before_actv = m.insert_instruction(ins, make_op("add"), ft_before_actv, pphf_ct);
} }
auto it = prog.insert_instruction(ins, actv_func1, it_before_actv); auto it = m.insert_instruction(ins, actv_func1, it_before_actv);
auto ft = prog.insert_instruction(ins, actv_func1, ft_before_actv); auto ft = m.insert_instruction(ins, actv_func1, ft_before_actv);
auto ct = prog.insert_instruction(ins, actv_func2, ct_before_actv); auto ct = m.insert_instruction(ins, actv_func2, ct_before_actv);
// equation Ct = ft (.) Ct-1 + it (.) ct // equation Ct = ft (.) Ct-1 + it (.) ct
auto ft_cell = prog.insert_instruction(ins, make_op("mul"), ft, sic); auto ft_cell = m.insert_instruction(ins, make_op("mul"), ft, sic);
auto it_ct = prog.insert_instruction(ins, make_op("mul"), it, ct); auto it_ct = m.insert_instruction(ins, make_op("mul"), it, ct);
auto cellt = prog.insert_instruction(ins, make_op("add"), ft_cell, it_ct); auto cellt = m.insert_instruction(ins, make_op("add"), ft_cell, it_ct);
if(pph != prog.end()) if(pph != m.end())
{ {
auto ppho_cellt = prog.insert_instruction(ins, make_op("mul"), ppho_brcst, cellt); auto ppho_cellt = m.insert_instruction(ins, make_op("mul"), ppho_brcst, cellt);
ot_before_actv = ot_before_actv = m.insert_instruction(ins, make_op("add"), ot_before_actv, ppho_cellt);
prog.insert_instruction(ins, make_op("add"), ot_before_actv, ppho_cellt);
} }
auto ot = prog.insert_instruction(ins, actv_func1, ot_before_actv); auto ot = m.insert_instruction(ins, actv_func1, ot_before_actv);
// Ht = ot (.) h(Ct) // Ht = ot (.) h(Ct)
auto h_cellt = prog.insert_instruction(ins, actv_func3, cellt); auto h_cellt = m.insert_instruction(ins, actv_func3, cellt);
auto ht = prog.insert_instruction(ins, make_op("mul"), ot, h_cellt); auto ht = m.insert_instruction(ins, make_op("mul"), ot, h_cellt);
sic = cellt; sic = cellt;
sih = ht; sih = ht;
last_hs_output = prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), ht); last_hs_output = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), ht);
last_cell_output = last_cell_output =
prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), cellt); m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), cellt);
if(i < seq_len - 1) if(i < seq_len - 1)
{ {
...@@ -1173,12 +1167,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1173,12 +1167,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
{ {
auto concat_hs_arg0 = is_forward ? hidden_states : last_hs_output; auto concat_hs_arg0 = is_forward ? hidden_states : last_hs_output;
auto concat_hs_arg1 = is_forward ? last_hs_output : hidden_states; auto concat_hs_arg1 = is_forward ? last_hs_output : hidden_states;
hidden_states = prog.insert_instruction( hidden_states = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), concat_hs_arg0, concat_hs_arg1); ins, make_op("concat", {{"axis", 0}}), concat_hs_arg0, concat_hs_arg1);
auto concat_cell_arg0 = is_forward ? cell_outputs : last_cell_output; auto concat_cell_arg0 = is_forward ? cell_outputs : last_cell_output;
auto concat_cell_arg1 = is_forward ? last_cell_output : cell_outputs; auto concat_cell_arg1 = is_forward ? last_cell_output : cell_outputs;
cell_outputs = prog.insert_instruction( cell_outputs = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), concat_cell_arg0, concat_cell_arg1); ins, make_op("concat", {{"axis", 0}}), concat_cell_arg0, concat_cell_arg1);
} }
} }
...@@ -1266,10 +1260,10 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const ...@@ -1266,10 +1260,10 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
} }
} }
bool rewrite_rnn::is_variable_seq_lens(const module& prog, instruction_ref seq_lens) const bool rewrite_rnn::is_variable_seq_lens(const module& m, instruction_ref seq_lens) const
{ {
bool is_var_lens = false; bool is_var_lens = false;
if(seq_lens != prog.end()) if(seq_lens != m.end())
{ {
if(seq_lens->can_eval()) if(seq_lens->can_eval())
{ {
...@@ -1296,12 +1290,12 @@ bool rewrite_rnn::is_variable_seq_lens(const module& prog, instruction_ref seq_l ...@@ -1296,12 +1290,12 @@ bool rewrite_rnn::is_variable_seq_lens(const module& prog, instruction_ref seq_l
} }
std::size_t std::size_t
rewrite_rnn::get_seq_len(const module& prog, instruction_ref input, instruction_ref seq_lens) const rewrite_rnn::get_seq_len(const module& m, instruction_ref input, instruction_ref seq_lens) const
{ {
bool is_var_lens = is_variable_seq_lens(prog, seq_lens); bool is_var_lens = is_variable_seq_lens(m, seq_lens);
auto input_shape = input->get_shape(); auto input_shape = input->get_shape();
auto length = input_shape.lens()[0]; auto length = input_shape.lens()[0];
if(!is_var_lens and seq_lens != prog.end()) if(!is_var_lens and seq_lens != m.end())
{ {
auto arg_len = seq_lens->eval(); auto arg_len = seq_lens->eval();
std::vector<std::size_t> vec_lens; std::vector<std::size_t> vec_lens;
...@@ -1312,33 +1306,33 @@ rewrite_rnn::get_seq_len(const module& prog, instruction_ref input, instruction_ ...@@ -1312,33 +1306,33 @@ rewrite_rnn::get_seq_len(const module& prog, instruction_ref input, instruction_
return length; return length;
} }
instruction_ref rewrite_rnn::replace_last_hs_output(module& prog, instruction_ref rewrite_rnn::replace_last_hs_output(module& m,
instruction_ref ins, instruction_ref ins,
instruction_ref seq_lens, instruction_ref seq_lens,
instruction_ref last_hs_output, instruction_ref last_hs_output,
op::rnn_direction dirct) const op::rnn_direction dirct) const
{ {
bool variable_seq_len = is_variable_seq_lens(prog, seq_lens); bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
instruction_ref result_ins{}; instruction_ref result_ins{};
if(variable_seq_len) if(variable_seq_len)
{ {
result_ins = prog.insert_instruction( result_ins =
std::next(ins), m.insert_instruction(std::next(ins),
make_op("rnn_var_sl_shift_output", make_op("rnn_var_sl_shift_output",
{{"output_name", "hidden_states"}, {"direction", dirct}}), {{"output_name", "hidden_states"}, {"direction", dirct}}),
ins, ins,
seq_lens); seq_lens);
prog.replace_instruction(ins, result_ins); m.replace_instruction(ins, result_ins);
auto hs_outputs = find_all(result_ins->outputs(), auto hs_outputs = find_all(result_ins->outputs(),
[&](auto i) { return i->name() == "rnn_last_hs_output"; }); [&](auto i) { return i->name() == "rnn_last_hs_output"; });
for(auto& hs_out : hs_outputs) for(auto& hs_out : hs_outputs)
{ {
auto inputs = hs_out->inputs(); auto inputs = hs_out->inputs();
prog.replace_instruction(hs_out, m.replace_instruction(hs_out,
make_op("rnn_var_sl_last_output", {{"direction", dirct}}), make_op("rnn_var_sl_last_output", {{"direction", dirct}}),
inputs.front(), inputs.front(),
seq_lens); seq_lens);
} }
} }
else else
...@@ -1348,7 +1342,7 @@ instruction_ref rewrite_rnn::replace_last_hs_output(module& prog, ...@@ -1348,7 +1342,7 @@ instruction_ref rewrite_rnn::replace_last_hs_output(module& prog,
for(auto& hs_out : hs_outputs) for(auto& hs_out : hs_outputs)
{ {
prog.replace_instruction(hs_out, last_hs_output); m.replace_instruction(hs_out, last_hs_output);
} }
result_ins = ins; result_ins = ins;
...@@ -1357,14 +1351,14 @@ instruction_ref rewrite_rnn::replace_last_hs_output(module& prog, ...@@ -1357,14 +1351,14 @@ instruction_ref rewrite_rnn::replace_last_hs_output(module& prog,
return result_ins; return result_ins;
} }
void rewrite_rnn::replace_last_cell_output(module& prog, void rewrite_rnn::replace_last_cell_output(module& m,
instruction_ref ins, instruction_ref ins,
instruction_ref seq_lens, instruction_ref seq_lens,
instruction_ref cell_outputs, instruction_ref cell_outputs,
instruction_ref last_cell_output, instruction_ref last_cell_output,
op::rnn_direction dirct) const op::rnn_direction dirct) const
{ {
bool variable_seq_len = is_variable_seq_lens(prog, seq_lens); bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
auto ins_outputs = auto ins_outputs =
find_all(ins->outputs(), [&](auto i) { return i->name() == "rnn_last_cell_output"; }); find_all(ins->outputs(), [&](auto i) { return i->name() == "rnn_last_cell_output"; });
...@@ -1372,7 +1366,7 @@ void rewrite_rnn::replace_last_cell_output(module& prog, ...@@ -1372,7 +1366,7 @@ void rewrite_rnn::replace_last_cell_output(module& prog,
{ {
if(!ins_outputs.empty()) if(!ins_outputs.empty())
{ {
cell_outputs = prog.insert_instruction( cell_outputs = m.insert_instruction(
std::next(ins), std::next(ins),
make_op("rnn_var_sl_shift_output", make_op("rnn_var_sl_shift_output",
{{"output_name", "cell_outputs"}, {"direction", dirct}}), {{"output_name", "cell_outputs"}, {"direction", dirct}}),
...@@ -1382,10 +1376,10 @@ void rewrite_rnn::replace_last_cell_output(module& prog, ...@@ -1382,10 +1376,10 @@ void rewrite_rnn::replace_last_cell_output(module& prog,
for(auto co : ins_outputs) for(auto co : ins_outputs)
{ {
prog.replace_instruction(co, m.replace_instruction(co,
make_op("rnn_var_sl_last_output", {{"direction", dirct}}), make_op("rnn_var_sl_last_output", {{"direction", dirct}}),
cell_outputs, cell_outputs,
seq_lens); seq_lens);
} }
} }
// replace the rnn_last_cell_output with the last_cell_output. The while // replace the rnn_last_cell_output with the last_cell_output. The while
...@@ -1394,18 +1388,18 @@ void rewrite_rnn::replace_last_cell_output(module& prog, ...@@ -1394,18 +1388,18 @@ void rewrite_rnn::replace_last_cell_output(module& prog,
{ {
for(auto co : ins_outputs) for(auto co : ins_outputs)
{ {
prog.replace_instruction(co, last_cell_output); m.replace_instruction(co, last_cell_output);
} }
} }
} }
instruction_ref rewrite_rnn::pad_hidden_states(module& prog, instruction_ref rewrite_rnn::pad_hidden_states(module& m,
instruction_ref seq, instruction_ref seq,
instruction_ref seq_lens, instruction_ref seq_lens,
instruction_ref hs) const instruction_ref hs) const
{ {
auto max_seq_len = seq->get_shape().lens()[0]; auto max_seq_len = seq->get_shape().lens()[0];
auto seq_len = get_seq_len(prog, seq, seq_lens); auto seq_len = get_seq_len(m, seq, seq_lens);
// condition of all sequence are of the same length and // condition of all sequence are of the same length and
// less than max_seq_len, we need to append the hs outputs // less than max_seq_len, we need to append the hs outputs
...@@ -1417,23 +1411,13 @@ instruction_ref rewrite_rnn::pad_hidden_states(module& prog, ...@@ -1417,23 +1411,13 @@ instruction_ref rewrite_rnn::pad_hidden_states(module& prog,
pad_lens[0] = static_cast<std::size_t>(max_seq_len - seq_len); pad_lens[0] = static_cast<std::size_t>(max_seq_len - seq_len);
shape pad_s{s.type(), pad_lens}; shape pad_s{s.type(), pad_lens};
std::vector<float> pad_data(pad_s.elements(), 0.0f); std::vector<float> pad_data(pad_s.elements(), 0.0f);
auto pl = prog.add_literal(pad_s, pad_data.begin(), pad_data.end()); auto pl = m.add_literal(pad_s, pad_data.begin(), pad_data.end());
hs_padded = hs_padded = m.insert_instruction(std::next(hs), make_op("concat", {{"axis", 0}}), hs, pl);
prog.insert_instruction(std::next(hs), make_op("concat", {{"axis", 0}}), hs, pl); m.replace_instruction(hs, hs_padded);
prog.replace_instruction(hs, hs_padded);
} }
return hs_padded; return hs_padded;
} }
namespace op {
std::ostream& operator<<(std::ostream& os, rnn_direction v)
{
std::vector<std::string> rnn_direction_str = {"forward", "reverse", "bidirectional"};
os << rnn_direction_str[static_cast<std::underlying_type<rnn_direction>::type>(v)];
return os;
}
} // namespace op
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -42,7 +42,7 @@ struct stream_info ...@@ -42,7 +42,7 @@ struct stream_info
std::unordered_map<instruction_ref, std::size_t> iweights; std::unordered_map<instruction_ref, std::size_t> iweights;
ins_dep_map mod_implicit_deps; ins_dep_map mod_implicit_deps;
void calc_implicit_deps(const module& p) { mod_implicit_deps = p.calc_implicit_deps(); } void calc_implicit_deps(const module& m) { mod_implicit_deps = m.calc_implicit_deps(); }
void accumulate_weights(instruction_ref last, const schedule_model& model) void accumulate_weights(instruction_ref last, const schedule_model& model)
{ {
...@@ -116,15 +116,15 @@ struct stream_info ...@@ -116,15 +116,15 @@ struct stream_info
} }
}; };
std::size_t assign_streams(module& p, std::size_t n) std::size_t assign_streams(module& m, std::size_t n)
{ {
assert(n > 0); assert(n > 0);
partition critical; partition critical;
std::unordered_map<instruction_ref, std::deque<partition>> partitions; std::unordered_map<instruction_ref, std::deque<partition>> partitions;
partitions.reserve(weights.size()); partitions.reserve(weights.size());
fix([&](auto self, auto ins, auto& part) { fix([&](auto self, auto ins, auto& part) {
assert(not is_end(ins, p.end())); assert(not is_end(ins, m.end()));
if(not p.has_instruction(ins)) if(not m.has_instruction(ins))
return; return;
if(contains(partitions, ins)) if(contains(partitions, ins))
return; return;
...@@ -151,8 +151,8 @@ struct stream_info ...@@ -151,8 +151,8 @@ struct stream_info
} }
} }
// Sort instructions // Sort instructions
p.move_instruction(ins, p.end()); m.move_instruction(ins, m.end());
})(std::prev(p.end()), critical); })(std::prev(m.end()), critical);
// Set the critical partition to stream 0 // Set the critical partition to stream 0
set_stream(critical, 0); set_stream(critical, 0);
...@@ -197,13 +197,13 @@ struct stream_info ...@@ -197,13 +197,13 @@ struct stream_info
} }
}; };
void sort(module& p, std::size_t) void sort(module& m, std::size_t)
{ {
std::set<weight_ins, compare_weight_ins> children; std::set<weight_ins, compare_weight_ins> children;
std::unordered_map<instruction_ref, std::size_t> visited; std::unordered_map<instruction_ref, std::size_t> visited;
auto last = std::prev(p.end()); auto last = std::prev(m.end());
auto mw = this->weights.at(last); auto mw = this->weights.at(last);
auto nw = mw / (p.size() + 1); auto nw = mw / (m.size() + 1);
auto add_child = [&](auto ins) { auto add_child = [&](auto ins) {
auto x = 1 + (mw - this->weights.at(ins)) / (nw + 1); auto x = 1 + (mw - this->weights.at(ins)) / (nw + 1);
auto w = x * this->iweights.at(ins); auto w = x * this->iweights.at(ins);
...@@ -222,10 +222,10 @@ struct stream_info ...@@ -222,10 +222,10 @@ struct stream_info
// Pop the first element // Pop the first element
auto top = children.begin()->second; auto top = children.begin()->second;
children.erase(children.begin()); children.erase(children.begin());
p.move_instruction(top, p.begin()); m.move_instruction(top, m.begin());
for(auto ins : top->inputs()) for(auto ins : top->inputs())
{ {
if(not p.has_instruction(ins)) if(not m.has_instruction(ins))
continue; continue;
add_child(ins); add_child(ins);
} }
...@@ -234,7 +234,7 @@ struct stream_info ...@@ -234,7 +234,7 @@ struct stream_info
{ {
for(auto ins : mod_implicit_deps.at(top)) for(auto ins : mod_implicit_deps.at(top))
{ {
assert(p.has_instruction(ins)); assert(m.has_instruction(ins));
add_child(ins); add_child(ins);
} }
} }
...@@ -242,12 +242,12 @@ struct stream_info ...@@ -242,12 +242,12 @@ struct stream_info
// move dangling parameter to the front so as not be removed // move dangling parameter to the front so as not be removed
auto ins = std::next(last); auto ins = std::next(last);
while(ins != p.end()) while(ins != m.end())
{ {
auto next = std::next(ins); auto next = std::next(ins);
if(ins->name() == "@param") if(ins->name() == "@param")
{ {
p.move_instruction(ins, p.begin()); m.move_instruction(ins, m.begin());
} }
ins = next; ins = next;
} }
...@@ -364,18 +364,18 @@ struct stream_info ...@@ -364,18 +364,18 @@ struct stream_info
} }
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>>
find_concurrent_instructions(module& p) const find_concurrent_instructions(module& m) const
{ {
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result; std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result;
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> merge_from; std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> merge_from;
dominator_info di = compute_dominator(p); dominator_info di = compute_dominator(m);
result.reserve(p.size()); result.reserve(m.size());
merge_from.reserve(p.size()); merge_from.reserve(m.size());
for(auto ins : reverse_iterator_for(p)) for(auto ins : reverse_iterator_for(m))
{ {
for(auto&& arg : ins->outputs()) for(auto&& arg : ins->outputs())
{ {
if(not p.has_instruction(arg)) if(not m.has_instruction(arg))
continue; continue;
if(is_merge_point(arg)) if(is_merge_point(arg))
merge_from[ins].insert(arg); merge_from[ins].insert(arg);
...@@ -415,18 +415,18 @@ struct stream_info ...@@ -415,18 +415,18 @@ struct stream_info
} }
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>
get_conflicts(module& p) get_conflicts(module& m)
{ {
using conflict_table_type = using conflict_table_type =
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>; std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>;
conflict_table_type conflict_table; conflict_table_type conflict_table;
auto concur_ins = this->find_concurrent_instructions(p); auto concur_ins = this->find_concurrent_instructions(m);
// Compute an index for each instruction // Compute an index for each instruction
std::unordered_map<instruction_ref, std::size_t> ins2index; std::unordered_map<instruction_ref, std::size_t> ins2index;
std::size_t index_total = 0; std::size_t index_total = 0;
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(m))
ins2index[ins] = index_total++; ins2index[ins] = index_total++;
std::vector<conflict_table_type> thread_conflict_tables( std::vector<conflict_table_type> thread_conflict_tables(
...@@ -507,21 +507,21 @@ struct stream_info ...@@ -507,21 +507,21 @@ struct stream_info
} }
}; };
void schedule::apply(module& p) const void schedule::apply(module& m) const
{ {
if(not enable) if(not enable)
return; return;
stream_info si; stream_info si;
si.calc_implicit_deps(p); si.calc_implicit_deps(m);
auto last = std::prev(p.end()); auto last = std::prev(m.end());
si.accumulate_weights(last, model); si.accumulate_weights(last, model);
auto nstreams = si.assign_streams(p, model.concurrency()); auto nstreams = si.assign_streams(m, model.concurrency());
si.sort(p, model.concurrency()); si.sort(m, model.concurrency());
if(enabled(MIGRAPHX_TRACE_COMPILE{}) or enabled(MIGRAPHX_TRACE_SCHEDULE{})) if(enabled(MIGRAPHX_TRACE_COMPILE{}) or enabled(MIGRAPHX_TRACE_SCHEDULE{}))
{ {
p.annotate(std::cout, [&](auto ins) { m.annotate(std::cout, [&](auto ins) {
if(ins->name() == "@param" and not contains(si.weights, ins)) if(ins->name() == "@param" and not contains(si.weights, ins))
return; return;
...@@ -548,9 +548,9 @@ void schedule::apply(module& p) const ...@@ -548,9 +548,9 @@ void schedule::apply(module& p) const
std::unordered_map<instruction_ref, std::size_t> ins2wait; std::unordered_map<instruction_ref, std::size_t> ins2wait;
std::unordered_map<std::size_t, std::unordered_set<std::size_t>> waited_for; std::unordered_map<std::size_t, std::unordered_set<std::size_t>> waited_for;
std::unordered_map<instruction_ref, std::unordered_set<std::size_t>> ins2waited; std::unordered_map<instruction_ref, std::unordered_set<std::size_t>> ins2waited;
ins2wait.reserve(p.size()); ins2wait.reserve(m.size());
ins2waited.reserve(p.size()); ins2waited.reserve(m.size());
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(m))
{ {
// Only schedule instructions that have a stream // Only schedule instructions that have a stream
if(not si.has_stream(ins)) if(not si.has_stream(ins))
...@@ -559,7 +559,7 @@ void schedule::apply(module& p) const ...@@ -559,7 +559,7 @@ void schedule::apply(module& p) const
// Schedule instruction on the stream // Schedule instruction on the stream
auto stream = si.get_stream(ins); auto stream = si.get_stream(ins);
assert(stream < model.concurrency()); assert(stream < model.concurrency());
model.sched(p, ins, stream); model.sched(m, ins, stream);
// Insert wait instructions // Insert wait instructions
if(si.is_merge_point(ins, stream)) if(si.is_merge_point(ins, stream))
{ {
...@@ -572,14 +572,14 @@ void schedule::apply(module& p) const ...@@ -572,14 +572,14 @@ void schedule::apply(module& p) const
if(not contains(ins2wait, i)) if(not contains(ins2wait, i))
{ {
ins2wait[i] = wait_id; ins2wait[i] = wait_id;
model.record(p, i, wait_id); model.record(m, i, wait_id);
wait_id++; wait_id++;
} }
auto w = ins2wait.at(i); auto w = ins2wait.at(i);
// If we already waited for the event on this stream then dont // If we already waited for the event on this stream then dont
// insert another wait event // insert another wait event
if(not contains(waited_for[stream], w)) if(not contains(waited_for[stream], w))
model.wait(p, ins, w); model.wait(m, ins, w);
// Store the event as waited // Store the event as waited
waited_for[stream].insert(w); waited_for[stream].insert(w);
// Store all wait events that have been waited on prior to the recorded instruction // Store all wait events that have been waited on prior to the recorded instruction
...@@ -594,7 +594,7 @@ void schedule::apply(module& p) const ...@@ -594,7 +594,7 @@ void schedule::apply(module& p) const
} }
// Add memory conflicts // Add memory conflicts
auto conflict_table = si.get_conflicts(p); auto conflict_table = si.get_conflicts(m);
for(auto&& ip : conflict_table) for(auto&& ip : conflict_table)
{ {
if(ip.second.empty()) if(ip.second.empty())
...@@ -602,7 +602,7 @@ void schedule::apply(module& p) const ...@@ -602,7 +602,7 @@ void schedule::apply(module& p) const
std::vector<instruction_ref> args; std::vector<instruction_ref> args;
args.push_back(ip.first); args.push_back(ip.first);
args.insert(args.end(), ip.second.begin(), ip.second.end()); args.insert(args.end(), ip.second.begin(), ip.second.end());
p.insert_instruction(std::next(ip.first), make_op("identity"), args); m.insert_instruction(std::next(ip.first), make_op("identity"), args);
} }
} }
......
...@@ -86,6 +86,8 @@ struct shape_impl ...@@ -86,6 +86,8 @@ struct shape_impl
return std::accumulate( return std::accumulate(
m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>()); m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
} }
std::shared_ptr<shape_impl> copy() const { return std::make_shared<shape_impl>(*this); }
}; };
const std::vector<shape::type_t>& shape::types() const std::vector<shape::type_t>& shape::types()
...@@ -135,6 +137,8 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s) ...@@ -135,6 +137,8 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {} shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {}
shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {}
shape shape::from_permutation(type_t t, shape shape::from_permutation(type_t t,
const std::vector<std::size_t>& l, const std::vector<std::size_t>& l,
const std::vector<int64_t>& perm) const std::vector<int64_t>& perm)
...@@ -294,6 +298,13 @@ shape shape::with_lens(const std::vector<std::size_t>& l) const ...@@ -294,6 +298,13 @@ shape shape::with_lens(const std::vector<std::size_t>& l) const
return this->with_lens(this->type(), l); return this->with_lens(this->type(), l);
} }
shape shape::with_type(type_t t) const
{
auto c = impl->copy();
c->m_type = t;
return {c};
}
std::size_t shape::element_space() const { return impl->element_space(); } std::size_t shape::element_space() const { return impl->element_space(); }
std::string shape::type_string() const { return name(this->type()); } std::string shape::type_string() const { return name(this->type()); }
......
...@@ -42,7 +42,7 @@ struct find_mul_conv ...@@ -42,7 +42,7 @@ struct find_mul_conv
match::name("broadcast").bind("a"))); match::name("broadcast").bind("a")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto conv_ins = r.instructions["conv"]; auto conv_ins = r.instructions["conv"];
...@@ -53,14 +53,14 @@ struct find_mul_conv ...@@ -53,14 +53,14 @@ struct find_mul_conv
if(broadcast_op.axis != 1) if(broadcast_op.axis != 1)
return; return;
auto new_a = p.insert_instruction( auto new_a = m.insert_instruction(
ins, ins,
make_op("broadcast", {{"axis", 0}, {"out_lens", w_ins->get_shape().lens()}}), make_op("broadcast", {{"axis", 0}, {"out_lens", w_ins->get_shape().lens()}}),
a_ins->inputs().front()); a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, make_op("mul"), new_a, w_ins); auto new_mul = m.insert_instruction(ins, make_op("mul"), new_a, w_ins);
auto new_conv = p.insert_instruction( auto new_conv = m.insert_instruction(
ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_mul); ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_mul);
p.replace_instruction(ins, new_conv); m.replace_instruction(ins, new_conv);
} }
}; };
...@@ -80,7 +80,7 @@ struct find_mul_slice_conv ...@@ -80,7 +80,7 @@ struct find_mul_slice_conv
match::name("broadcast")(match::is_constant()).bind("a"))); match::name("broadcast")(match::is_constant()).bind("a")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto slice_ins = r.instructions["slice"]; auto slice_ins = r.instructions["slice"];
...@@ -116,38 +116,38 @@ struct find_mul_slice_conv ...@@ -116,38 +116,38 @@ struct find_mul_slice_conv
auto w_slice_op = slice_op; auto w_slice_op = slice_op;
w_slice_op.axes = {0}; w_slice_op.axes = {0};
auto slice_w_ins = p.insert_instruction(ins, w_slice_op, w_ins); auto slice_w_ins = m.insert_instruction(ins, w_slice_op, w_ins);
auto new_a = p.insert_instruction( auto new_a = m.insert_instruction(
ins, ins,
make_op("broadcast", {{"axis", 0}, {"out_lens", slice_w_ins->get_shape().lens()}}), make_op("broadcast", {{"axis", 0}, {"out_lens", slice_w_ins->get_shape().lens()}}),
a_ins->inputs().front()); a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, make_op("mul"), new_a, slice_w_ins); auto new_mul = m.insert_instruction(ins, make_op("mul"), new_a, slice_w_ins);
std::vector<instruction_ref> sliced_weights; std::vector<instruction_ref> sliced_weights;
if(slice_op.starts.front() != 0) if(slice_op.starts.front() != 0)
sliced_weights.push_back(p.insert_instruction( sliced_weights.push_back(m.insert_instruction(
ins, ins,
make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", slice_op.starts}}), make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", slice_op.starts}}),
w_ins)); w_ins));
sliced_weights.push_back(new_mul); sliced_weights.push_back(new_mul);
int64_t end_axis = w_ins->get_shape().lens().at(0); int64_t end_axis = w_ins->get_shape().lens().at(0);
if(slice_op.ends.front() != end_axis) if(slice_op.ends.front() != end_axis)
sliced_weights.push_back(p.insert_instruction( sliced_weights.push_back(m.insert_instruction(
ins, ins,
make_op("slice", {{"axes", {0}}, {"starts", slice_op.ends}, {"ends", {end_axis}}}), make_op("slice", {{"axes", {0}}, {"starts", slice_op.ends}, {"ends", {end_axis}}}),
w_ins)); w_ins));
auto new_weights = auto new_weights =
p.insert_instruction(ins, make_op("concat", {{"axis", 0}}), sliced_weights); m.insert_instruction(ins, make_op("concat", {{"axis", 0}}), sliced_weights);
auto new_conv = p.insert_instruction( auto new_conv = m.insert_instruction(
ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_weights); ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_weights);
assert(conv_ins->get_shape() == new_conv->get_shape()); assert(conv_ins->get_shape() == new_conv->get_shape());
auto slice1 = p.insert_instruction(ins, slice_op, new_conv); auto slice1 = m.insert_instruction(ins, slice_op, new_conv);
assert(ins->get_shape().lens() == slice1->get_shape().lens()); assert(ins->get_shape().lens() == slice1->get_shape().lens());
p.replace_instruction(ins, slice1); m.replace_instruction(ins, slice1);
// TODO: Check each slice doesn't overlap and that it occurs after slice_ins // TODO: Check each slice doesn't overlap and that it occurs after slice_ins
auto outputs = conv_ins->outputs(); auto outputs = conv_ins->outputs();
for(auto output : outputs) for(auto output : outputs)
...@@ -171,7 +171,7 @@ struct find_mul_add ...@@ -171,7 +171,7 @@ struct find_mul_add
match::is_constant().bind("a"))); match::is_constant().bind("a")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto a_ins = r.instructions["a"]; auto a_ins = r.instructions["a"];
...@@ -179,9 +179,9 @@ struct find_mul_add ...@@ -179,9 +179,9 @@ struct find_mul_add
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
assert(x_ins != b_ins); assert(x_ins != b_ins);
auto ax_ins = p.insert_instruction(ins, make_op("mul"), a_ins, x_ins); auto ax_ins = m.insert_instruction(ins, make_op("mul"), a_ins, x_ins);
auto ab_ins = p.insert_instruction(ins, make_op("mul"), a_ins, b_ins); auto ab_ins = m.insert_instruction(ins, make_op("mul"), a_ins, b_ins);
p.replace_instruction(ins, make_op("add"), ax_ins, ab_ins); m.replace_instruction(ins, make_op("add"), ax_ins, ab_ins);
} }
}; };
...@@ -193,15 +193,15 @@ struct find_add_lit_broadcast ...@@ -193,15 +193,15 @@ struct find_add_lit_broadcast
match::either_arg(0, 1)(op_lit_broadcast("add", "a", "x"), lit_broadcast().bind("b"))); match::either_arg(0, 1)(op_lit_broadcast("add", "a", "x"), lit_broadcast().bind("b")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
auto a_ins = r.instructions["a"]; auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"]; auto b_ins = r.instructions["b"];
auto sumab = p.insert_instruction(ins, make_op("add"), a_ins, b_ins); auto sumab = m.insert_instruction(ins, make_op("add"), a_ins, b_ins);
p.replace_instruction(ins, make_op("add"), x_ins, sumab); m.replace_instruction(ins, make_op("add"), x_ins, sumab);
} }
}; };
...@@ -213,7 +213,7 @@ struct find_double_add_lit_broadcast ...@@ -213,7 +213,7 @@ struct find_double_add_lit_broadcast
match::args(op_lit_broadcast("add", "a", "x"), op_lit_broadcast("add", "b", "y"))); match::args(op_lit_broadcast("add", "a", "x"), op_lit_broadcast("add", "b", "y")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
...@@ -228,17 +228,17 @@ struct find_double_add_lit_broadcast ...@@ -228,17 +228,17 @@ struct find_double_add_lit_broadcast
if(a_ins->inputs().at(0)->get_shape() != b_ins->inputs().at(0)->get_shape()) if(a_ins->inputs().at(0)->get_shape() != b_ins->inputs().at(0)->get_shape())
return; return;
auto op = a_ins->get_operator(); auto op = a_ins->get_operator();
auto presum = p.insert_instruction( auto presum = m.insert_instruction(
ins, make_op("add"), a_ins->inputs().at(0), b_ins->inputs().at(0)); ins, make_op("add"), a_ins->inputs().at(0), b_ins->inputs().at(0));
sumab = p.insert_instruction(ins, op, presum); sumab = m.insert_instruction(ins, op, presum);
} }
else else
{ {
sumab = p.insert_instruction(ins, make_op("add"), a_ins, b_ins); sumab = m.insert_instruction(ins, make_op("add"), a_ins, b_ins);
} }
auto sumxy = p.insert_instruction(ins, make_op("add"), x_ins, y_ins); auto sumxy = m.insert_instruction(ins, make_op("add"), x_ins, y_ins);
p.replace_instruction(ins, make_op("add"), sumxy, sumab); m.replace_instruction(ins, make_op("add"), sumxy, sumab);
} }
}; };
...@@ -251,7 +251,7 @@ struct find_inner_broadcast ...@@ -251,7 +251,7 @@ struct find_inner_broadcast
match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y"))); match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
...@@ -263,9 +263,9 @@ struct find_inner_broadcast ...@@ -263,9 +263,9 @@ struct find_inner_broadcast
if(xbroadcast.axis != ybroadcast.axis) if(xbroadcast.axis != ybroadcast.axis)
return; return;
auto op = p.insert_instruction( auto op = m.insert_instruction(
ins, ins->get_operator(), x_ins->inputs().front(), y_ins->inputs().front()); ins, ins->get_operator(), x_ins->inputs().front(), y_ins->inputs().front());
p.replace_instruction(ins, xbroadcast, op); m.replace_instruction(ins, xbroadcast, op);
} }
}; };
...@@ -296,7 +296,7 @@ struct find_concat_op ...@@ -296,7 +296,7 @@ struct find_concat_op
return op.name() == "broadcast" or op.attributes().contains("pointwise"); return op.name() == "broadcast" or op.attributes().contains("pointwise");
} }
void apply(module& p, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto axis = any_cast<op::concat>(ins->get_operator()).axis; auto axis = any_cast<op::concat>(ins->get_operator()).axis;
...@@ -330,12 +330,11 @@ struct find_concat_op ...@@ -330,12 +330,11 @@ struct find_concat_op
return j->inputs().at(i); return j->inputs().at(i);
}); });
auto concat = auto concat =
p.insert_instruction(ins, make_op("concat", {{"axis", iaxis}}), inputs); m.insert_instruction(ins, make_op("concat", {{"axis", iaxis}}), inputs);
concats.push_back(concat); concats.push_back(concat);
} }
auto y = p.insert_instruction(ins, op, concats); auto y = m.insert_instruction(ins, op, concats);
return {y}; return {y};
}; };
std::vector<instruction_ref> args; std::vector<instruction_ref> args;
...@@ -350,9 +349,9 @@ struct find_concat_op ...@@ -350,9 +349,9 @@ struct find_concat_op
}; };
group_unique(ins->inputs().begin(), ins->inputs().end(), update_args, pred); group_unique(ins->inputs().begin(), ins->inputs().end(), update_args, pred);
if(args.size() == 1) if(args.size() == 1)
p.replace_instruction(ins, args.front()); m.replace_instruction(ins, args.front());
else else
p.replace_instruction(ins, make_op("concat", {{"axis", axis}}), args); m.replace_instruction(ins, make_op("concat", {{"axis", axis}}), args);
} }
}; };
...@@ -479,14 +478,14 @@ struct find_splits ...@@ -479,14 +478,14 @@ struct find_splits
return true; return true;
} }
void apply(module& p, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto splits = get_splits(ins); auto splits = get_splits(ins);
if(splits.empty()) if(splits.empty())
return; return;
for(const auto& group : get_split_groups(p, splits)) for(const auto& group : get_split_groups(m, splits))
{ {
auto start = group.front(); auto start = group.front();
auto split_front = splits.front(); auto split_front = splits.front();
...@@ -501,10 +500,10 @@ struct find_splits ...@@ -501,10 +500,10 @@ struct find_splits
std::next(group.begin()), group.end(), [&](auto i) { return i == start; })); std::next(group.begin()), group.end(), [&](auto i) { return i == start; }));
auto split_idx = 0; auto split_idx = 0;
instruction_ref c = p.end(); instruction_ref c = m.end();
if(start->inputs().size() == 1) if(start->inputs().size() == 1)
{ {
c = p.insert_instruction(std::next(ins), op, ins); c = m.insert_instruction(std::next(ins), op, ins);
} }
else if(start->inputs().size() == 2) else if(start->inputs().size() == 2)
{ {
...@@ -531,7 +530,7 @@ struct find_splits ...@@ -531,7 +530,7 @@ struct find_splits
return; return;
for(auto data : data_args) for(auto data : data_args)
p.move_instructions(data, ins); m.move_instructions(data, ins);
auto slice_op = any_cast<op::slice>(splits.front()->get_operator()); auto slice_op = any_cast<op::slice>(splits.front()->get_operator());
assert(not slice_op.axes.empty()); assert(not slice_op.axes.empty());
...@@ -539,16 +538,16 @@ struct find_splits ...@@ -539,16 +538,16 @@ struct find_splits
return; return;
auto concat_axis = slice_op.axes.front(); auto concat_axis = slice_op.axes.front();
// TODO: Check if axises match // TODO: Check if axises match
auto concat = p.insert_instruction( auto concat = m.insert_instruction(
ins, make_op("concat", {{"axis", concat_axis}}), data_args); ins, make_op("concat", {{"axis", concat_axis}}), data_args);
std::vector<instruction_ref> args; std::vector<instruction_ref> args;
args.resize(2); args.resize(2);
args[split_idx] = ins; args[split_idx] = ins;
args[data_idx] = concat; args[data_idx] = concat;
c = p.insert_instruction(std::next(ins), op, args); c = m.insert_instruction(std::next(ins), op, args);
} }
if(c != p.end()) if(c != m.end())
{ {
for(auto i : group) for(auto i : group)
{ {
...@@ -561,11 +560,11 @@ struct find_splits ...@@ -561,11 +560,11 @@ struct find_splits
if(not contains({"reshape", "squeeze", "unsqueeze"}, output->name())) if(not contains({"reshape", "squeeze", "unsqueeze"}, output->name()))
continue; continue;
auto x = auto x =
p.insert_instruction(output, make_op("contiguous"), output->inputs()); m.insert_instruction(output, make_op("contiguous"), output->inputs());
p.replace_instruction(output, output->get_operator(), x); m.replace_instruction(output, output->get_operator(), x);
} }
p.replace_instruction(i, split->get_operator(), c); m.replace_instruction(i, split->get_operator(), c);
} }
} }
} }
...@@ -580,7 +579,7 @@ struct find_split_concat ...@@ -580,7 +579,7 @@ struct find_split_concat
match::name("slice")(match::all_of[match::outputs()](match::name("concat"))))); match::name("slice")(match::all_of[match::outputs()](match::name("concat")))));
} }
void apply(module& p, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
...@@ -620,9 +619,9 @@ struct find_split_concat ...@@ -620,9 +619,9 @@ struct find_split_concat
args.erase(std::next(it), it + splits.size()); args.erase(std::next(it), it + splits.size());
if(args.size() == 1) if(args.size() == 1)
p.replace_instruction(concat, args.front()); m.replace_instruction(concat, args.front());
else else
p.replace_instruction(concat, concat->get_operator(), args); m.replace_instruction(concat, concat->get_operator(), args);
} }
}; };
...@@ -665,7 +664,7 @@ struct find_add_convs ...@@ -665,7 +664,7 @@ struct find_add_convs
return x.stride[0] / y.stride[0]; return x.stride[0] / y.stride[0];
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto a_conv = r.instructions["a"]; auto a_conv = r.instructions["a"];
...@@ -694,7 +693,7 @@ struct find_add_convs ...@@ -694,7 +693,7 @@ struct find_add_convs
if(n == 0) if(n == 0)
return; return;
new_op = a_op; new_op = a_op;
b_input = p.insert_instruction( b_input = m.insert_instruction(
ins, make_op("step", {{"axes", {2, 3}}, {"steps", {n, n}}}), b_input); ins, make_op("step", {{"axes", {2, 3}}, {"steps", {n, n}}}), b_input);
} }
else if(b_op.stride < a_op.stride) else if(b_op.stride < a_op.stride)
...@@ -703,7 +702,7 @@ struct find_add_convs ...@@ -703,7 +702,7 @@ struct find_add_convs
if(n == 0) if(n == 0)
return; return;
new_op = b_op; new_op = b_op;
a_input = p.insert_instruction( a_input = m.insert_instruction(
ins, make_op("step", {{"axes", {2, 3}}, {"steps", {n, n}}}), a_input); ins, make_op("step", {{"axes", {2, 3}}, {"steps", {n, n}}}), a_input);
} }
else else
...@@ -714,10 +713,10 @@ struct find_add_convs ...@@ -714,10 +713,10 @@ struct find_add_convs
} }
auto concat_input = auto concat_input =
p.insert_instruction(ins, make_op("concat", {{"axis", 1}}), a_input, b_input); m.insert_instruction(ins, make_op("concat", {{"axis", 1}}), a_input, b_input);
auto concat_weights = auto concat_weights =
p.insert_instruction(ins, make_op("concat", {{"axis", 1}}), a_weights, b_weights); m.insert_instruction(ins, make_op("concat", {{"axis", 1}}), a_weights, b_weights);
p.replace_instruction(ins, new_op, concat_input, concat_weights); m.replace_instruction(ins, new_op, concat_input, concat_weights);
} }
}; };
...@@ -738,7 +737,7 @@ struct find_conv_dot_horiz_fusion ...@@ -738,7 +737,7 @@ struct find_conv_dot_horiz_fusion
{ {
auto matcher() const { return horiz_conv_dot(); } auto matcher() const { return horiz_conv_dot(); }
void apply(module& p, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
...@@ -786,16 +785,16 @@ struct find_conv_dot_horiz_fusion ...@@ -786,16 +785,16 @@ struct find_conv_dot_horiz_fusion
} }
for(auto arg : args) for(auto arg : args)
p.move_instructions(arg, input); m.move_instructions(arg, input);
// TODO: Check if axises match // TODO: Check if axises match
auto concat = auto concat =
p.insert_instruction(input, make_op("concat", {{"axis", concat_axis}}), args); m.insert_instruction(input, make_op("concat", {{"axis", concat_axis}}), args);
auto fused = p.insert_instruction(std::next(input), op, input, concat); auto fused = m.insert_instruction(std::next(input), op, input, concat);
int64_t offset = 0; int64_t offset = 0;
for(auto arg : range(start, last)) for(auto arg : range(start, last))
{ {
int64_t len = arg->get_shape().lens()[axis]; int64_t len = arg->get_shape().lens()[axis];
p.replace_instruction( m.replace_instruction(
arg, arg,
make_op("slice", make_op("slice",
{{"axes", {axis}}, {"starts", {offset}}, {"ends", {offset + len}}}), {{"axes", {axis}}, {"starts", {offset}}, {"ends", {offset + len}}}),
...@@ -816,16 +815,16 @@ struct find_div_const ...@@ -816,16 +815,16 @@ struct find_div_const
return match::name("div")(match::arg(1)(match::is_constant().bind("c"))); return match::name("div")(match::arg(1)(match::is_constant().bind("c")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto c_ins = r.instructions["c"]; auto c_ins = r.instructions["c"];
auto recip = p.insert_instruction(std::next(c_ins), make_op("recip"), c_ins); auto recip = m.insert_instruction(std::next(c_ins), make_op("recip"), c_ins);
auto args = ins->inputs(); auto args = ins->inputs();
p.replace_instruction(ins, make_op("mul"), args.front(), recip); m.replace_instruction(ins, make_op("mul"), args.front(), recip);
} }
}; };
...@@ -836,16 +835,16 @@ struct find_sub_const ...@@ -836,16 +835,16 @@ struct find_sub_const
return match::name("sub")(match::arg(1)(match::is_constant().bind("c"))); return match::name("sub")(match::arg(1)(match::is_constant().bind("c")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto c_ins = r.instructions["c"]; auto c_ins = r.instructions["c"];
auto neg = p.insert_instruction(std::next(c_ins), make_op("neg"), c_ins); auto neg = m.insert_instruction(std::next(c_ins), make_op("neg"), c_ins);
auto args = ins->inputs(); auto args = ins->inputs();
p.replace_instruction(ins, make_op("add"), args.front(), neg); m.replace_instruction(ins, make_op("add"), args.front(), neg);
} }
}; };
...@@ -857,12 +856,12 @@ struct find_rsqrt ...@@ -857,12 +856,12 @@ struct find_rsqrt
match::name("sqrt")(match::used_once(), match::args(match::any().bind("x"))))); match::name("sqrt")(match::used_once(), match::args(match::any().bind("x")))));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
p.replace_instruction(ins, make_op("rsqrt"), x_ins); m.replace_instruction(ins, make_op("rsqrt"), x_ins);
} }
}; };
...@@ -882,7 +881,7 @@ struct find_split_reshape ...@@ -882,7 +881,7 @@ struct find_split_reshape
.bind("reshape"); .bind("reshape");
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto slc = r.instructions["slice"]; auto slc = r.instructions["slice"];
auto rsp = r.instructions["reshape"]; auto rsp = r.instructions["reshape"];
...@@ -937,14 +936,14 @@ struct find_split_reshape ...@@ -937,14 +936,14 @@ struct find_split_reshape
rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0}); rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0});
// insert the reshape instruction // insert the reshape instruction
auto rsp_ins = p.insert_instruction( auto rsp_ins = m.insert_instruction(
std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input); std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input);
// replace the original reshape with slice // replace the original reshape with slice
int64_t start = 0; int64_t start = 0;
for(std::size_t i = 0; i < vec_rsp.size(); ++i) for(std::size_t i = 0; i < vec_rsp.size(); ++i)
{ {
p.replace_instruction( m.replace_instruction(
vec_rsp[i], vec_rsp[i],
make_op( make_op(
"slice", "slice",
...@@ -963,7 +962,7 @@ struct find_split_transpose ...@@ -963,7 +962,7 @@ struct find_split_transpose
.bind("trans"); .bind("trans");
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto slc = r.instructions["slice"]; auto slc = r.instructions["slice"];
auto trans = r.instructions["trans"]; auto trans = r.instructions["trans"];
...@@ -989,14 +988,14 @@ struct find_split_transpose ...@@ -989,14 +988,14 @@ struct find_split_transpose
} }
// insert an transpose instruction // insert an transpose instruction
auto tr = p.insert_instruction( auto tr = m.insert_instruction(
std::next(input), make_op("transpose", {{"permutation", perm}}), input); std::next(input), make_op("transpose", {{"permutation", perm}}), input);
// compute the axis in the slice // compute the axis in the slice
auto axis = any_cast<op::slice>(slc->get_operator()).axes.front(); auto axis = any_cast<op::slice>(slc->get_operator()).axes.front();
auto it = std::find(perm.begin(), perm.end(), axis); auto it = std::find(perm.begin(), perm.end(), axis);
assert(it != perm.end()); assert(it != perm.end());
auto axis_new = static_cast<int64_t>(std::distance(perm.begin(), it)); int64_t axis_new = std::distance(perm.begin(), it);
for(auto in : split_outputs) for(auto in : split_outputs)
{ {
...@@ -1004,7 +1003,7 @@ struct find_split_transpose ...@@ -1004,7 +1003,7 @@ struct find_split_transpose
auto starts = oper.starts; auto starts = oper.starts;
auto ends = oper.ends; auto ends = oper.ends;
auto tr_orig = in->outputs().front(); auto tr_orig = in->outputs().front();
p.replace_instruction( m.replace_instruction(
tr_orig, tr_orig,
make_op("slice", {{"axes", {axis_new}}, {"starts", starts}, {"ends", ends}}), make_op("slice", {{"axes", {axis_new}}, {"starts", starts}, {"ends", ends}}),
tr); tr);
...@@ -1012,12 +1011,12 @@ struct find_split_transpose ...@@ -1012,12 +1011,12 @@ struct find_split_transpose
} }
}; };
void simplify_algebra::apply(module& p) const void simplify_algebra::apply(module& m) const
{ {
// Run simplifications multiple times // Run simplifications multiple times
for(int i = 0; i < 8; i++) for(int i = 0; i < 8; i++)
{ {
match::find_matches(p, match::find_matches(m,
find_inner_broadcast{}, find_inner_broadcast{},
find_double_add_lit_broadcast{}, find_double_add_lit_broadcast{},
find_add_lit_broadcast{}, find_add_lit_broadcast{},
...@@ -1034,7 +1033,7 @@ void simplify_algebra::apply(module& p) const ...@@ -1034,7 +1033,7 @@ void simplify_algebra::apply(module& p) const
find_splits{}, find_splits{},
find_split_reshape{}, find_split_reshape{},
find_split_transpose{}); find_split_transpose{});
dead_code_elimination{}.apply(p); dead_code_elimination{}.apply(m);
} }
} }
......
...@@ -53,7 +53,7 @@ struct match_find_quantizable_ops ...@@ -53,7 +53,7 @@ struct match_find_quantizable_ops
match::arg(1)(dequantizelinear_op("x2", "scale2"))); match::arg(1)(dequantizelinear_op("x2", "scale2")));
} }
void apply(module& m, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto qop = r.result; auto qop = r.result;
auto q1 = r.instructions["x1"]; auto q1 = r.instructions["x1"];
......
...@@ -70,19 +70,19 @@ struct find_reshaper ...@@ -70,19 +70,19 @@ struct find_reshaper
match::any_of[match::outputs()](match::name(reshaper_names()))); match::any_of[match::outputs()](match::name(reshaper_names())));
} }
void apply(module& p, const match::matcher_result& mr) const void apply(module& m, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
std::vector<instruction_ref> reshapes{ins}; std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back())) while(is_reshaper(reshapes.back()))
{ {
assert(!reshapes.back()->inputs().empty()); assert(!reshapes.back()->inputs().empty());
assert(p.has_instruction(reshapes.back()->inputs().front())); assert(m.has_instruction(reshapes.back()->inputs().front()));
auto input = reshapes.back()->inputs().front(); auto input = reshapes.back()->inputs().front();
reshapes.push_back(input); reshapes.push_back(input);
} }
std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()}; std::pair<instruction_ref, instruction_ref> r{m.end(), m.end()};
for(auto start : iterator_for(reshapes)) for(auto start : iterator_for(reshapes))
{ {
auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) { auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
...@@ -96,7 +96,7 @@ struct find_reshaper ...@@ -96,7 +96,7 @@ struct find_reshaper
} }
if(r.first != r.second) if(r.first != r.second)
{ {
p.replace_instruction(r.first, r.second); m.replace_instruction(r.first, r.second);
} }
} }
}; };
...@@ -117,10 +117,10 @@ struct find_nop_reshapes ...@@ -117,10 +117,10 @@ struct find_nop_reshapes
return match::name(reshapes)(match::same_shape(match::arg(0))); return match::name(reshapes)(match::same_shape(match::arg(0)));
} }
void apply(module& p, const match::matcher_result& mr) const void apply(module& m, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
p.replace_instruction(ins, ins->inputs().front()); m.replace_instruction(ins, ins->inputs().front());
} }
}; };
...@@ -132,7 +132,7 @@ struct find_transpose ...@@ -132,7 +132,7 @@ struct find_transpose
match::skip_output(match::name("contiguous"))(match::name("transpose")))); match::skip_output(match::name("contiguous"))(match::name("transpose"))));
} }
void apply(module& p, const match::matcher_result& mr) const void apply(module& m, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
auto x = ins; auto x = ins;
...@@ -149,11 +149,11 @@ struct find_transpose ...@@ -149,11 +149,11 @@ struct find_transpose
return; return;
if(is_no_transpose(dims)) if(is_no_transpose(dims))
{ {
p.replace_instruction(ins, t->inputs().front()); m.replace_instruction(ins, t->inputs().front());
} }
else else
{ {
p.replace_instruction( m.replace_instruction(
ins, make_op("transpose", {{"permutation", dims}}), t->inputs().front()); ins, make_op("transpose", {{"permutation", dims}}), t->inputs().front());
} }
} }
...@@ -223,7 +223,7 @@ struct find_nested_slice ...@@ -223,7 +223,7 @@ struct find_nested_slice
return result; return result;
} }
void apply(module& p, const match::matcher_result& mr) const void apply(module& m, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
auto slice = ins->inputs().front(); auto slice = ins->inputs().front();
...@@ -241,7 +241,7 @@ struct find_nested_slice ...@@ -241,7 +241,7 @@ struct find_nested_slice
op.starts.push_back(pp.second.first); op.starts.push_back(pp.second.first);
op.ends.push_back(pp.second.second); op.ends.push_back(pp.second.second);
} }
p.replace_instruction(ins, op, input); m.replace_instruction(ins, op, input);
} }
}; };
...@@ -252,7 +252,7 @@ struct find_concat_transpose ...@@ -252,7 +252,7 @@ struct find_concat_transpose
return match::name("concat")(match::all_of[match::inputs()](match::transpose_shape())); return match::name("concat")(match::all_of[match::inputs()](match::transpose_shape()));
} }
void apply(module& p, const match::matcher_result& mr) const void apply(module& m, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
auto trans_inputs = ins->inputs(); auto trans_inputs = ins->inputs();
...@@ -279,14 +279,14 @@ struct find_concat_transpose ...@@ -279,14 +279,14 @@ struct find_concat_transpose
std::vector<instruction_ref> inputs; std::vector<instruction_ref> inputs;
std::transform( std::transform(
ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) { ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) {
return p.insert_instruction( return m.insert_instruction(
ins, make_op("transpose", {{"permutation", permutation}}), i); ins, make_op("transpose", {{"permutation", permutation}}), i);
}); });
auto concat = p.insert_instruction(ins, op, inputs); auto concat = m.insert_instruction(ins, op, inputs);
auto t = p.insert_instruction( auto t = m.insert_instruction(
ins, make_op("transpose", {{"permutation", ipermutation}}), concat); ins, make_op("transpose", {{"permutation", ipermutation}}), concat);
assert(ins->get_shape().lens() == t->get_shape().lens()); assert(ins->get_shape().lens() == t->get_shape().lens());
p.replace_instruction(ins, t); m.replace_instruction(ins, t);
} }
}; };
...@@ -303,7 +303,7 @@ struct find_nested_concat ...@@ -303,7 +303,7 @@ struct find_nested_concat
return op.axis; return op.axis;
} }
void apply(module& p, const match::matcher_result& mr) const void apply(module& m, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
auto axis = get_axis(ins); auto axis = get_axis(ins);
...@@ -316,9 +316,8 @@ struct find_nested_concat ...@@ -316,9 +316,8 @@ struct find_nested_concat
else else
args.push_back(i); args.push_back(i);
} }
})(ins->inputs()); })(ins->inputs());
p.replace_instruction(ins, ins->get_operator(), args); m.replace_instruction(ins, ins->get_operator(), args);
} }
}; };
...@@ -330,7 +329,7 @@ struct find_resize ...@@ -330,7 +329,7 @@ struct find_resize
match::args(match::name("reshape").bind("data"), match::is_constant().bind("ind"))); match::args(match::name("reshape").bind("data"), match::is_constant().bind("ind")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto ins_rsp = r.instructions["data"]; auto ins_rsp = r.instructions["data"];
...@@ -418,13 +417,13 @@ struct find_resize ...@@ -418,13 +417,13 @@ struct find_resize
} }
auto in_rsp = ins_rsp->inputs().front(); auto in_rsp = ins_rsp->inputs().front();
auto rsp_data = p.insert_instruction( auto rsp_data = m.insert_instruction(
ins_rsp, migraphx::make_op("reshape", {{"dims", in_dims}}), in_rsp); ins_rsp, migraphx::make_op("reshape", {{"dims", in_dims}}), in_rsp);
auto mb_rsp = p.insert_instruction( auto mb_rsp = m.insert_instruction(
ins_rsp, migraphx::make_op("multibroadcast", {{"out_lens", out_dims}}), rsp_data); ins_rsp, migraphx::make_op("multibroadcast", {{"out_lens", out_dims}}), rsp_data);
auto std_mb = p.insert_instruction(ins, migraphx::make_op("contiguous"), mb_rsp); auto std_mb = m.insert_instruction(ins, migraphx::make_op("contiguous"), mb_rsp);
std::vector<int64_t> rsp_dims(out_lens.begin(), out_lens.end()); std::vector<int64_t> rsp_dims(out_lens.begin(), out_lens.end());
p.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), std_mb); m.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), std_mb);
} }
}; };
...@@ -437,7 +436,7 @@ struct find_where_op ...@@ -437,7 +436,7 @@ struct find_where_op
match::is_constant().bind("ind"))); match::is_constant().bind("ind")));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto concat = r.instructions["data"]; auto concat = r.instructions["data"];
...@@ -476,11 +475,11 @@ struct find_where_op ...@@ -476,11 +475,11 @@ struct find_where_op
if(val) if(val)
{ {
p.replace_instruction(ins, inputs.at(0)); m.replace_instruction(ins, inputs.at(0));
} }
else else
{ {
p.replace_instruction(ins, inputs.at(1)); m.replace_instruction(ins, inputs.at(1));
} }
} }
}; };
...@@ -497,7 +496,7 @@ struct find_reshape_cont ...@@ -497,7 +496,7 @@ struct find_reshape_cont
match::any())); match::any()));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto ins_cont = r.instructions["cont"]; auto ins_cont = r.instructions["cont"];
...@@ -531,11 +530,11 @@ struct find_reshape_cont ...@@ -531,11 +530,11 @@ struct find_reshape_cont
else else
{ {
inputs.push_back( inputs.push_back(
p.insert_instruction(ins, make_op("reshape", {{"dims", dims}}), in)); m.insert_instruction(ins, make_op("reshape", {{"dims", dims}}), in));
} }
} }
auto out = p.insert_instruction(ins, ins->get_operator(), inputs); auto out = m.insert_instruction(ins, ins->get_operator(), inputs);
p.replace_instruction(ins, make_op("reshape", {{"dims", out_dims}}), out); m.replace_instruction(ins, make_op("reshape", {{"dims", out_dims}}), out);
} }
}; };
...@@ -565,25 +564,25 @@ struct find_transpose_contiguous_reshaper_unary ...@@ -565,25 +564,25 @@ struct find_transpose_contiguous_reshaper_unary
match::args(match_transpose_contiguous_reshaper())); match::args(match_transpose_contiguous_reshaper()));
} }
void apply(module& p, match::matcher_result r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto reshaper_ins = r.instructions["reshaper_ins"]; auto reshaper_ins = r.instructions["reshaper_ins"];
auto trans_ins = r.instructions["trans_ins"]; auto trans_ins = r.instructions["trans_ins"];
auto cont_ins = r.instructions["cont_ins"]; auto cont_ins = r.instructions["cont_ins"];
auto unary_op_name = ins->get_operator().name(); auto unary_op_name = ins->get_operator().name();
auto unary_ins = p.insert_instruction(cont_ins, make_op(unary_op_name), trans_ins); auto unary_ins = m.insert_instruction(cont_ins, make_op(unary_op_name), trans_ins);
auto new_cont_ins = p.insert_instruction(cont_ins, make_op("contiguous"), unary_ins); auto new_cont_ins = m.insert_instruction(cont_ins, make_op("contiguous"), unary_ins);
// older cont and reshape are removed by deadcode elimination // older cont and reshape are removed by deadcode elimination
p.replace_instruction(ins, reshaper_ins->get_operator(), new_cont_ins); m.replace_instruction(ins, reshaper_ins->get_operator(), new_cont_ins);
} }
}; };
void simplify_reshapes::apply(module& p) const void simplify_reshapes::apply(module& m) const
{ {
for(int i = 0; i < 2; i++) for(int i = 0; i < 2; i++)
{ {
match::find_matches(p, match::find_matches(m,
find_where_op{}, find_where_op{},
find_resize{}, find_resize{},
find_reshape_cont{}, find_reshape_cont{},
...@@ -595,7 +594,7 @@ void simplify_reshapes::apply(module& p) const ...@@ -595,7 +594,7 @@ void simplify_reshapes::apply(module& p) const
find_nested_slice{}, find_nested_slice{},
find_nested_concat{}, find_nested_concat{},
find_transpose_contiguous_reshaper_unary{}); find_transpose_contiguous_reshaper_unary{});
dead_code_elimination{}.apply(p); dead_code_elimination{}.apply(m);
} }
} }
......
...@@ -20,7 +20,6 @@ struct cpu_copy : reduce_dims_base, auto_register_op<cpu_copy> ...@@ -20,7 +20,6 @@ struct cpu_copy : reduce_dims_base, auto_register_op<cpu_copy>
return inputs.at(1); return inputs.at(1);
} }
argument argument
// cppcheck-suppress constParameter
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{ {
argument result = get_arg(args, args.size() - 1); argument result = get_arg(args, args.size() - 1);
......
...@@ -26,7 +26,6 @@ struct cpu_gather : auto_register_op<cpu_gather> ...@@ -26,7 +26,6 @@ struct cpu_gather : auto_register_op<cpu_gather>
} }
argument argument
// cppcheck-suppress constParameter
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{ {
std::size_t nelements = output_shape.elements(); std::size_t nelements = output_shape.elements();
......
...@@ -7,7 +7,16 @@ ...@@ -7,7 +7,16 @@
#ifdef MIGRAPHX_DISABLE_OMP #ifdef MIGRAPHX_DISABLE_OMP
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#else #else
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wreserved-identifier"
#endif
#include <omp.h> #include <omp.h>
#ifdef __clang__
#pragma clang diagnostic pop
#endif
#endif #endif
namespace migraphx { namespace migraphx {
......
...@@ -213,7 +213,6 @@ template <std::size_t N, class... Xs> ...@@ -213,7 +213,6 @@ template <std::size_t N, class... Xs>
bool is_vectorizable(const Xs&... xs) bool is_vectorizable(const Xs&... xs)
{ {
return all_of({xs...}, [](const auto& s) { return all_of({xs...}, [](const auto& s) {
if(s.standard() and (s.lens().back() % N) == 0) if(s.standard() and (s.lens().back() % N) == 0)
return true; return true;
if(s.broadcasted()) if(s.broadcasted())
...@@ -320,11 +319,10 @@ struct cpu_unary : reduce_dims_base, auto_register_op<cpu_unary<Op>> ...@@ -320,11 +319,10 @@ struct cpu_unary : reduce_dims_base, auto_register_op<cpu_unary<Op>>
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this}.has(2);
auto s = inputs.at(0); const auto& s = inputs.at(0);
return {s.type(), s.lens()}; return {s.type(), s.lens()};
} }
argument argument
// cppcheck-suppress constParameter
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{ {
argument result = get_arg(args, args.size() - 1); argument result = get_arg(args, args.size() - 1);
...@@ -358,12 +356,11 @@ struct cpu_binary : reduce_dims_base, auto_register_op<cpu_binary<Op>> ...@@ -358,12 +356,11 @@ struct cpu_binary : reduce_dims_base, auto_register_op<cpu_binary<Op>>
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(3); check_shapes{inputs, *this}.has(3);
auto s = inputs.at(0); const auto& s = inputs.at(0);
return {s.type(), s.lens()}; return {s.type(), s.lens()};
} }
argument argument
// cppcheck-suppress constParameter
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{ {
argument result = get_arg(args, args.size() - 1); argument result = get_arg(args, args.size() - 1);
......
...@@ -223,7 +223,7 @@ struct cpu_unary2 : auto_register_op<cpu_unary2<Op>> ...@@ -223,7 +223,7 @@ struct cpu_unary2 : auto_register_op<cpu_unary2<Op>>
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1);
auto s = inputs.at(0); const auto& s = inputs.at(0);
return {s.type(), s.lens()}; return {s.type(), s.lens()};
} }
...@@ -352,7 +352,7 @@ struct cpu_apply ...@@ -352,7 +352,7 @@ struct cpu_apply
std::transform(bind_inputs.begin(), std::transform(bind_inputs.begin(),
bind_inputs.end(), bind_inputs.end(),
std::back_inserter(inputs), std::back_inserter(inputs),
[&](const auto& s) { return r.instructions.at(s); }); [&](const auto& s) { return r.instructions[s]; });
inputs.push_back(this->insert_allocation(ins, ins->get_shape())); inputs.push_back(this->insert_allocation(ins, ins->get_shape()));
modl->replace_instruction(ins, op, inputs); modl->replace_instruction(ins, op, inputs);
}); });
...@@ -460,11 +460,6 @@ struct cpu_apply ...@@ -460,11 +460,6 @@ struct cpu_apply
if(has_op("dnnl::pooling") and ins->get_shape().type() == shape::type_t::float_type and if(has_op("dnnl::pooling") and ins->get_shape().type() == shape::type_t::float_type and
not v["ceil_mode"].to<bool>()) not v["ceil_mode"].to<bool>())
return replace(ins, make_op("dnnl::pooling", op.to_value())); return replace(ins, make_op("dnnl::pooling", op.to_value()));
std::string mode = v["mode"].to<std::string>();
if(mode == "max")
return replace(ins, make_op("cpu::pooling_max", v));
else if(mode == "average")
return replace(ins, make_op("cpu::pooling_average", v));
return ins; return ins;
} }
......
...@@ -11,125 +11,14 @@ namespace migraphx { ...@@ -11,125 +11,14 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace cpu { namespace cpu {
struct max_pool
{
static std::string name() { return "max"; }
template <class T>
static T start()
{
return std::numeric_limits<T>::lowest();
}
static double apply(double x, double y)
{
double m = std::max(x, y);
return (m);
}
static double final(double x, std::size_t) { return (x); }
};
struct avg_pool
{
static std::string name() { return "average"; }
template <class T>
static double start()
{
return 0.0;
}
static double apply(double x, double y) { return x + y; }
static double final(double x, std::size_t y) { return (y == 0) ? 0.0 : (x / y); }
};
template <class Op>
struct cpu_pooling : auto_register_op<cpu_pooling<Op>>
{
cpu_pooling() = default;
cpu_pooling(op::pooling pop) : op(std::move(pop)) {}
op::pooling op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "cpu::pooling_" + Op::name(); }
shape compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
return op.normalize_compute_shape(inputs);
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
visit_all(args.back(), args[0])([&](auto output, auto input) {
using type = typename decltype(output)::value_type;
auto in_s = input.get_shape();
auto in_lens = in_s.lens();
std::vector<std::size_t> vec_len(in_lens.begin() + 2, in_lens.end());
par_for(output_shape.elements(), [&](auto i) {
auto idx_o = output_shape.multi(i);
auto n_dim = idx_o.size();
std::vector<std::size_t> win_start;
std::vector<std::size_t> win_size;
for(std::size_t dim = 2; dim < n_dim; ++dim)
{
auto d_2 = dim - 2;
int start = static_cast<int>(idx_o[dim] * op.stride[d_2]) -
static_cast<int>(op.padding[d_2]);
int end = std::min(start + op.lengths[d_2], in_lens[dim]);
start = std::max(start, 0);
win_start.push_back(start);
win_size.push_back(end - start);
}
shape win_shape{output_shape.type(), win_size};
auto pool_size = win_shape.elements();
double acc = Op::template start<type>();
shape_for_each(win_shape, [&](auto idx_w) {
auto idx = idx_o;
std::transform(idx_w.begin(),
idx_w.end(),
win_start.begin(),
idx.begin() + 2,
[](auto ii, auto jj) { return ii + jj; });
if(std::all_of(idx.begin() + 2, idx.end(), [&](auto ii) { return ii >= 0; }) and
idx < in_lens)
{
acc = Op::apply(acc, input[in_s.index(idx)]);
}
});
output[i] = type(Op::final(acc, pool_size));
});
});
return args.back();
}
};
template struct cpu_pooling<avg_pool>;
template struct cpu_pooling<max_pool>;
struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::pooling> struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::pooling>
{ {
std::vector<int> arg_map(int) const { return {MIGRAPHX_DNNL_PREFIX(ARG_SRC)}; } std::vector<int> arg_map(int) const { return {MIGRAPHX_DNNL_PREFIX(ARG_SRC)}; }
dnnl::pooling_forward::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const dnnl::pooling_forward::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{ {
auto algo = op.mode == "max" ? dnnl::algorithm::pooling_max : dnnl::algorithm::pooling_avg; auto algo = op.mode == op::pooling_mode::max ? dnnl::algorithm::pooling_max
: dnnl::algorithm::pooling_avg;
auto kdims = op.kdims(); auto kdims = op.kdims();
std::vector<size_t> padding_l(op.padding.begin(), op.padding.begin() + kdims); std::vector<size_t> padding_l(op.padding.begin(), op.padding.begin() + kdims);
std::vector<size_t> padding_r(op.padding.begin() + kdims, op.padding.end()); std::vector<size_t> padding_r(op.padding.begin() + kdims, op.padding.end());
...@@ -145,5 +34,6 @@ struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::po ...@@ -145,5 +34,6 @@ struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::po
}; };
} // namespace cpu } // namespace cpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -11,7 +11,7 @@ if(NOT TARGET MIOpen) ...@@ -11,7 +11,7 @@ if(NOT TARGET MIOpen)
endif() endif()
include(Embed) include(Embed)
file(GLOB KERNEL_FILES file(GLOB KERNEL_FILES ${CONFIGURE_DEPENDS}
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp) ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp)
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}") message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
add_embed_library(migraphx_kernels ${KERNEL_FILES}) add_embed_library(migraphx_kernels ${KERNEL_FILES})
...@@ -93,7 +93,7 @@ add_library(migraphx_device ...@@ -93,7 +93,7 @@ add_library(migraphx_device
) )
add_library(compile_for_gpu INTERFACE) add_library(compile_for_gpu INTERFACE)
target_compile_options(compile_for_gpu INTERFACE -std=c++17 -fno-gpu-rdc -Wno-cuda-compat -Wno-unused-command-line-argument -Xclang -fallow-half-arguments-and-returns) target_compile_options(compile_for_gpu INTERFACE -std=c++17 -fno-gpu-rdc -Wno-cuda-compat -Wno-unused-command-line-argument -Xclang -fallow-half-arguments-and-returns)
target_link_libraries(compile_for_gpu INTERFACE hip::device -fno-gpu-rdc -Wno-invalid-command-line-argument -Wno-unused-command-line-argument) target_link_libraries(compile_for_gpu INTERFACE hip::device -fno-gpu-rdc -Wno-invalid-command-line-argument -Wno-unused-command-line-argument -Wno-option-ignored)
check_cxx_compiler_flag("--cuda-host-only -fhip-lambda-host-device -x hip" HAS_HIP_LAMBDA_HOST_DEVICE) check_cxx_compiler_flag("--cuda-host-only -fhip-lambda-host-device -x hip" HAS_HIP_LAMBDA_HOST_DEVICE)
if(HAS_HIP_LAMBDA_HOST_DEVICE) if(HAS_HIP_LAMBDA_HOST_DEVICE)
message(STATUS "Enable -fhip-lambda-host-device") message(STATUS "Enable -fhip-lambda-host-device")
...@@ -114,11 +114,13 @@ foreach(KERNEL_FILE ${KERNEL_FILES}) ...@@ -114,11 +114,13 @@ foreach(KERNEL_FILE ${KERNEL_FILES})
file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/kernels/include/migraphx/kernels/${KERNEL_BASE_FILE}.cpp "#include <migraphx/kernels/${KERNEL_BASE_FILE}.hpp>\n") file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/kernels/include/migraphx/kernels/${KERNEL_BASE_FILE}.cpp "#include <migraphx/kernels/${KERNEL_BASE_FILE}.hpp>\n")
target_sources(kernel_file_check PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/kernels/include/migraphx/kernels/${KERNEL_BASE_FILE}.cpp) target_sources(kernel_file_check PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/kernels/include/migraphx/kernels/${KERNEL_BASE_FILE}.cpp)
endforeach() endforeach()
target_compile_definitions(kernel_file_check PRIVATE -DMIGRAPHX_NLOCAL=256)
target_include_directories(kernel_file_check PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/>) target_include_directories(kernel_file_check PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/>)
target_link_libraries(kernel_file_check compile_for_gpu) target_link_libraries(kernel_file_check compile_for_gpu)
rocm_clang_tidy_check(kernel_file_check) rocm_clang_tidy_check(kernel_file_check)
file(GLOB JIT_GPU_SRCS ${CONFIGURE_DEPENDS} ${CMAKE_CURRENT_SOURCE_DIR}/jit/*.cpp)
add_library(migraphx_gpu add_library(migraphx_gpu
abs.cpp abs.cpp
analyze_streams.cpp analyze_streams.cpp
...@@ -129,10 +131,10 @@ add_library(migraphx_gpu ...@@ -129,10 +131,10 @@ add_library(migraphx_gpu
clip.cpp clip.cpp
code_object_op.cpp code_object_op.cpp
compile_ops.cpp compile_ops.cpp
compile_gen.cpp
compile_hip.cpp compile_hip.cpp
compile_hip_code_object.cpp compile_hip_code_object.cpp
compile_pointwise.cpp compiler.cpp
compile_roialign.cpp
concat.cpp concat.cpp
convert.cpp convert.cpp
convolution.cpp convolution.cpp
...@@ -157,6 +159,7 @@ add_library(migraphx_gpu ...@@ -157,6 +159,7 @@ add_library(migraphx_gpu
nonzero.cpp nonzero.cpp
pack_args.cpp pack_args.cpp
pack_int8_args.cpp pack_int8_args.cpp
prefuse_ops.cpp
pad.cpp pad.cpp
pooling.cpp pooling.cpp
quant_convolution.cpp quant_convolution.cpp
...@@ -170,6 +173,7 @@ add_library(migraphx_gpu ...@@ -170,6 +173,7 @@ add_library(migraphx_gpu
target.cpp target.cpp
topk.cpp topk.cpp
write_literals.cpp write_literals.cpp
${JIT_GPU_SRCS}
) )
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu) set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
...@@ -330,6 +334,12 @@ target_compile_definitions(migraphx_gpu PRIVATE ...@@ -330,6 +334,12 @@ target_compile_definitions(migraphx_gpu PRIVATE
"-DMIGRAPHX_EXTRACT_KERNEL=${MIGRAPHX_EXTRACT_KERNEL}" "-DMIGRAPHX_EXTRACT_KERNEL=${MIGRAPHX_EXTRACT_KERNEL}"
"-DMIGRAPHX_USE_HIPRTC=0" "-DMIGRAPHX_USE_HIPRTC=0"
) )
if(DEFINED CMAKE_CXX_COMPILER_LAUNCHER)
execute_process(COMMAND which ${CMAKE_CXX_COMPILER_LAUNCHER} OUTPUT_VARIABLE MIGRAPHX_HIP_COMPILER_LAUNCHER)
string(STRIP "${MIGRAPHX_HIP_COMPILER_LAUNCHER}" MIGRAPHX_HIP_COMPILER_LAUNCHER)
target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_HIP_COMPILER_LAUNCHER=${MIGRAPHX_HIP_COMPILER_LAUNCHER}")
endif()
endif() endif()
# Check miopen find mode api # Check miopen find mode api
......
...@@ -28,30 +28,30 @@ struct hip_stream_model ...@@ -28,30 +28,30 @@ struct hip_stream_model
bool is_wait(migraphx::instruction_ref ins) const { return ins->name() == "gpu::wait_event"; } bool is_wait(migraphx::instruction_ref ins) const { return ins->name() == "gpu::wait_event"; }
}; };
stream_model make_stream_model(const module& p) stream_model make_stream_model(const module& m)
{ {
hip_stream_model m; hip_stream_model hsm;
std::size_t stream = 0; std::size_t stream = 0;
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(m))
{ {
if(ins->name() == "gpu::set_stream") if(ins->name() == "gpu::set_stream")
{ {
auto v = ins->get_operator().to_value(); auto v = ins->get_operator().to_value();
stream = v["stream"].to<std::size_t>(); stream = v["stream"].to<std::size_t>();
m.max_stream = std::max(stream, m.max_stream); hsm.max_stream = std::max(stream, hsm.max_stream);
} }
if(ins->get_operator().is_context_free()) if(ins->get_operator().is_context_free())
continue; continue;
if(contains({"hip::hip_allocate_memory", "hip::hip_copy_literal", "@param"}, ins->name())) if(contains({"hip::hip_allocate_memory", "hip::hip_copy_literal", "@param"}, ins->name()))
continue; continue;
m.ins2stream[ins] = stream; hsm.ins2stream[ins] = stream;
} }
return m; return hsm;
} }
std::vector<stream_race> analyze_streams(const module& p) std::vector<stream_race> analyze_streams(const module& m)
{ {
return migraphx::analyze_streams(p, make_stream_model(p)); return migraphx::analyze_streams(m, make_stream_model(m));
} }
} // namespace gpu } // namespace gpu
......
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace gen {
static std::vector<std::size_t> vector_sizes(const std::vector<shape>& inputs)
{
// If all inputs are half then only use half2
if(std::all_of(inputs.begin(), inputs.end(), [](const auto& s) {
return s.type() == shape::half_type;
}))
return {2};
return {4, 2};
}
vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs)
{
auto sizes = vector_sizes(inputs);
std::vector<std::size_t> max_vec_size;
std::transform(inputs.begin(),
inputs.end(),
std::back_inserter(max_vec_size),
[&](const auto& input) -> std::size_t {
auto stride = input.strides()[axis];
auto len = input.lens()[axis];
if(stride != 0 and stride != 1)
return 1;
if(len == 1 and input.elements() > sizes.front())
return sizes.front();
auto it = std::find_if(
sizes.begin(), sizes.end(), [&](auto i) { return (len % i) == 0; });
if(it != sizes.end())
return *it;
return 1;
});
return {*std::min_element(max_vec_size.begin(), max_vec_size.end()), axis};
}
std::string vectorize::str() const
{
return "vectorize<" + to_string(size) + ", " + to_string(axis) + ">()";
}
preload preload::broadcasts(std::size_t axis, const std::vector<shape>& inputs)
{
const std::size_t max_lds_bytes = 4096;
std::vector<bool> result;
std::transform(inputs.begin(),
inputs.end(),
std::back_inserter(result),
[&](const shape& input) { return input.strides()[axis] == 0; });
auto bytes = std::inner_product(inputs.begin(),
inputs.end(),
result.begin(),
std::size_t{0},
std::plus<>{},
[](const shape& s, bool b) -> std::size_t {
if(b)
return s.bytes();
return 0;
});
if(bytes < max_lds_bytes)
return {result};
// TODO: Try to partially preload items
std::fill(result.begin(), result.end(), false);
return {result};
}
std::string preload::str() const
{
std::vector<std::string> bool_strs;
std::transform(args.begin(), std::prev(args.end()), std::back_inserter(bool_strs), [](bool b) {
if(b)
return "true";
return "false";
});
return "auto_preload<false, " + join_strings(bool_strs, ", ") + ">(idx)";
}
bool preload::is_preloading() const
{
return std::accumulate(args.begin(), args.end(), false, std::logical_or<>{});
}
std::size_t find_fast_axis(const std::vector<shape>& inputs)
{
auto permutation = find_permutation(inputs);
auto it = std::max_element(permutation.begin(), permutation.end());
return it - permutation.begin();
}
std::string make_transformer_args(std::vector<std::string> transformers)
{
return join_strings(std::move(transformers), ", ");
}
} // namespace gen
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -21,6 +21,8 @@ namespace gpu { ...@@ -21,6 +21,8 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DEBUG); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DEBUG);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_OPTIMIZE); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_OPTIMIZE);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_ASM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC);
#if MIGRAPHX_USE_HIPRTC #if MIGRAPHX_USE_HIPRTC
...@@ -178,6 +180,19 @@ bool is_hip_clang_compiler() ...@@ -178,6 +180,19 @@ bool is_hip_clang_compiler()
return result; return result;
} }
bool has_compiler_launcher()
{
static const auto result = fs::exists(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_LAUNCHER));
return result;
}
src_compiler assemble(src_compiler compiler)
{
compiler.out_ext = ".S";
compiler.flags = replace_string(compiler.flags, " -c", " -S");
return compiler;
}
std::vector<std::vector<char>> std::vector<std::vector<char>>
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch) compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch)
{ {
...@@ -210,6 +225,10 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -210,6 +225,10 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
src_compiler compiler; src_compiler compiler;
compiler.flags = params; compiler.flags = params;
compiler.compiler = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER); compiler.compiler = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER);
#ifdef MIGRAPHX_HIP_COMPILER_LAUNCHER
if(has_compiler_launcher())
compiler.launcher = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_LAUNCHER);
#endif
if(is_hcc_compiler()) if(is_hcc_compiler())
compiler.process = [&](const fs::path& obj_path) -> fs::path { compiler.process = [&](const fs::path& obj_path) -> fs::path {
...@@ -228,6 +247,22 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -228,6 +247,22 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
MIGRAPHX_THROW("Missing hsaco"); MIGRAPHX_THROW("Missing hsaco");
}; };
if(enabled(MIGRAPHX_GPU_DUMP_SRC{}))
{
for(const auto& src : srcs)
{
if(src.path.extension() != ".cpp")
continue;
std::cout << std::string(src.content.first, src.len()) << std::endl;
}
}
if(enabled(MIGRAPHX_GPU_DUMP_ASM{}))
{
std::cout << assemble(compiler).compile(srcs).data() << std::endl;
}
return {compiler.compile(srcs)}; return {compiler.compile(srcs)};
} }
...@@ -238,13 +273,6 @@ std::string enum_params(std::size_t count, std::string param) ...@@ -238,13 +273,6 @@ std::string enum_params(std::size_t count, std::string param)
return join_strings(items, ","); return join_strings(items, ",");
} }
std::size_t compute_global(std::size_t n, std::size_t local)
{
std::size_t groups = (n + local - 1) / local;
std::size_t nglobal = std::min<std::size_t>(256, groups) * local;
return nglobal;
}
#endif // MIGRAPHX_USE_HIPRTC #endif // MIGRAPHX_USE_HIPRTC
} // namespace gpu } // namespace gpu
......
...@@ -93,8 +93,47 @@ const std::vector<std::string>& compiler_warnings() ...@@ -93,8 +93,47 @@ const std::vector<std::string>& compiler_warnings()
return warnings; return warnings;
} }
void hip_compile_options::set_launch_params(
const value& v,
const std::function<std::size_t(std::size_t local)>& compute_global,
std::size_t default_local)
{
local = v.get("local", default_local);
if(v.contains("global"))
global = v.at("global").to<std::size_t>();
else
global = compute_global(local);
}
std::function<std::size_t(std::size_t local)>
compute_global_for(context& ctx, std::size_t n, std::size_t over)
{
assert(over > 0);
std::size_t max_global = ctx.get_current_device().get_cu_count() *
ctx.get_current_device().get_max_workitems_per_cu();
return [n, over, max_global](std::size_t local) {
std::size_t groups = (n + local - 1) / local;
std::size_t max_blocks = max_global / local;
std::size_t nglobal = std::min(max_blocks * over, groups) * local;
return nglobal;
};
}
std::size_t compute_block_size(std::size_t n, std::size_t max_block_size)
{
size_t block_size = 128;
while(block_size <= max_block_size and block_size <= n)
block_size *= 2;
return block_size / 2;
}
operation compile_hip_code_object(const std::string& content, hip_compile_options options) operation compile_hip_code_object(const std::string& content, hip_compile_options options)
{ {
assert(options.global > 0);
assert(options.local > 0);
assert(not options.inputs.empty());
assert(options.inputs.size() == options.virtual_inputs.size() or
options.virtual_inputs.empty());
std::vector<src_file> srcs; std::vector<src_file> srcs;
std::transform(migraphx_kernels().begin(), std::transform(migraphx_kernels().begin(),
migraphx_kernels().end(), migraphx_kernels().end(),
......
...@@ -6,12 +6,14 @@ ...@@ -6,12 +6,14 @@
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
#include <migraphx/gpu/compile_pointwise.hpp> #include <migraphx/gpu/compiler.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_COMPILE_PARALLEL);
struct precompile_op struct precompile_op
{ {
operation op = op::identity{}; operation op = op::identity{};
...@@ -38,41 +40,22 @@ struct precompile_op ...@@ -38,41 +40,22 @@ struct precompile_op
MIGRAPHX_REGISTER_OP(precompile_op); MIGRAPHX_REGISTER_OP(precompile_op);
struct pointwise_compiler struct compiled_result
{ {
std::string name() const { return "pointwise"; } compiler_replace replace;
instruction_ref ins;
operation apply(context& ctx, instruction_ref ins, const operation&) const
{
assert(not ins->module_inputs().empty());
auto* pm = ins->module_inputs().front();
return compile_pointwise(ctx, to_shapes(ins->inputs()), *pm);
}
}; };
using compiler_function = std::function<operation(context&, instruction_ref, operation)>; template <class F>
void par_compile(std::size_t n, F f)
template <class T>
compiler_function make_compiler_function(T x)
{ {
return {[=](auto&&... xs) { return x.apply(xs...); }}; if(n == 0)
return;
par_for(n, n / value_of(MIGRAPHX_GPU_COMPILE_PARALLEL{}, n), f);
} }
template <class... Ts>
std::unordered_map<std::string, compiler_function> make_compilers(Ts... xs)
{
return {{xs.name(), make_compiler_function(xs)}...};
}
struct compiled_result
{
operation op;
instruction_ref ins;
};
void compile_ops::apply(module& m) const void compile_ops::apply(module& m) const
{ {
auto compilers = make_compilers(pointwise_compiler{});
std::vector<std::function<compiled_result()>> compiles; std::vector<std::function<compiled_result()>> compiles;
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
...@@ -80,15 +63,15 @@ void compile_ops::apply(module& m) const ...@@ -80,15 +63,15 @@ void compile_ops::apply(module& m) const
if(ins->name() != "gpu::precompile_op") if(ins->name() != "gpu::precompile_op")
continue; continue;
operation preop = any_cast<precompile_op>(ins->get_operator()).op; operation preop = any_cast<precompile_op>(ins->get_operator()).op;
assert(contains(compilers, preop.name())); compiles.emplace_back([=]() -> compiled_result {
auto c = compilers[preop.name()]; return {compile(*ctx, ins, preop), ins};
compiles.emplace_back([=]() -> compiled_result { return {c(*ctx, ins, preop), ins}; }); });
} }
std::vector<compiled_result> results(compiles.size()); std::vector<compiled_result> results(compiles.size());
par_for(compiles.size(), 1, [&](auto i) { results[i] = compiles[i](); }); par_compile(compiles.size(), [&](auto i) { results[i] = compiles[i](); });
for(const auto& cr : results) for(const auto& cr : results)
{ {
m.replace_instruction(cr.ins, cr.op, cr.ins->inputs()); cr.replace(m, cr.ins);
} }
} }
......
#include <migraphx/gpu/compile_pointwise.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
static const char* const pointwise_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
__global__ void kernel(${params})
{
pointwise(${lambda}, ${args});
}
}
} // namespace migraphx
int main() {}
)__migraphx__";
operation compile_pointwise(context&,
const std::vector<shape>& inputs,
const std::string& lambda,
const std::string& preamble)
{
hip_compile_options options;
options.global = compute_global(inputs.front().elements());
options.local = 1024;
options.inputs = inputs;
options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs);
options.params = "-Wno-float-equal";
auto src = interpolate_string(pointwise_kernel,
{{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"lambda", lambda},
{"preamble", preamble}});
return compile_hip_code_object(src, options);
}
operation compile_pointwise(context& ctx, const std::vector<shape>& inputs, module m)
{
run_passes(m, {eliminate_common_subexpression{}, dead_code_elimination{}});
cpp_generator g;
g.fmap([](const std::string& fname) { return "migraphx::" + fname; });
g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})");
g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})");
g.add_point_op("sign", "${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))");
g.add_point_op("equal", "migraphx::abs(${0} == ${1})");
g.add_point_op("less", "migraphx::abs(${0} < ${1})");
g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
g.add_point_op("not", "migraphx::abs(not ${0})");
auto name =
g.create_function(g.generate_module(m).set_attributes({"__device__"}).set_generic_types(m));
return compile_pointwise((ctx), inputs, "MIGRAPHX_LIFT(" + name + ")", g.str());
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment