Commit 11e155c2 authored by Paul's avatar Paul
Browse files

Merge

parents 8a9c5bce aa7ff911
......@@ -12,9 +12,9 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void rewrite_pooling::apply(module& prog) const
void rewrite_pooling::apply(module& m) const
{
for(auto ins : iterator_for(prog))
for(auto ins : iterator_for(m))
{
if(ins->name() != "pooling")
continue;
......@@ -33,26 +33,25 @@ void rewrite_pooling::apply(module& prog) const
continue;
std::int64_t n = s.lens()[0];
std::int64_t c = s.lens()[1];
auto reshape = prog.insert_instruction(
auto reshape = m.insert_instruction(
ins, make_op("reshape", {{"dims", {n * c, -1}}}), ins->inputs().front());
instruction_ref pooling{};
// average pooling
if(op.mode == "average")
if(op.mode == op::pooling_mode::average)
{
pooling =
prog.insert_instruction(ins, make_op("reduce_mean", {{"axes", {1}}}), reshape);
pooling = m.insert_instruction(ins, make_op("reduce_mean", {{"axes", {1}}}), reshape);
}
// max pooling
else
{
pooling = prog.insert_instruction(ins, make_op("reduce_max", {{"axes", {1}}}), reshape);
pooling = m.insert_instruction(ins, make_op("reduce_max", {{"axes", {1}}}), reshape);
}
std::vector<int64_t> rsp_lens(lens.size(), 1);
rsp_lens[0] = n;
rsp_lens[1] = c;
prog.replace_instruction(ins, make_op("reshape", {{"dims", rsp_lens}}), pooling);
m.replace_instruction(ins, make_op("reshape", {{"dims", rsp_lens}}), pooling);
}
}
......
......@@ -30,27 +30,27 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void rewrite_rnn::apply(module& prog) const
void rewrite_rnn::apply(module& m) const
{
for(auto ins : iterator_for(prog))
for(auto ins : iterator_for(m))
{
if(ins->name() == "rnn")
{
apply_vanilla_rnn(prog, ins);
apply_vanilla_rnn(m, ins);
}
else if(ins->name() == "gru")
{
apply_gru(prog, ins);
apply_gru(m, ins);
}
else if(ins->name() == "lstm")
{
apply_lstm(prog, ins);
apply_lstm(m, ins);
}
}
}
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
{
assert(ins->name() == "rnn");
// could be 3 to 6 inputs, but the parse_rnn function will
......@@ -71,37 +71,37 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
op::rnn_direction dirct = rnn_op.direction;
// process sequence length
instruction_ref seq_lens = prog.end();
instruction_ref seq_lens = m.end();
if((args.size() >= 5) && args[4]->name() != "undefined")
{
seq_lens = args[4];
}
bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);
bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
instruction_ref last_output{};
if(dirct == op::rnn_direction::bidirectional)
{
// input weight matrix
auto w_forward = prog.insert_instruction(
auto w_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
auto w_reverse = prog.insert_instruction(
auto w_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
// hidden state weight matrix
auto r_forward = prog.insert_instruction(
auto r_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
auto r_reverse = prog.insert_instruction(
auto r_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
// process bias
instruction_ref bias_forward = prog.end();
instruction_ref bias_reverse = prog.end();
instruction_ref bias_forward = m.end();
instruction_ref bias_reverse = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias_forward = prog.insert_instruction(
bias_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
bias_reverse = prog.insert_instruction(
bias_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
}
......@@ -111,57 +111,56 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->name() != "undefined")
{
ih_forward = prog.insert_instruction(
ih_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
ih_reverse = prog.insert_instruction(
ih_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
}
else
{
ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
ih_forward = m.add_literal(migraphx::literal{ih_shape, data});
ih_reverse = m.add_literal(migraphx::literal{ih_shape, data});
}
auto ret_forward =
vanilla_rnn_cell(true,
prog,
m,
ins,
{args[0], w_forward, r_forward, bias_forward, seq_lens, ih_forward},
actv_funcs.at(0));
if(variable_seq_len)
{
args[0] = prog.insert_instruction(
ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
args[0] =
m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
}
auto ret_reverse =
vanilla_rnn_cell(false,
prog,
m,
ins,
{args[0], w_reverse, r_reverse, bias_reverse, seq_lens, ih_reverse},
actv_funcs.at(1));
auto concat_output = prog.insert_instruction(
auto concat_output = m.insert_instruction(
ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
last_output =
prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_output);
last_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_output);
// The following logic is to ensure the last instruction rewritten from
// rnn operator is a concat instruction
// sequence len is 1
if(ret_forward[0] == prog.end())
if(ret_forward[0] == m.end())
{
prog.replace_instruction(
m.replace_instruction(
ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
}
else
{
ret_forward[0] = prog.insert_instruction(
ret_forward[0] = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
ret_reverse[0] = prog.insert_instruction(
ret_reverse[0] = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
prog.replace_instruction(
m.replace_instruction(
ins, make_op("concat", {{"axis", 1}}), {ret_forward[0], ret_reverse[0]});
}
}
......@@ -175,7 +174,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
auto r = args[2];
// process bias and initial hidden state
instruction_ref bias = prog.end();
instruction_ref bias = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias = args[3];
......@@ -189,43 +188,42 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
}
else
{
ih = prog.add_literal(migraphx::literal{ih_shape, data});
ih = m.add_literal(migraphx::literal{ih_shape, data});
}
if(!is_forward and variable_seq_len)
{
args[0] = prog.insert_instruction(
ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
args[0] =
m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
}
auto ret = vanilla_rnn_cell(
is_forward, prog, ins, {args[0], w, r, bias, seq_lens, ih}, actv_funcs.at(0));
last_output = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
is_forward, m, ins, {args[0], w, r, bias, seq_lens, ih}, actv_funcs.at(0));
last_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
// following logic is to ensure the last instruction is a
// concat instruction
// sequence len is 1
if(ret[0] == prog.end())
if(ret[0] == m.end())
{
prog.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
}
else
{
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0];
prog.replace_instruction(
ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
}
}
// in case of all sequences are of the same lengths and shorter than the
// max sequence length, need to pad 0's at the end for output hidden states
ins = pad_hidden_states(prog, args[0], seq_lens, ins);
replace_last_hs_output(prog, ins, seq_lens, last_output, dirct);
ins = pad_hidden_states(m, args[0], seq_lens, ins);
replace_last_hs_output(m, ins, seq_lens, last_output, dirct);
}
std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
module& prog,
module& m,
instruction_ref ins,
std::vector<instruction_ref> inputs,
operation& actv_func) const
......@@ -240,60 +238,60 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
// squeeze and transpose w
std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tran_sw = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
auto sw = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tran_sw = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
// squeeze and transpose r
auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto tran_sr = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr);
auto sr = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto tran_sr = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr);
// initial hidden state
auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
auto sih = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
auto sih_lens = sih->get_shape().lens();
// bias
instruction_ref bb{};
if(bias != prog.end())
if(bias != m.end())
{
long hs = static_cast<long>(r->get_shape().lens()[2]);
auto sbias = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
auto wb = prog.insert_instruction(
auto sbias = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
auto wb = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), sbias);
auto rb = prog.insert_instruction(
auto rb = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), sbias);
auto wrb = prog.insert_instruction(ins, make_op("add"), wb, rb);
bb = prog.insert_instruction(
auto wrb = m.insert_instruction(ins, make_op("add"), wb, rb);
bb = m.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"out_lens", sih_lens}}), wrb);
}
instruction_ref hidden_out = prog.end();
instruction_ref hidden_out = m.end();
instruction_ref last_out{};
last_out = prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), sih);
long seq_len = get_seq_len(prog, seq, seq_lens);
last_out = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), sih);
long seq_len = get_seq_len(m, seq, seq_lens);
for(long i = 0; i < seq_len; i++)
{
long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction(
auto xt = m.insert_instruction(
ins,
make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
seq);
auto cont_xt = prog.insert_instruction(ins, make_op("contiguous"), xt);
xt = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
auto xt_wi = prog.insert_instruction(ins, make_op("dot"), xt, tran_sw);
auto ht_ri = prog.insert_instruction(ins, make_op("dot"), sih, tran_sr);
if(bias != prog.end())
auto cont_xt = m.insert_instruction(ins, make_op("contiguous"), xt);
xt = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
auto xt_wi = m.insert_instruction(ins, make_op("dot"), xt, tran_sw);
auto ht_ri = m.insert_instruction(ins, make_op("dot"), sih, tran_sr);
if(bias != m.end())
{
xt_wi = prog.insert_instruction(ins, make_op("add"), xt_wi, bb);
xt_wi = m.insert_instruction(ins, make_op("add"), xt_wi, bb);
}
auto xt_ht = prog.insert_instruction(ins, make_op("add"), xt_wi, ht_ri);
auto xt_ht = m.insert_instruction(ins, make_op("add"), xt_wi, ht_ri);
// apply activation function
auto ht = prog.insert_instruction(ins, actv_func, xt_ht);
auto ht = m.insert_instruction(ins, actv_func, xt_ht);
sih = ht;
// add the dimensions of sequence length (axis 0 for sequence length,
// axis 1 for num_directions
last_out = prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), ht);
last_out = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), ht);
// concatenation for the last last_out is performed in the apply()
// function to ensure the last instruction is concat, then we have
......@@ -304,14 +302,14 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
{
hidden_out = (seq_index == 0)
? last_out
: prog.insert_instruction(
: m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), hidden_out, last_out);
}
else
{
hidden_out = (seq_index == seq_len - 1)
? last_out
: prog.insert_instruction(
: m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), last_out, hidden_out);
}
}
......@@ -358,7 +356,7 @@ std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins)
}
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
{
assert(ins->name() == "gru");
const auto actv_funcs = gru_actv_funcs(ins);
......@@ -379,37 +377,37 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
op::rnn_direction dirct = gru_op.direction;
// process sequence length
instruction_ref seq_lens = prog.end();
instruction_ref seq_lens = m.end();
if((args.size() >= 5) && args[4]->name() != "undefined")
{
seq_lens = args[4];
}
bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);
bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
instruction_ref last_output{};
if(dirct == op::rnn_direction::bidirectional)
{
// w weight matrix
auto w_forward = prog.insert_instruction(
auto w_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
auto w_reverse = prog.insert_instruction(
auto w_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
// r weight matrix
auto r_forward = prog.insert_instruction(
auto r_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
auto r_reverse = prog.insert_instruction(
auto r_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
// bias
instruction_ref bias_forward = prog.end();
instruction_ref bias_reverse = prog.end();
instruction_ref bias_forward = m.end();
instruction_ref bias_reverse = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias_forward = prog.insert_instruction(
bias_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
bias_reverse = prog.insert_instruction(
bias_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
}
......@@ -418,20 +416,20 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->name() != "undefined")
{
ih_forward = prog.insert_instruction(
ih_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
ih_reverse = prog.insert_instruction(
ih_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
}
else
{
ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
ih_forward = m.add_literal(migraphx::literal{ih_shape, data});
ih_reverse = m.add_literal(migraphx::literal{ih_shape, data});
}
auto ret_forward =
gru_cell(true,
prog,
m,
ins,
{args[0], w_forward, r_forward, bias_forward, seq_lens, ih_forward},
gru_op.linear_before_reset,
......@@ -440,38 +438,37 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
if(variable_seq_len)
{
args[0] = prog.insert_instruction(
ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
args[0] =
m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
}
auto ret_reverse =
gru_cell(false,
prog,
m,
ins,
{args[0], w_reverse, r_reverse, bias_reverse, seq_lens, ih_reverse},
gru_op.linear_before_reset,
actv_funcs.at(2),
actv_funcs.at(3));
auto concat_output = prog.insert_instruction(
auto concat_output = m.insert_instruction(
ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
last_output =
prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_output);
last_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_output);
// The following logic is to ensure the last instruction rewritten
// from gru operator is a concat
if(ret_forward[0] == prog.end())
if(ret_forward[0] == m.end())
{
prog.replace_instruction(
m.replace_instruction(
ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
}
else
{
ret_forward[0] = prog.insert_instruction(
ret_forward[0] = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
ret_reverse[0] = prog.insert_instruction(
ret_reverse[0] = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
prog.replace_instruction(
m.replace_instruction(
ins, make_op("concat", {{"axis", 1}}), {ret_forward[0], ret_reverse[0]});
}
}
......@@ -483,7 +480,7 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
auto r = args[2];
// bias
instruction_ref bias = prog.end();
instruction_ref bias = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias = args[3];
......@@ -497,47 +494,46 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
}
else
{
ih = prog.add_literal(migraphx::literal{ih_shape, data});
ih = m.add_literal(migraphx::literal{ih_shape, data});
}
if(!is_forward and variable_seq_len)
{
args[0] = prog.insert_instruction(
ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
args[0] =
m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
}
auto ret = gru_cell(is_forward,
prog,
m,
ins,
{args[0], w, r, bias, seq_lens, ih},
gru_op.linear_before_reset,
actv_funcs.at(0),
actv_funcs.at(1));
last_output = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
last_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
if(ret[0] == prog.end())
if(ret[0] == m.end())
{
prog.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
}
else
{
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0];
prog.replace_instruction(
ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
}
}
// in case of all sequences are of the same lengths and shorter than the
// max sequence length, need to pad 0's at the end for output hidden states
ins = pad_hidden_states(prog, args[0], seq_lens, ins);
replace_last_hs_output(prog, ins, seq_lens, last_output, dirct);
ins = pad_hidden_states(m, args[0], seq_lens, ins);
replace_last_hs_output(m, ins, seq_lens, last_output, dirct);
}
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
module& prog,
module& m,
instruction_ref ins,
std::vector<instruction_ref> inputs,
int linear_before_reset,
......@@ -552,7 +548,7 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
auto seq_lens = inputs.at(4);
auto ih = inputs.at(5);
instruction_ref hidden_states = prog.end();
instruction_ref hidden_states = m.end();
instruction_ref last_output{};
migraphx::shape seq_shape = seq->get_shape();
migraphx::shape r_shape = r->get_shape();
......@@ -560,127 +556,127 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
migraphx::shape ss(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]});
std::vector<float> data(ss.elements(), 1.0f);
auto l1 = prog.add_literal(migraphx::literal{ss, data});
auto l1 = m.add_literal(migraphx::literal{ss, data});
// w matrix squeeze to 2-dim and do a transpose
std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tw = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
auto sw = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tw = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
// r slide to two part, zr and h
auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto rzr = prog.insert_instruction(
auto sr = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto rzr = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2 * hs}}}), sr);
auto trzr = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rzr);
auto trzr = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rzr);
auto rh = prog.insert_instruction(
auto rh = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), sr);
auto trh = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rh);
auto trh = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rh);
// initial states
auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
auto sih = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
size_t bs = ih->get_shape().lens()[1];
// bias
instruction_ref bwb{};
instruction_ref brb_zr{};
instruction_ref brb_h{};
if(bias != prog.end())
if(bias != m.end())
{
auto sbias = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
auto wb = prog.insert_instruction(
auto sbias = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
auto wb = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {3 * hs}}}), sbias);
bwb = prog.insert_instruction(
bwb = m.insert_instruction(
ins,
make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(3 * hs)}}}),
wb);
auto rb_zr = prog.insert_instruction(
auto rb_zr = m.insert_instruction(
ins,
make_op("slice", {{"axes", {0}}, {"starts", {3 * hs}}, {"ends", {5 * hs}}}),
sbias);
auto rb_h = prog.insert_instruction(
auto rb_h = m.insert_instruction(
ins,
make_op("slice", {{"axes", {0}}, {"starts", {5 * hs}}, {"ends", {6 * hs}}}),
sbias);
brb_zr = prog.insert_instruction(
brb_zr = m.insert_instruction(
ins,
make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(2 * hs)}}}),
rb_zr);
brb_h = prog.insert_instruction(
brb_h = m.insert_instruction(
ins,
make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(hs)}}}),
rb_h);
}
long seq_len = get_seq_len(prog, seq, seq_lens);
long seq_len = get_seq_len(m, seq, seq_lens);
for(long i = 0; i < seq_len; i++)
{
long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction(
auto xt = m.insert_instruction(
ins,
make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
seq);
auto cont_xt = prog.insert_instruction(ins, make_op("contiguous"), xt);
xt = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
auto cont_xt = m.insert_instruction(ins, make_op("contiguous"), xt);
xt = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
auto xt_w = prog.insert_instruction(ins, make_op("dot"), xt, tw);
auto ih1_rzr = prog.insert_instruction(ins, make_op("dot"), sih, trzr);
if(bias != prog.end())
auto xt_w = m.insert_instruction(ins, make_op("dot"), xt, tw);
auto ih1_rzr = m.insert_instruction(ins, make_op("dot"), sih, trzr);
if(bias != m.end())
{
xt_w = prog.insert_instruction(ins, make_op("add"), xt_w, bwb);
ih1_rzr = prog.insert_instruction(ins, make_op("add"), ih1_rzr, brb_zr);
xt_w = m.insert_instruction(ins, make_op("add"), xt_w, bwb);
ih1_rzr = m.insert_instruction(ins, make_op("add"), ih1_rzr, brb_zr);
}
auto xw_z = prog.insert_instruction(
auto xw_z = m.insert_instruction(
ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), xt_w);
auto xw_r = prog.insert_instruction(
auto xw_r = m.insert_instruction(
ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), xt_w);
auto xw_h = prog.insert_instruction(
auto xw_h = m.insert_instruction(
ins, make_op("slice", {{"axes", {1}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), xt_w);
auto hr_z = prog.insert_instruction(
auto hr_z = m.insert_instruction(
ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), ih1_rzr);
auto hr_r = prog.insert_instruction(
auto hr_r = m.insert_instruction(
ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), ih1_rzr);
auto xw_hr_z = prog.insert_instruction(ins, make_op("add"), xw_z, hr_z);
auto zt = prog.insert_instruction(ins, actv_func1, xw_hr_z);
auto xw_hr_z = m.insert_instruction(ins, make_op("add"), xw_z, hr_z);
auto zt = m.insert_instruction(ins, actv_func1, xw_hr_z);
auto xw_hr_r = prog.insert_instruction(ins, make_op("add"), xw_r, hr_r);
auto rt = prog.insert_instruction(ins, actv_func1, xw_hr_r);
auto xw_hr_r = m.insert_instruction(ins, make_op("add"), xw_r, hr_r);
auto rt = m.insert_instruction(ins, actv_func1, xw_hr_r);
instruction_ref hr_h{};
if(linear_before_reset == 0)
{
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto rt_ht1 = prog.insert_instruction(ins, make_op("mul"), rt, sih);
hr_h = prog.insert_instruction(ins, make_op("dot"), rt_ht1, trh);
if(bias != prog.end())
auto rt_ht1 = m.insert_instruction(ins, make_op("mul"), rt, sih);
hr_h = m.insert_instruction(ins, make_op("dot"), rt_ht1, trh);
if(bias != m.end())
{
hr_h = prog.insert_instruction(ins, make_op("add"), hr_h, brb_h);
hr_h = m.insert_instruction(ins, make_op("add"), hr_h, brb_h);
}
}
else
{
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto ht1_rh = prog.insert_instruction(ins, make_op("dot"), sih, trh);
if(bias != prog.end())
auto ht1_rh = m.insert_instruction(ins, make_op("dot"), sih, trh);
if(bias != m.end())
{
ht1_rh = prog.insert_instruction(ins, make_op("add"), ht1_rh, brb_h);
ht1_rh = m.insert_instruction(ins, make_op("add"), ht1_rh, brb_h);
}
hr_h = prog.insert_instruction(ins, make_op("mul"), rt, ht1_rh);
hr_h = m.insert_instruction(ins, make_op("mul"), rt, ht1_rh);
}
auto xw_hr_h = prog.insert_instruction(ins, make_op("add"), xw_h, hr_h);
auto ht = prog.insert_instruction(ins, actv_func2, xw_hr_h);
auto xw_hr_h = m.insert_instruction(ins, make_op("add"), xw_h, hr_h);
auto ht = m.insert_instruction(ins, actv_func2, xw_hr_h);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto one_minus_zt = prog.insert_instruction(ins, make_op("sub"), l1, zt);
auto one_minus_zt_ht = prog.insert_instruction(ins, make_op("mul"), one_minus_zt, ht);
auto zt_ht1 = prog.insert_instruction(ins, make_op("mul"), zt, sih);
sih = prog.insert_instruction(ins, make_op("add"), one_minus_zt_ht, zt_ht1);
last_output = prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), sih);
auto one_minus_zt = m.insert_instruction(ins, make_op("sub"), l1, zt);
auto one_minus_zt_ht = m.insert_instruction(ins, make_op("mul"), one_minus_zt, ht);
auto zt_ht1 = m.insert_instruction(ins, make_op("mul"), zt, sih);
sih = m.insert_instruction(ins, make_op("add"), one_minus_zt_ht, zt_ht1);
last_output = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), sih);
if(i < seq_len - 1)
{
......@@ -689,7 +685,7 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
hidden_states =
(seq_index == 0)
? last_output
: prog.insert_instruction(
: m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), hidden_states, last_output);
}
else
......@@ -697,7 +693,7 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
hidden_states =
(seq_index == seq_len - 1)
? last_output
: prog.insert_instruction(
: m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), last_output, hidden_states);
}
}
......@@ -748,7 +744,7 @@ std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
// for lstm operators
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
{
assert(ins->name() == "lstm");
auto args = ins->inputs();
......@@ -767,13 +763,13 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
op::rnn_direction dirct = lstm_op.direction;
// process sequence length
instruction_ref seq_lens = prog.end();
instruction_ref seq_lens = m.end();
if((args.size() >= 5) && args[4]->name() != "undefined")
{
seq_lens = args[4];
}
bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);
bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
instruction_ref last_hs_output{};
instruction_ref last_cell_output{};
......@@ -783,25 +779,25 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
{
// input weight matrix
// input weight matrix
auto w_forward = prog.insert_instruction(
auto w_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
auto w_reverse = prog.insert_instruction(
auto w_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
// hidden state weight matrix
auto r_forward = prog.insert_instruction(
auto r_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
auto r_reverse = prog.insert_instruction(
auto r_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
// process bias
instruction_ref bias_forward = prog.end();
instruction_ref bias_reverse = prog.end();
instruction_ref bias_forward = m.end();
instruction_ref bias_reverse = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias_forward = prog.insert_instruction(
bias_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
bias_reverse = prog.insert_instruction(
bias_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
}
......@@ -810,15 +806,15 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
instruction_ref ih_reverse{};
if(args.size() >= 6 && args[5]->name() != "undefined")
{
ih_forward = prog.insert_instruction(
ih_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
ih_reverse = prog.insert_instruction(
ih_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
}
else
{
ih_forward = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
ih_reverse = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
ih_forward = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
ih_reverse = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
}
// process initial cell value
......@@ -826,30 +822,30 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
instruction_ref ic_reverse{};
if(args.size() >= 7 && args[6]->name() != "undefined")
{
ic_forward = prog.insert_instruction(
ic_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[6]);
ic_reverse = prog.insert_instruction(
ic_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[6]);
}
else
{
ic_forward = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
ic_reverse = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
ic_forward = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
ic_reverse = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
}
// process weight of the peephole
instruction_ref pph_forward = prog.end();
instruction_ref pph_reverse = prog.end();
instruction_ref pph_forward = m.end();
instruction_ref pph_reverse = m.end();
if(args.size() == 8 && args[7]->name() != "undefined")
{
pph_forward = prog.insert_instruction(
pph_forward = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[7]);
pph_reverse = prog.insert_instruction(
pph_reverse = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[7]);
}
auto ret_forward = lstm_cell(true,
prog,
m,
ins,
{args[0],
w_forward,
......@@ -865,11 +861,11 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
if(variable_seq_len)
{
args[0] = prog.insert_instruction(
ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
args[0] =
m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
}
auto ret_reverse = lstm_cell(false,
prog,
m,
ins,
{args[0],
w_reverse,
......@@ -883,36 +879,36 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
actv_funcs.at(4),
actv_funcs.at(5));
auto concat_hs_output = prog.insert_instruction(
auto concat_hs_output = m.insert_instruction(
ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
auto concat_cell_output = prog.insert_instruction(
auto concat_cell_output = m.insert_instruction(
ins, make_op("concat", {{"axis", 1}}), ret_forward[3], ret_reverse[3]);
last_hs_output =
prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_hs_output);
m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_hs_output);
last_cell_output =
prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_cell_output);
m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_cell_output);
// the following logic is to ensure the last instruction is a concat
if(ret_forward[0] == prog.end())
if(ret_forward[0] == m.end())
{
cell_outputs = concat_cell_output;
}
else
{
ret_forward[1] = prog.insert_instruction(
ret_forward[1] = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
ret_reverse[1] = prog.insert_instruction(
ret_reverse[1] = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
ret_forward[3] = prog.insert_instruction(
ret_forward[3] = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), ret_forward[2], ret_forward[3]);
ret_reverse[3] = prog.insert_instruction(
ret_reverse[3] = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), ret_reverse[3], ret_reverse[2]);
cell_outputs = prog.insert_instruction(
cell_outputs = m.insert_instruction(
ins, make_op("concat", {{"axis", 1}}), ret_forward[3], ret_reverse[3]);
}
hidden_state = prog.replace_instruction(
hidden_state = m.replace_instruction(
ins, make_op("concat", {{"axis", 1}}), {ret_forward[1], ret_reverse[1]});
}
else
......@@ -923,7 +919,7 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
auto r = args[2];
// bias
instruction_ref bias = prog.end();
instruction_ref bias = m.end();
if(args.size() >= 4 && args[3]->name() != "undefined")
{
bias = args[3];
......@@ -937,7 +933,7 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
}
else
{
ih = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
ih = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
}
// initial cell value
......@@ -948,11 +944,11 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
}
else
{
ic = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
ic = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
}
// process weight of the peephole
instruction_ref pph = prog.end();
instruction_ref pph = m.end();
if(args.size() == 8 && args[7]->name() != "undefined")
{
pph = args[7];
......@@ -960,54 +956,53 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
if(!is_forward and variable_seq_len)
{
args[0] = prog.insert_instruction(
ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
args[0] =
m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
}
auto ret = lstm_cell(is_forward,
prog,
m,
ins,
{args[0], w, r, bias, seq_lens, ih, ic, pph},
actv_funcs.at(0),
actv_funcs.at(1),
actv_funcs.at(2));
last_hs_output = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
last_cell_output =
prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[3]);
last_hs_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
last_cell_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[3]);
if(ret[0] == prog.end())
if(ret[0] == m.end())
{
cell_outputs = ret[3];
hidden_state = prog.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
hidden_state = m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
}
else
{
auto concat_cell_arg0 = is_forward ? ret[2] : ret[3];
auto concat_cell_arg1 = is_forward ? ret[3] : ret[2];
cell_outputs = prog.insert_instruction(
cell_outputs = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), concat_cell_arg0, concat_cell_arg1);
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0];
hidden_state = prog.replace_instruction(
hidden_state = m.replace_instruction(
ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
}
}
// in case of all sequences are of the same lengths and shorter than the
// max sequence length, need to pad 0's at the end for output hidden states
hidden_state = pad_hidden_states(prog, args[0], seq_lens, hidden_state);
hidden_state = pad_hidden_states(m, args[0], seq_lens, hidden_state);
// replace last hidden states with corresponding instructions
ins = replace_last_hs_output(prog, hidden_state, seq_lens, last_hs_output, dirct);
ins = replace_last_hs_output(m, hidden_state, seq_lens, last_hs_output, dirct);
// replace last cell outputs with corresponding instructions
replace_last_cell_output(prog, ins, seq_lens, cell_outputs, last_cell_output, dirct);
replace_last_cell_output(m, ins, seq_lens, cell_outputs, last_cell_output, dirct);
}
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
module& prog,
module& m,
instruction_ref ins,
std::vector<instruction_ref> inputs,
const operation& actv_func1,
......@@ -1025,8 +1020,8 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto ic = inputs.at(6);
auto pph = inputs.at(7);
instruction_ref hidden_states = prog.end();
instruction_ref cell_outputs = prog.end();
instruction_ref hidden_states = m.end();
instruction_ref cell_outputs = m.end();
instruction_ref last_hs_output{};
instruction_ref last_cell_output{};
......@@ -1037,35 +1032,35 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
std::vector<int64_t> perm{1, 0};
// w matrix, squeeze and transpose
auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tsw = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
auto sw = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
auto tsw = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
// r matrix, squeeze and transpose
auto sr = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto tsr = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr);
auto sr = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
auto tsr = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr);
// initial hidden state
auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
auto sih = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
// initial cell state
auto sic = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ic);
auto sic = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ic);
auto ic_lens = sic->get_shape().lens();
// bias
instruction_ref wrb{};
if(bias != prog.end())
if(bias != m.end())
{
auto sbias = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
auto ub_wb = prog.insert_instruction(
auto sbias = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
auto ub_wb = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {4 * hs}}}), sbias);
auto ub_rb = prog.insert_instruction(
auto ub_rb = m.insert_instruction(
ins,
make_op("slice", {{"axes", {0}}, {"starts", {4 * hs}}, {"ends", {8 * hs}}}),
sbias);
auto ub_wrb = prog.insert_instruction(ins, make_op("add"), ub_wb, ub_rb);
auto ub_wrb = m.insert_instruction(ins, make_op("add"), ub_wb, ub_rb);
wrb = prog.insert_instruction(
wrb = m.insert_instruction(
ins,
make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, 4 * static_cast<size_t>(hs)}}}),
ub_wrb);
......@@ -1075,92 +1070,91 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
instruction_ref pphi_brcst{};
instruction_ref ppho_brcst{};
instruction_ref pphf_brcst{};
if(pph != prog.end())
if(pph != m.end())
{
auto spph = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), pph);
auto pphi = prog.insert_instruction(
auto spph = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), pph);
auto pphi = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), spph);
pphi_brcst = prog.insert_instruction(
pphi_brcst = m.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), pphi);
auto ppho = prog.insert_instruction(
auto ppho = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), spph);
ppho_brcst = prog.insert_instruction(
ppho_brcst = m.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), ppho);
auto pphf = prog.insert_instruction(
auto pphf = m.insert_instruction(
ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), spph);
pphf_brcst = prog.insert_instruction(
pphf_brcst = m.insert_instruction(
ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), pphf);
}
long seq_len = get_seq_len(prog, seq, seq_lens);
long seq_len = get_seq_len(m, seq, seq_lens);
for(long i = 0; i < seq_len; ++i)
{
long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction(
auto xt = m.insert_instruction(
ins,
make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
seq);
auto cont_xt = prog.insert_instruction(ins, make_op("contiguous"), xt);
xt = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
auto cont_xt = m.insert_instruction(ins, make_op("contiguous"), xt);
xt = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
auto xt_tsw = prog.insert_instruction(ins, make_op("dot"), xt, tsw);
auto sih_tsr = prog.insert_instruction(ins, make_op("dot"), sih, tsr);
auto xt_sih = prog.insert_instruction(ins, make_op("add"), xt_tsw, sih_tsr);
if(bias != prog.end())
auto xt_tsw = m.insert_instruction(ins, make_op("dot"), xt, tsw);
auto sih_tsr = m.insert_instruction(ins, make_op("dot"), sih, tsr);
auto xt_sih = m.insert_instruction(ins, make_op("add"), xt_tsw, sih_tsr);
if(bias != m.end())
{
xt_sih = prog.insert_instruction(ins, make_op("add"), xt_sih, wrb);
xt_sih = m.insert_instruction(ins, make_op("add"), xt_sih, wrb);
}
auto it_before_actv = prog.insert_instruction(
auto it_before_actv = m.insert_instruction(
ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), xt_sih);
auto ot_before_actv = prog.insert_instruction(
auto ot_before_actv = m.insert_instruction(
ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), xt_sih);
auto ft_before_actv = prog.insert_instruction(
auto ft_before_actv = m.insert_instruction(
ins,
make_op("slice", {{"axes", {1}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}),
xt_sih);
auto ct_before_actv = prog.insert_instruction(
auto ct_before_actv = m.insert_instruction(
ins,
make_op("slice", {{"axes", {1}}, {"starts", {3 * hs}}, {"ends", {4 * hs}}}),
xt_sih);
if(pph != prog.end())
if(pph != m.end())
{
auto pphi_ct = prog.insert_instruction(ins, make_op("mul"), pphi_brcst, sic);
it_before_actv = prog.insert_instruction(ins, make_op("add"), it_before_actv, pphi_ct);
auto pphi_ct = m.insert_instruction(ins, make_op("mul"), pphi_brcst, sic);
it_before_actv = m.insert_instruction(ins, make_op("add"), it_before_actv, pphi_ct);
auto pphf_ct = prog.insert_instruction(ins, make_op("mul"), pphf_brcst, sic);
ft_before_actv = prog.insert_instruction(ins, make_op("add"), ft_before_actv, pphf_ct);
auto pphf_ct = m.insert_instruction(ins, make_op("mul"), pphf_brcst, sic);
ft_before_actv = m.insert_instruction(ins, make_op("add"), ft_before_actv, pphf_ct);
}
auto it = prog.insert_instruction(ins, actv_func1, it_before_actv);
auto ft = prog.insert_instruction(ins, actv_func1, ft_before_actv);
auto ct = prog.insert_instruction(ins, actv_func2, ct_before_actv);
auto it = m.insert_instruction(ins, actv_func1, it_before_actv);
auto ft = m.insert_instruction(ins, actv_func1, ft_before_actv);
auto ct = m.insert_instruction(ins, actv_func2, ct_before_actv);
// equation Ct = ft (.) Ct-1 + it (.) ct
auto ft_cell = prog.insert_instruction(ins, make_op("mul"), ft, sic);
auto it_ct = prog.insert_instruction(ins, make_op("mul"), it, ct);
auto cellt = prog.insert_instruction(ins, make_op("add"), ft_cell, it_ct);
auto ft_cell = m.insert_instruction(ins, make_op("mul"), ft, sic);
auto it_ct = m.insert_instruction(ins, make_op("mul"), it, ct);
auto cellt = m.insert_instruction(ins, make_op("add"), ft_cell, it_ct);
if(pph != prog.end())
if(pph != m.end())
{
auto ppho_cellt = prog.insert_instruction(ins, make_op("mul"), ppho_brcst, cellt);
ot_before_actv =
prog.insert_instruction(ins, make_op("add"), ot_before_actv, ppho_cellt);
auto ppho_cellt = m.insert_instruction(ins, make_op("mul"), ppho_brcst, cellt);
ot_before_actv = m.insert_instruction(ins, make_op("add"), ot_before_actv, ppho_cellt);
}
auto ot = prog.insert_instruction(ins, actv_func1, ot_before_actv);
auto ot = m.insert_instruction(ins, actv_func1, ot_before_actv);
// Ht = ot (.) h(Ct)
auto h_cellt = prog.insert_instruction(ins, actv_func3, cellt);
auto ht = prog.insert_instruction(ins, make_op("mul"), ot, h_cellt);
auto h_cellt = m.insert_instruction(ins, actv_func3, cellt);
auto ht = m.insert_instruction(ins, make_op("mul"), ot, h_cellt);
sic = cellt;
sih = ht;
last_hs_output = prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), ht);
last_hs_output = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), ht);
last_cell_output =
prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), cellt);
m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), cellt);
if(i < seq_len - 1)
{
......@@ -1173,12 +1167,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
{
auto concat_hs_arg0 = is_forward ? hidden_states : last_hs_output;
auto concat_hs_arg1 = is_forward ? last_hs_output : hidden_states;
hidden_states = prog.insert_instruction(
hidden_states = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), concat_hs_arg0, concat_hs_arg1);
auto concat_cell_arg0 = is_forward ? cell_outputs : last_cell_output;
auto concat_cell_arg1 = is_forward ? last_cell_output : cell_outputs;
cell_outputs = prog.insert_instruction(
cell_outputs = m.insert_instruction(
ins, make_op("concat", {{"axis", 0}}), concat_cell_arg0, concat_cell_arg1);
}
}
......@@ -1266,10 +1260,10 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
}
}
bool rewrite_rnn::is_variable_seq_lens(const module& prog, instruction_ref seq_lens) const
bool rewrite_rnn::is_variable_seq_lens(const module& m, instruction_ref seq_lens) const
{
bool is_var_lens = false;
if(seq_lens != prog.end())
if(seq_lens != m.end())
{
if(seq_lens->can_eval())
{
......@@ -1296,12 +1290,12 @@ bool rewrite_rnn::is_variable_seq_lens(const module& prog, instruction_ref seq_l
}
std::size_t
rewrite_rnn::get_seq_len(const module& prog, instruction_ref input, instruction_ref seq_lens) const
rewrite_rnn::get_seq_len(const module& m, instruction_ref input, instruction_ref seq_lens) const
{
bool is_var_lens = is_variable_seq_lens(prog, seq_lens);
bool is_var_lens = is_variable_seq_lens(m, seq_lens);
auto input_shape = input->get_shape();
auto length = input_shape.lens()[0];
if(!is_var_lens and seq_lens != prog.end())
if(!is_var_lens and seq_lens != m.end())
{
auto arg_len = seq_lens->eval();
std::vector<std::size_t> vec_lens;
......@@ -1312,33 +1306,33 @@ rewrite_rnn::get_seq_len(const module& prog, instruction_ref input, instruction_
return length;
}
instruction_ref rewrite_rnn::replace_last_hs_output(module& prog,
instruction_ref rewrite_rnn::replace_last_hs_output(module& m,
instruction_ref ins,
instruction_ref seq_lens,
instruction_ref last_hs_output,
op::rnn_direction dirct) const
{
bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);
bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
instruction_ref result_ins{};
if(variable_seq_len)
{
result_ins = prog.insert_instruction(
std::next(ins),
make_op("rnn_var_sl_shift_output",
{{"output_name", "hidden_states"}, {"direction", dirct}}),
ins,
seq_lens);
prog.replace_instruction(ins, result_ins);
result_ins =
m.insert_instruction(std::next(ins),
make_op("rnn_var_sl_shift_output",
{{"output_name", "hidden_states"}, {"direction", dirct}}),
ins,
seq_lens);
m.replace_instruction(ins, result_ins);
auto hs_outputs = find_all(result_ins->outputs(),
[&](auto i) { return i->name() == "rnn_last_hs_output"; });
for(auto& hs_out : hs_outputs)
{
auto inputs = hs_out->inputs();
prog.replace_instruction(hs_out,
make_op("rnn_var_sl_last_output", {{"direction", dirct}}),
inputs.front(),
seq_lens);
m.replace_instruction(hs_out,
make_op("rnn_var_sl_last_output", {{"direction", dirct}}),
inputs.front(),
seq_lens);
}
}
else
......@@ -1348,7 +1342,7 @@ instruction_ref rewrite_rnn::replace_last_hs_output(module& prog,
for(auto& hs_out : hs_outputs)
{
prog.replace_instruction(hs_out, last_hs_output);
m.replace_instruction(hs_out, last_hs_output);
}
result_ins = ins;
......@@ -1357,14 +1351,14 @@ instruction_ref rewrite_rnn::replace_last_hs_output(module& prog,
return result_ins;
}
void rewrite_rnn::replace_last_cell_output(module& prog,
void rewrite_rnn::replace_last_cell_output(module& m,
instruction_ref ins,
instruction_ref seq_lens,
instruction_ref cell_outputs,
instruction_ref last_cell_output,
op::rnn_direction dirct) const
{
bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);
bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
auto ins_outputs =
find_all(ins->outputs(), [&](auto i) { return i->name() == "rnn_last_cell_output"; });
......@@ -1372,7 +1366,7 @@ void rewrite_rnn::replace_last_cell_output(module& prog,
{
if(!ins_outputs.empty())
{
cell_outputs = prog.insert_instruction(
cell_outputs = m.insert_instruction(
std::next(ins),
make_op("rnn_var_sl_shift_output",
{{"output_name", "cell_outputs"}, {"direction", dirct}}),
......@@ -1382,10 +1376,10 @@ void rewrite_rnn::replace_last_cell_output(module& prog,
for(auto co : ins_outputs)
{
prog.replace_instruction(co,
make_op("rnn_var_sl_last_output", {{"direction", dirct}}),
cell_outputs,
seq_lens);
m.replace_instruction(co,
make_op("rnn_var_sl_last_output", {{"direction", dirct}}),
cell_outputs,
seq_lens);
}
}
// replace the rnn_last_cell_output with the last_cell_output. The while
......@@ -1394,18 +1388,18 @@ void rewrite_rnn::replace_last_cell_output(module& prog,
{
for(auto co : ins_outputs)
{
prog.replace_instruction(co, last_cell_output);
m.replace_instruction(co, last_cell_output);
}
}
}
instruction_ref rewrite_rnn::pad_hidden_states(module& prog,
instruction_ref rewrite_rnn::pad_hidden_states(module& m,
instruction_ref seq,
instruction_ref seq_lens,
instruction_ref hs) const
{
auto max_seq_len = seq->get_shape().lens()[0];
auto seq_len = get_seq_len(prog, seq, seq_lens);
auto seq_len = get_seq_len(m, seq, seq_lens);
// condition of all sequence are of the same length and
// less than max_seq_len, we need to append the hs outputs
......@@ -1417,23 +1411,13 @@ instruction_ref rewrite_rnn::pad_hidden_states(module& prog,
pad_lens[0] = static_cast<std::size_t>(max_seq_len - seq_len);
shape pad_s{s.type(), pad_lens};
std::vector<float> pad_data(pad_s.elements(), 0.0f);
auto pl = prog.add_literal(pad_s, pad_data.begin(), pad_data.end());
hs_padded =
prog.insert_instruction(std::next(hs), make_op("concat", {{"axis", 0}}), hs, pl);
prog.replace_instruction(hs, hs_padded);
auto pl = m.add_literal(pad_s, pad_data.begin(), pad_data.end());
hs_padded = m.insert_instruction(std::next(hs), make_op("concat", {{"axis", 0}}), hs, pl);
m.replace_instruction(hs, hs_padded);
}
return hs_padded;
}
namespace op {
std::ostream& operator<<(std::ostream& os, rnn_direction v)
{
std::vector<std::string> rnn_direction_str = {"forward", "reverse", "bidirectional"};
os << rnn_direction_str[static_cast<std::underlying_type<rnn_direction>::type>(v)];
return os;
}
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -42,7 +42,7 @@ struct stream_info
std::unordered_map<instruction_ref, std::size_t> iweights;
ins_dep_map mod_implicit_deps;
void calc_implicit_deps(const module& p) { mod_implicit_deps = p.calc_implicit_deps(); }
void calc_implicit_deps(const module& m) { mod_implicit_deps = m.calc_implicit_deps(); }
void accumulate_weights(instruction_ref last, const schedule_model& model)
{
......@@ -116,15 +116,15 @@ struct stream_info
}
};
std::size_t assign_streams(module& p, std::size_t n)
std::size_t assign_streams(module& m, std::size_t n)
{
assert(n > 0);
partition critical;
std::unordered_map<instruction_ref, std::deque<partition>> partitions;
partitions.reserve(weights.size());
fix([&](auto self, auto ins, auto& part) {
assert(not is_end(ins, p.end()));
if(not p.has_instruction(ins))
assert(not is_end(ins, m.end()));
if(not m.has_instruction(ins))
return;
if(contains(partitions, ins))
return;
......@@ -151,8 +151,8 @@ struct stream_info
}
}
// Sort instructions
p.move_instruction(ins, p.end());
})(std::prev(p.end()), critical);
m.move_instruction(ins, m.end());
})(std::prev(m.end()), critical);
// Set the critical partition to stream 0
set_stream(critical, 0);
......@@ -197,13 +197,13 @@ struct stream_info
}
};
void sort(module& p, std::size_t)
void sort(module& m, std::size_t)
{
std::set<weight_ins, compare_weight_ins> children;
std::unordered_map<instruction_ref, std::size_t> visited;
auto last = std::prev(p.end());
auto last = std::prev(m.end());
auto mw = this->weights.at(last);
auto nw = mw / (p.size() + 1);
auto nw = mw / (m.size() + 1);
auto add_child = [&](auto ins) {
auto x = 1 + (mw - this->weights.at(ins)) / (nw + 1);
auto w = x * this->iweights.at(ins);
......@@ -222,10 +222,10 @@ struct stream_info
// Pop the first element
auto top = children.begin()->second;
children.erase(children.begin());
p.move_instruction(top, p.begin());
m.move_instruction(top, m.begin());
for(auto ins : top->inputs())
{
if(not p.has_instruction(ins))
if(not m.has_instruction(ins))
continue;
add_child(ins);
}
......@@ -234,7 +234,7 @@ struct stream_info
{
for(auto ins : mod_implicit_deps.at(top))
{
assert(p.has_instruction(ins));
assert(m.has_instruction(ins));
add_child(ins);
}
}
......@@ -242,12 +242,12 @@ struct stream_info
// move dangling parameter to the front so as not be removed
auto ins = std::next(last);
while(ins != p.end())
while(ins != m.end())
{
auto next = std::next(ins);
if(ins->name() == "@param")
{
p.move_instruction(ins, p.begin());
m.move_instruction(ins, m.begin());
}
ins = next;
}
......@@ -364,18 +364,18 @@ struct stream_info
}
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>>
find_concurrent_instructions(module& p) const
find_concurrent_instructions(module& m) const
{
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result;
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> merge_from;
dominator_info di = compute_dominator(p);
result.reserve(p.size());
merge_from.reserve(p.size());
for(auto ins : reverse_iterator_for(p))
dominator_info di = compute_dominator(m);
result.reserve(m.size());
merge_from.reserve(m.size());
for(auto ins : reverse_iterator_for(m))
{
for(auto&& arg : ins->outputs())
{
if(not p.has_instruction(arg))
if(not m.has_instruction(arg))
continue;
if(is_merge_point(arg))
merge_from[ins].insert(arg);
......@@ -415,18 +415,18 @@ struct stream_info
}
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>
get_conflicts(module& p)
get_conflicts(module& m)
{
using conflict_table_type =
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>;
conflict_table_type conflict_table;
auto concur_ins = this->find_concurrent_instructions(p);
auto concur_ins = this->find_concurrent_instructions(m);
// Compute an index for each instruction
std::unordered_map<instruction_ref, std::size_t> ins2index;
std::size_t index_total = 0;
for(auto ins : iterator_for(p))
for(auto ins : iterator_for(m))
ins2index[ins] = index_total++;
std::vector<conflict_table_type> thread_conflict_tables(
......@@ -507,21 +507,21 @@ struct stream_info
}
};
void schedule::apply(module& p) const
void schedule::apply(module& m) const
{
if(not enable)
return;
stream_info si;
si.calc_implicit_deps(p);
auto last = std::prev(p.end());
si.calc_implicit_deps(m);
auto last = std::prev(m.end());
si.accumulate_weights(last, model);
auto nstreams = si.assign_streams(p, model.concurrency());
si.sort(p, model.concurrency());
auto nstreams = si.assign_streams(m, model.concurrency());
si.sort(m, model.concurrency());
if(enabled(MIGRAPHX_TRACE_COMPILE{}) or enabled(MIGRAPHX_TRACE_SCHEDULE{}))
{
p.annotate(std::cout, [&](auto ins) {
m.annotate(std::cout, [&](auto ins) {
if(ins->name() == "@param" and not contains(si.weights, ins))
return;
......@@ -548,9 +548,9 @@ void schedule::apply(module& p) const
std::unordered_map<instruction_ref, std::size_t> ins2wait;
std::unordered_map<std::size_t, std::unordered_set<std::size_t>> waited_for;
std::unordered_map<instruction_ref, std::unordered_set<std::size_t>> ins2waited;
ins2wait.reserve(p.size());
ins2waited.reserve(p.size());
for(auto ins : iterator_for(p))
ins2wait.reserve(m.size());
ins2waited.reserve(m.size());
for(auto ins : iterator_for(m))
{
// Only schedule instructions that have a stream
if(not si.has_stream(ins))
......@@ -559,7 +559,7 @@ void schedule::apply(module& p) const
// Schedule instruction on the stream
auto stream = si.get_stream(ins);
assert(stream < model.concurrency());
model.sched(p, ins, stream);
model.sched(m, ins, stream);
// Insert wait instructions
if(si.is_merge_point(ins, stream))
{
......@@ -572,14 +572,14 @@ void schedule::apply(module& p) const
if(not contains(ins2wait, i))
{
ins2wait[i] = wait_id;
model.record(p, i, wait_id);
model.record(m, i, wait_id);
wait_id++;
}
auto w = ins2wait.at(i);
// If we already waited for the event on this stream then dont
// insert another wait event
if(not contains(waited_for[stream], w))
model.wait(p, ins, w);
model.wait(m, ins, w);
// Store the event as waited
waited_for[stream].insert(w);
// Store all wait events that have been waited on prior to the recorded instruction
......@@ -594,7 +594,7 @@ void schedule::apply(module& p) const
}
// Add memory conflicts
auto conflict_table = si.get_conflicts(p);
auto conflict_table = si.get_conflicts(m);
for(auto&& ip : conflict_table)
{
if(ip.second.empty())
......@@ -602,7 +602,7 @@ void schedule::apply(module& p) const
std::vector<instruction_ref> args;
args.push_back(ip.first);
args.insert(args.end(), ip.second.begin(), ip.second.end());
p.insert_instruction(std::next(ip.first), make_op("identity"), args);
m.insert_instruction(std::next(ip.first), make_op("identity"), args);
}
}
......
......@@ -86,6 +86,8 @@ struct shape_impl
return std::accumulate(
m_lens.begin(), m_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
}
std::shared_ptr<shape_impl> copy() const { return std::make_shared<shape_impl>(*this); }
};
const std::vector<shape::type_t>& shape::types()
......@@ -135,6 +137,8 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
shape::shape(const std::vector<shape>& subs) : impl(std::make_shared<shape_impl>(subs)) {}
shape::shape(std::shared_ptr<shape_impl> pimpl) : impl(std::move(pimpl)) {}
shape shape::from_permutation(type_t t,
const std::vector<std::size_t>& l,
const std::vector<int64_t>& perm)
......@@ -294,6 +298,13 @@ shape shape::with_lens(const std::vector<std::size_t>& l) const
return this->with_lens(this->type(), l);
}
shape shape::with_type(type_t t) const
{
auto c = impl->copy();
c->m_type = t;
return {c};
}
std::size_t shape::element_space() const { return impl->element_space(); }
std::string shape::type_string() const { return name(this->type()); }
......
......@@ -42,7 +42,7 @@ struct find_mul_conv
match::name("broadcast").bind("a")));
}
void apply(module& p, match::matcher_result r) const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto conv_ins = r.instructions["conv"];
......@@ -53,14 +53,14 @@ struct find_mul_conv
if(broadcast_op.axis != 1)
return;
auto new_a = p.insert_instruction(
auto new_a = m.insert_instruction(
ins,
make_op("broadcast", {{"axis", 0}, {"out_lens", w_ins->get_shape().lens()}}),
a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, make_op("mul"), new_a, w_ins);
auto new_conv = p.insert_instruction(
auto new_mul = m.insert_instruction(ins, make_op("mul"), new_a, w_ins);
auto new_conv = m.insert_instruction(
ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_mul);
p.replace_instruction(ins, new_conv);
m.replace_instruction(ins, new_conv);
}
};
......@@ -80,7 +80,7 @@ struct find_mul_slice_conv
match::name("broadcast")(match::is_constant()).bind("a")));
}
void apply(module& p, match::matcher_result r) const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto slice_ins = r.instructions["slice"];
......@@ -116,38 +116,38 @@ struct find_mul_slice_conv
auto w_slice_op = slice_op;
w_slice_op.axes = {0};
auto slice_w_ins = p.insert_instruction(ins, w_slice_op, w_ins);
auto slice_w_ins = m.insert_instruction(ins, w_slice_op, w_ins);
auto new_a = p.insert_instruction(
auto new_a = m.insert_instruction(
ins,
make_op("broadcast", {{"axis", 0}, {"out_lens", slice_w_ins->get_shape().lens()}}),
a_ins->inputs().front());
auto new_mul = p.insert_instruction(ins, make_op("mul"), new_a, slice_w_ins);
auto new_mul = m.insert_instruction(ins, make_op("mul"), new_a, slice_w_ins);
std::vector<instruction_ref> sliced_weights;
if(slice_op.starts.front() != 0)
sliced_weights.push_back(p.insert_instruction(
sliced_weights.push_back(m.insert_instruction(
ins,
make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", slice_op.starts}}),
w_ins));
sliced_weights.push_back(new_mul);
int64_t end_axis = w_ins->get_shape().lens().at(0);
if(slice_op.ends.front() != end_axis)
sliced_weights.push_back(p.insert_instruction(
sliced_weights.push_back(m.insert_instruction(
ins,
make_op("slice", {{"axes", {0}}, {"starts", slice_op.ends}, {"ends", {end_axis}}}),
w_ins));
auto new_weights =
p.insert_instruction(ins, make_op("concat", {{"axis", 0}}), sliced_weights);
m.insert_instruction(ins, make_op("concat", {{"axis", 0}}), sliced_weights);
auto new_conv = p.insert_instruction(
auto new_conv = m.insert_instruction(
ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_weights);
assert(conv_ins->get_shape() == new_conv->get_shape());
auto slice1 = p.insert_instruction(ins, slice_op, new_conv);
auto slice1 = m.insert_instruction(ins, slice_op, new_conv);
assert(ins->get_shape().lens() == slice1->get_shape().lens());
p.replace_instruction(ins, slice1);
m.replace_instruction(ins, slice1);
// TODO: Check each slice doesn't overlap and that it occurs after slice_ins
auto outputs = conv_ins->outputs();
for(auto output : outputs)
......@@ -171,7 +171,7 @@ struct find_mul_add
match::is_constant().bind("a")));
}
void apply(module& p, match::matcher_result r) const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_ins = r.instructions["a"];
......@@ -179,9 +179,9 @@ struct find_mul_add
auto x_ins = r.instructions["x"];
assert(x_ins != b_ins);
auto ax_ins = p.insert_instruction(ins, make_op("mul"), a_ins, x_ins);
auto ab_ins = p.insert_instruction(ins, make_op("mul"), a_ins, b_ins);
p.replace_instruction(ins, make_op("add"), ax_ins, ab_ins);
auto ax_ins = m.insert_instruction(ins, make_op("mul"), a_ins, x_ins);
auto ab_ins = m.insert_instruction(ins, make_op("mul"), a_ins, b_ins);
m.replace_instruction(ins, make_op("add"), ax_ins, ab_ins);
}
};
......@@ -193,15 +193,15 @@ struct find_add_lit_broadcast
match::either_arg(0, 1)(op_lit_broadcast("add", "a", "x"), lit_broadcast().bind("b")));
}
void apply(module& p, match::matcher_result r) const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto sumab = p.insert_instruction(ins, make_op("add"), a_ins, b_ins);
p.replace_instruction(ins, make_op("add"), x_ins, sumab);
auto sumab = m.insert_instruction(ins, make_op("add"), a_ins, b_ins);
m.replace_instruction(ins, make_op("add"), x_ins, sumab);
}
};
......@@ -213,7 +213,7 @@ struct find_double_add_lit_broadcast
match::args(op_lit_broadcast("add", "a", "x"), op_lit_broadcast("add", "b", "y")));
}
void apply(module& p, match::matcher_result r) const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
......@@ -228,17 +228,17 @@ struct find_double_add_lit_broadcast
if(a_ins->inputs().at(0)->get_shape() != b_ins->inputs().at(0)->get_shape())
return;
auto op = a_ins->get_operator();
auto presum = p.insert_instruction(
auto presum = m.insert_instruction(
ins, make_op("add"), a_ins->inputs().at(0), b_ins->inputs().at(0));
sumab = p.insert_instruction(ins, op, presum);
sumab = m.insert_instruction(ins, op, presum);
}
else
{
sumab = p.insert_instruction(ins, make_op("add"), a_ins, b_ins);
sumab = m.insert_instruction(ins, make_op("add"), a_ins, b_ins);
}
auto sumxy = p.insert_instruction(ins, make_op("add"), x_ins, y_ins);
p.replace_instruction(ins, make_op("add"), sumxy, sumab);
auto sumxy = m.insert_instruction(ins, make_op("add"), x_ins, y_ins);
m.replace_instruction(ins, make_op("add"), sumxy, sumab);
}
};
......@@ -251,7 +251,7 @@ struct find_inner_broadcast
match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y")));
}
void apply(module& p, match::matcher_result r) const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
......@@ -263,9 +263,9 @@ struct find_inner_broadcast
if(xbroadcast.axis != ybroadcast.axis)
return;
auto op = p.insert_instruction(
auto op = m.insert_instruction(
ins, ins->get_operator(), x_ins->inputs().front(), y_ins->inputs().front());
p.replace_instruction(ins, xbroadcast, op);
m.replace_instruction(ins, xbroadcast, op);
}
};
......@@ -296,7 +296,7 @@ struct find_concat_op
return op.name() == "broadcast" or op.attributes().contains("pointwise");
}
void apply(module& p, const match::matcher_result& r) const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto axis = any_cast<op::concat>(ins->get_operator()).axis;
......@@ -330,12 +330,11 @@ struct find_concat_op
return j->inputs().at(i);
});
auto concat =
p.insert_instruction(ins, make_op("concat", {{"axis", iaxis}}), inputs);
m.insert_instruction(ins, make_op("concat", {{"axis", iaxis}}), inputs);
concats.push_back(concat);
}
auto y = p.insert_instruction(ins, op, concats);
auto y = m.insert_instruction(ins, op, concats);
return {y};
};
std::vector<instruction_ref> args;
......@@ -350,9 +349,9 @@ struct find_concat_op
};
group_unique(ins->inputs().begin(), ins->inputs().end(), update_args, pred);
if(args.size() == 1)
p.replace_instruction(ins, args.front());
m.replace_instruction(ins, args.front());
else
p.replace_instruction(ins, make_op("concat", {{"axis", axis}}), args);
m.replace_instruction(ins, make_op("concat", {{"axis", axis}}), args);
}
};
......@@ -479,14 +478,14 @@ struct find_splits
return true;
}
void apply(module& p, const match::matcher_result& r) const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto splits = get_splits(ins);
if(splits.empty())
return;
for(const auto& group : get_split_groups(p, splits))
for(const auto& group : get_split_groups(m, splits))
{
auto start = group.front();
auto split_front = splits.front();
......@@ -501,10 +500,10 @@ struct find_splits
std::next(group.begin()), group.end(), [&](auto i) { return i == start; }));
auto split_idx = 0;
instruction_ref c = p.end();
instruction_ref c = m.end();
if(start->inputs().size() == 1)
{
c = p.insert_instruction(std::next(ins), op, ins);
c = m.insert_instruction(std::next(ins), op, ins);
}
else if(start->inputs().size() == 2)
{
......@@ -531,7 +530,7 @@ struct find_splits
return;
for(auto data : data_args)
p.move_instructions(data, ins);
m.move_instructions(data, ins);
auto slice_op = any_cast<op::slice>(splits.front()->get_operator());
assert(not slice_op.axes.empty());
......@@ -539,16 +538,16 @@ struct find_splits
return;
auto concat_axis = slice_op.axes.front();
// TODO: Check if axises match
auto concat = p.insert_instruction(
auto concat = m.insert_instruction(
ins, make_op("concat", {{"axis", concat_axis}}), data_args);
std::vector<instruction_ref> args;
args.resize(2);
args[split_idx] = ins;
args[data_idx] = concat;
c = p.insert_instruction(std::next(ins), op, args);
c = m.insert_instruction(std::next(ins), op, args);
}
if(c != p.end())
if(c != m.end())
{
for(auto i : group)
{
......@@ -561,11 +560,11 @@ struct find_splits
if(not contains({"reshape", "squeeze", "unsqueeze"}, output->name()))
continue;
auto x =
p.insert_instruction(output, make_op("contiguous"), output->inputs());
p.replace_instruction(output, output->get_operator(), x);
m.insert_instruction(output, make_op("contiguous"), output->inputs());
m.replace_instruction(output, output->get_operator(), x);
}
p.replace_instruction(i, split->get_operator(), c);
m.replace_instruction(i, split->get_operator(), c);
}
}
}
......@@ -580,7 +579,7 @@ struct find_split_concat
match::name("slice")(match::all_of[match::outputs()](match::name("concat")))));
}
void apply(module& p, const match::matcher_result& r) const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
......@@ -620,9 +619,9 @@ struct find_split_concat
args.erase(std::next(it), it + splits.size());
if(args.size() == 1)
p.replace_instruction(concat, args.front());
m.replace_instruction(concat, args.front());
else
p.replace_instruction(concat, concat->get_operator(), args);
m.replace_instruction(concat, concat->get_operator(), args);
}
};
......@@ -665,7 +664,7 @@ struct find_add_convs
return x.stride[0] / y.stride[0];
}
void apply(module& p, match::matcher_result r) const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_conv = r.instructions["a"];
......@@ -694,7 +693,7 @@ struct find_add_convs
if(n == 0)
return;
new_op = a_op;
b_input = p.insert_instruction(
b_input = m.insert_instruction(
ins, make_op("step", {{"axes", {2, 3}}, {"steps", {n, n}}}), b_input);
}
else if(b_op.stride < a_op.stride)
......@@ -703,7 +702,7 @@ struct find_add_convs
if(n == 0)
return;
new_op = b_op;
a_input = p.insert_instruction(
a_input = m.insert_instruction(
ins, make_op("step", {{"axes", {2, 3}}, {"steps", {n, n}}}), a_input);
}
else
......@@ -714,10 +713,10 @@ struct find_add_convs
}
auto concat_input =
p.insert_instruction(ins, make_op("concat", {{"axis", 1}}), a_input, b_input);
m.insert_instruction(ins, make_op("concat", {{"axis", 1}}), a_input, b_input);
auto concat_weights =
p.insert_instruction(ins, make_op("concat", {{"axis", 1}}), a_weights, b_weights);
p.replace_instruction(ins, new_op, concat_input, concat_weights);
m.insert_instruction(ins, make_op("concat", {{"axis", 1}}), a_weights, b_weights);
m.replace_instruction(ins, new_op, concat_input, concat_weights);
}
};
......@@ -738,7 +737,7 @@ struct find_conv_dot_horiz_fusion
{
auto matcher() const { return horiz_conv_dot(); }
void apply(module& p, const match::matcher_result& r) const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
......@@ -786,16 +785,16 @@ struct find_conv_dot_horiz_fusion
}
for(auto arg : args)
p.move_instructions(arg, input);
m.move_instructions(arg, input);
// TODO: Check if axises match
auto concat =
p.insert_instruction(input, make_op("concat", {{"axis", concat_axis}}), args);
auto fused = p.insert_instruction(std::next(input), op, input, concat);
m.insert_instruction(input, make_op("concat", {{"axis", concat_axis}}), args);
auto fused = m.insert_instruction(std::next(input), op, input, concat);
int64_t offset = 0;
for(auto arg : range(start, last))
{
int64_t len = arg->get_shape().lens()[axis];
p.replace_instruction(
m.replace_instruction(
arg,
make_op("slice",
{{"axes", {axis}}, {"starts", {offset}}, {"ends", {offset + len}}}),
......@@ -816,16 +815,16 @@ struct find_div_const
return match::name("div")(match::arg(1)(match::is_constant().bind("c")));
}
void apply(module& p, match::matcher_result r) const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto c_ins = r.instructions["c"];
auto recip = p.insert_instruction(std::next(c_ins), make_op("recip"), c_ins);
auto recip = m.insert_instruction(std::next(c_ins), make_op("recip"), c_ins);
auto args = ins->inputs();
p.replace_instruction(ins, make_op("mul"), args.front(), recip);
m.replace_instruction(ins, make_op("mul"), args.front(), recip);
}
};
......@@ -836,16 +835,16 @@ struct find_sub_const
return match::name("sub")(match::arg(1)(match::is_constant().bind("c")));
}
void apply(module& p, match::matcher_result r) const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto c_ins = r.instructions["c"];
auto neg = p.insert_instruction(std::next(c_ins), make_op("neg"), c_ins);
auto neg = m.insert_instruction(std::next(c_ins), make_op("neg"), c_ins);
auto args = ins->inputs();
p.replace_instruction(ins, make_op("add"), args.front(), neg);
m.replace_instruction(ins, make_op("add"), args.front(), neg);
}
};
......@@ -857,12 +856,12 @@ struct find_rsqrt
match::name("sqrt")(match::used_once(), match::args(match::any().bind("x")))));
}
void apply(module& p, match::matcher_result r) const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
p.replace_instruction(ins, make_op("rsqrt"), x_ins);
m.replace_instruction(ins, make_op("rsqrt"), x_ins);
}
};
......@@ -882,7 +881,7 @@ struct find_split_reshape
.bind("reshape");
}
void apply(module& p, match::matcher_result r) const
void apply(module& m, const match::matcher_result& r) const
{
auto slc = r.instructions["slice"];
auto rsp = r.instructions["reshape"];
......@@ -937,14 +936,14 @@ struct find_split_reshape
rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0});
// insert the reshape instruction
auto rsp_ins = p.insert_instruction(
auto rsp_ins = m.insert_instruction(
std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input);
// replace the original reshape with slice
int64_t start = 0;
for(std::size_t i = 0; i < vec_rsp.size(); ++i)
{
p.replace_instruction(
m.replace_instruction(
vec_rsp[i],
make_op(
"slice",
......@@ -963,7 +962,7 @@ struct find_split_transpose
.bind("trans");
}
void apply(module& p, match::matcher_result r) const
void apply(module& m, const match::matcher_result& r) const
{
auto slc = r.instructions["slice"];
auto trans = r.instructions["trans"];
......@@ -989,14 +988,14 @@ struct find_split_transpose
}
// insert an transpose instruction
auto tr = p.insert_instruction(
auto tr = m.insert_instruction(
std::next(input), make_op("transpose", {{"permutation", perm}}), input);
// compute the axis in the slice
auto axis = any_cast<op::slice>(slc->get_operator()).axes.front();
auto it = std::find(perm.begin(), perm.end(), axis);
assert(it != perm.end());
auto axis_new = static_cast<int64_t>(std::distance(perm.begin(), it));
int64_t axis_new = std::distance(perm.begin(), it);
for(auto in : split_outputs)
{
......@@ -1004,7 +1003,7 @@ struct find_split_transpose
auto starts = oper.starts;
auto ends = oper.ends;
auto tr_orig = in->outputs().front();
p.replace_instruction(
m.replace_instruction(
tr_orig,
make_op("slice", {{"axes", {axis_new}}, {"starts", starts}, {"ends", ends}}),
tr);
......@@ -1012,12 +1011,12 @@ struct find_split_transpose
}
};
void simplify_algebra::apply(module& p) const
void simplify_algebra::apply(module& m) const
{
// Run simplifications multiple times
for(int i = 0; i < 8; i++)
{
match::find_matches(p,
match::find_matches(m,
find_inner_broadcast{},
find_double_add_lit_broadcast{},
find_add_lit_broadcast{},
......@@ -1034,7 +1033,7 @@ void simplify_algebra::apply(module& p) const
find_splits{},
find_split_reshape{},
find_split_transpose{});
dead_code_elimination{}.apply(p);
dead_code_elimination{}.apply(m);
}
}
......
......@@ -53,7 +53,7 @@ struct match_find_quantizable_ops
match::arg(1)(dequantizelinear_op("x2", "scale2")));
}
void apply(module& m, match::matcher_result r) const
void apply(module& m, const match::matcher_result& r) const
{
auto qop = r.result;
auto q1 = r.instructions["x1"];
......
......@@ -70,19 +70,19 @@ struct find_reshaper
match::any_of[match::outputs()](match::name(reshaper_names())));
}
void apply(module& p, const match::matcher_result& mr) const
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back()))
{
assert(!reshapes.back()->inputs().empty());
assert(p.has_instruction(reshapes.back()->inputs().front()));
assert(m.has_instruction(reshapes.back()->inputs().front()));
auto input = reshapes.back()->inputs().front();
reshapes.push_back(input);
}
std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()};
std::pair<instruction_ref, instruction_ref> r{m.end(), m.end()};
for(auto start : iterator_for(reshapes))
{
auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
......@@ -96,7 +96,7 @@ struct find_reshaper
}
if(r.first != r.second)
{
p.replace_instruction(r.first, r.second);
m.replace_instruction(r.first, r.second);
}
}
};
......@@ -117,10 +117,10 @@ struct find_nop_reshapes
return match::name(reshapes)(match::same_shape(match::arg(0)));
}
void apply(module& p, const match::matcher_result& mr) const
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
p.replace_instruction(ins, ins->inputs().front());
m.replace_instruction(ins, ins->inputs().front());
}
};
......@@ -132,7 +132,7 @@ struct find_transpose
match::skip_output(match::name("contiguous"))(match::name("transpose"))));
}
void apply(module& p, const match::matcher_result& mr) const
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto x = ins;
......@@ -149,11 +149,11 @@ struct find_transpose
return;
if(is_no_transpose(dims))
{
p.replace_instruction(ins, t->inputs().front());
m.replace_instruction(ins, t->inputs().front());
}
else
{
p.replace_instruction(
m.replace_instruction(
ins, make_op("transpose", {{"permutation", dims}}), t->inputs().front());
}
}
......@@ -223,7 +223,7 @@ struct find_nested_slice
return result;
}
void apply(module& p, const match::matcher_result& mr) const
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto slice = ins->inputs().front();
......@@ -241,7 +241,7 @@ struct find_nested_slice
op.starts.push_back(pp.second.first);
op.ends.push_back(pp.second.second);
}
p.replace_instruction(ins, op, input);
m.replace_instruction(ins, op, input);
}
};
......@@ -252,7 +252,7 @@ struct find_concat_transpose
return match::name("concat")(match::all_of[match::inputs()](match::transpose_shape()));
}
void apply(module& p, const match::matcher_result& mr) const
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto trans_inputs = ins->inputs();
......@@ -279,14 +279,14 @@ struct find_concat_transpose
std::vector<instruction_ref> inputs;
std::transform(
ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto i) {
return p.insert_instruction(
return m.insert_instruction(
ins, make_op("transpose", {{"permutation", permutation}}), i);
});
auto concat = p.insert_instruction(ins, op, inputs);
auto t = p.insert_instruction(
auto concat = m.insert_instruction(ins, op, inputs);
auto t = m.insert_instruction(
ins, make_op("transpose", {{"permutation", ipermutation}}), concat);
assert(ins->get_shape().lens() == t->get_shape().lens());
p.replace_instruction(ins, t);
m.replace_instruction(ins, t);
}
};
......@@ -303,7 +303,7 @@ struct find_nested_concat
return op.axis;
}
void apply(module& p, const match::matcher_result& mr) const
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto axis = get_axis(ins);
......@@ -316,9 +316,8 @@ struct find_nested_concat
else
args.push_back(i);
}
})(ins->inputs());
p.replace_instruction(ins, ins->get_operator(), args);
m.replace_instruction(ins, ins->get_operator(), args);
}
};
......@@ -330,7 +329,7 @@ struct find_resize
match::args(match::name("reshape").bind("data"), match::is_constant().bind("ind")));
}
void apply(module& p, match::matcher_result r) const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto ins_rsp = r.instructions["data"];
......@@ -418,13 +417,13 @@ struct find_resize
}
auto in_rsp = ins_rsp->inputs().front();
auto rsp_data = p.insert_instruction(
auto rsp_data = m.insert_instruction(
ins_rsp, migraphx::make_op("reshape", {{"dims", in_dims}}), in_rsp);
auto mb_rsp = p.insert_instruction(
auto mb_rsp = m.insert_instruction(
ins_rsp, migraphx::make_op("multibroadcast", {{"out_lens", out_dims}}), rsp_data);
auto std_mb = p.insert_instruction(ins, migraphx::make_op("contiguous"), mb_rsp);
auto std_mb = m.insert_instruction(ins, migraphx::make_op("contiguous"), mb_rsp);
std::vector<int64_t> rsp_dims(out_lens.begin(), out_lens.end());
p.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), std_mb);
m.replace_instruction(ins, migraphx::make_op("reshape", {{"dims", rsp_dims}}), std_mb);
}
};
......@@ -437,7 +436,7 @@ struct find_where_op
match::is_constant().bind("ind")));
}
void apply(module& p, match::matcher_result r) const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto concat = r.instructions["data"];
......@@ -476,11 +475,11 @@ struct find_where_op
if(val)
{
p.replace_instruction(ins, inputs.at(0));
m.replace_instruction(ins, inputs.at(0));
}
else
{
p.replace_instruction(ins, inputs.at(1));
m.replace_instruction(ins, inputs.at(1));
}
}
};
......@@ -497,7 +496,7 @@ struct find_reshape_cont
match::any()));
}
void apply(module& p, match::matcher_result r) const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto ins_cont = r.instructions["cont"];
......@@ -531,11 +530,11 @@ struct find_reshape_cont
else
{
inputs.push_back(
p.insert_instruction(ins, make_op("reshape", {{"dims", dims}}), in));
m.insert_instruction(ins, make_op("reshape", {{"dims", dims}}), in));
}
}
auto out = p.insert_instruction(ins, ins->get_operator(), inputs);
p.replace_instruction(ins, make_op("reshape", {{"dims", out_dims}}), out);
auto out = m.insert_instruction(ins, ins->get_operator(), inputs);
m.replace_instruction(ins, make_op("reshape", {{"dims", out_dims}}), out);
}
};
......@@ -565,25 +564,25 @@ struct find_transpose_contiguous_reshaper_unary
match::args(match_transpose_contiguous_reshaper()));
}
void apply(module& p, match::matcher_result r) const
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto reshaper_ins = r.instructions["reshaper_ins"];
auto trans_ins = r.instructions["trans_ins"];
auto cont_ins = r.instructions["cont_ins"];
auto unary_op_name = ins->get_operator().name();
auto unary_ins = p.insert_instruction(cont_ins, make_op(unary_op_name), trans_ins);
auto new_cont_ins = p.insert_instruction(cont_ins, make_op("contiguous"), unary_ins);
auto unary_ins = m.insert_instruction(cont_ins, make_op(unary_op_name), trans_ins);
auto new_cont_ins = m.insert_instruction(cont_ins, make_op("contiguous"), unary_ins);
// older cont and reshape are removed by deadcode elimination
p.replace_instruction(ins, reshaper_ins->get_operator(), new_cont_ins);
m.replace_instruction(ins, reshaper_ins->get_operator(), new_cont_ins);
}
};
void simplify_reshapes::apply(module& p) const
void simplify_reshapes::apply(module& m) const
{
for(int i = 0; i < 2; i++)
{
match::find_matches(p,
match::find_matches(m,
find_where_op{},
find_resize{},
find_reshape_cont{},
......@@ -595,7 +594,7 @@ void simplify_reshapes::apply(module& p) const
find_nested_slice{},
find_nested_concat{},
find_transpose_contiguous_reshaper_unary{});
dead_code_elimination{}.apply(p);
dead_code_elimination{}.apply(m);
}
}
......
......@@ -20,7 +20,6 @@ struct cpu_copy : reduce_dims_base, auto_register_op<cpu_copy>
return inputs.at(1);
}
argument
// cppcheck-suppress constParameter
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{
argument result = get_arg(args, args.size() - 1);
......
......@@ -26,7 +26,6 @@ struct cpu_gather : auto_register_op<cpu_gather>
}
argument
// cppcheck-suppress constParameter
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{
std::size_t nelements = output_shape.elements();
......
......@@ -7,7 +7,16 @@
#ifdef MIGRAPHX_DISABLE_OMP
#include <migraphx/par_for.hpp>
#else
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wreserved-identifier"
#endif
#include <omp.h>
#ifdef __clang__
#pragma clang diagnostic pop
#endif
#endif
namespace migraphx {
......
......@@ -213,7 +213,6 @@ template <std::size_t N, class... Xs>
bool is_vectorizable(const Xs&... xs)
{
return all_of({xs...}, [](const auto& s) {
if(s.standard() and (s.lens().back() % N) == 0)
return true;
if(s.broadcasted())
......@@ -320,11 +319,10 @@ struct cpu_unary : reduce_dims_base, auto_register_op<cpu_unary<Op>>
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2);
auto s = inputs.at(0);
const auto& s = inputs.at(0);
return {s.type(), s.lens()};
}
argument
// cppcheck-suppress constParameter
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{
argument result = get_arg(args, args.size() - 1);
......@@ -358,12 +356,11 @@ struct cpu_binary : reduce_dims_base, auto_register_op<cpu_binary<Op>>
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(3);
auto s = inputs.at(0);
const auto& s = inputs.at(0);
return {s.type(), s.lens()};
}
argument
// cppcheck-suppress constParameter
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{
argument result = get_arg(args, args.size() - 1);
......
......@@ -223,7 +223,7 @@ struct cpu_unary2 : auto_register_op<cpu_unary2<Op>>
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(1);
auto s = inputs.at(0);
const auto& s = inputs.at(0);
return {s.type(), s.lens()};
}
......@@ -352,7 +352,7 @@ struct cpu_apply
std::transform(bind_inputs.begin(),
bind_inputs.end(),
std::back_inserter(inputs),
[&](const auto& s) { return r.instructions.at(s); });
[&](const auto& s) { return r.instructions[s]; });
inputs.push_back(this->insert_allocation(ins, ins->get_shape()));
modl->replace_instruction(ins, op, inputs);
});
......@@ -460,11 +460,6 @@ struct cpu_apply
if(has_op("dnnl::pooling") and ins->get_shape().type() == shape::type_t::float_type and
not v["ceil_mode"].to<bool>())
return replace(ins, make_op("dnnl::pooling", op.to_value()));
std::string mode = v["mode"].to<std::string>();
if(mode == "max")
return replace(ins, make_op("cpu::pooling_max", v));
else if(mode == "average")
return replace(ins, make_op("cpu::pooling_average", v));
return ins;
}
......
......@@ -11,125 +11,14 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace cpu {
struct max_pool
{
static std::string name() { return "max"; }
template <class T>
static T start()
{
return std::numeric_limits<T>::lowest();
}
static double apply(double x, double y)
{
double m = std::max(x, y);
return (m);
}
static double final(double x, std::size_t) { return (x); }
};
struct avg_pool
{
static std::string name() { return "average"; }
template <class T>
static double start()
{
return 0.0;
}
static double apply(double x, double y) { return x + y; }
static double final(double x, std::size_t y) { return (y == 0) ? 0.0 : (x / y); }
};
template <class Op>
struct cpu_pooling : auto_register_op<cpu_pooling<Op>>
{
cpu_pooling() = default;
cpu_pooling(op::pooling pop) : op(std::move(pop)) {}
op::pooling op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "cpu::pooling_" + Op::name(); }
shape compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
return op.normalize_compute_shape(inputs);
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
visit_all(args.back(), args[0])([&](auto output, auto input) {
using type = typename decltype(output)::value_type;
auto in_s = input.get_shape();
auto in_lens = in_s.lens();
std::vector<std::size_t> vec_len(in_lens.begin() + 2, in_lens.end());
par_for(output_shape.elements(), [&](auto i) {
auto idx_o = output_shape.multi(i);
auto n_dim = idx_o.size();
std::vector<std::size_t> win_start;
std::vector<std::size_t> win_size;
for(std::size_t dim = 2; dim < n_dim; ++dim)
{
auto d_2 = dim - 2;
int start = static_cast<int>(idx_o[dim] * op.stride[d_2]) -
static_cast<int>(op.padding[d_2]);
int end = std::min(start + op.lengths[d_2], in_lens[dim]);
start = std::max(start, 0);
win_start.push_back(start);
win_size.push_back(end - start);
}
shape win_shape{output_shape.type(), win_size};
auto pool_size = win_shape.elements();
double acc = Op::template start<type>();
shape_for_each(win_shape, [&](auto idx_w) {
auto idx = idx_o;
std::transform(idx_w.begin(),
idx_w.end(),
win_start.begin(),
idx.begin() + 2,
[](auto ii, auto jj) { return ii + jj; });
if(std::all_of(idx.begin() + 2, idx.end(), [&](auto ii) { return ii >= 0; }) and
idx < in_lens)
{
acc = Op::apply(acc, input[in_s.index(idx)]);
}
});
output[i] = type(Op::final(acc, pool_size));
});
});
return args.back();
}
};
template struct cpu_pooling<avg_pool>;
template struct cpu_pooling<max_pool>;
struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::pooling>
{
std::vector<int> arg_map(int) const { return {MIGRAPHX_DNNL_PREFIX(ARG_SRC)}; }
dnnl::pooling_forward::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
auto algo = op.mode == "max" ? dnnl::algorithm::pooling_max : dnnl::algorithm::pooling_avg;
auto algo = op.mode == op::pooling_mode::max ? dnnl::algorithm::pooling_max
: dnnl::algorithm::pooling_avg;
auto kdims = op.kdims();
std::vector<size_t> padding_l(op.padding.begin(), op.padding.begin() + kdims);
std::vector<size_t> padding_r(op.padding.begin() + kdims, op.padding.end());
......@@ -145,5 +34,6 @@ struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::po
};
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -11,7 +11,7 @@ if(NOT TARGET MIOpen)
endif()
include(Embed)
file(GLOB KERNEL_FILES
file(GLOB KERNEL_FILES ${CONFIGURE_DEPENDS}
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp)
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
add_embed_library(migraphx_kernels ${KERNEL_FILES})
......@@ -93,7 +93,7 @@ add_library(migraphx_device
)
add_library(compile_for_gpu INTERFACE)
target_compile_options(compile_for_gpu INTERFACE -std=c++17 -fno-gpu-rdc -Wno-cuda-compat -Wno-unused-command-line-argument -Xclang -fallow-half-arguments-and-returns)
target_link_libraries(compile_for_gpu INTERFACE hip::device -fno-gpu-rdc -Wno-invalid-command-line-argument -Wno-unused-command-line-argument)
target_link_libraries(compile_for_gpu INTERFACE hip::device -fno-gpu-rdc -Wno-invalid-command-line-argument -Wno-unused-command-line-argument -Wno-option-ignored)
check_cxx_compiler_flag("--cuda-host-only -fhip-lambda-host-device -x hip" HAS_HIP_LAMBDA_HOST_DEVICE)
if(HAS_HIP_LAMBDA_HOST_DEVICE)
message(STATUS "Enable -fhip-lambda-host-device")
......@@ -114,11 +114,13 @@ foreach(KERNEL_FILE ${KERNEL_FILES})
file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/kernels/include/migraphx/kernels/${KERNEL_BASE_FILE}.cpp "#include <migraphx/kernels/${KERNEL_BASE_FILE}.hpp>\n")
target_sources(kernel_file_check PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/kernels/include/migraphx/kernels/${KERNEL_BASE_FILE}.cpp)
endforeach()
target_compile_definitions(kernel_file_check PRIVATE -DMIGRAPHX_NLOCAL=256)
target_include_directories(kernel_file_check PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/>)
target_link_libraries(kernel_file_check compile_for_gpu)
rocm_clang_tidy_check(kernel_file_check)
file(GLOB JIT_GPU_SRCS ${CONFIGURE_DEPENDS} ${CMAKE_CURRENT_SOURCE_DIR}/jit/*.cpp)
add_library(migraphx_gpu
abs.cpp
analyze_streams.cpp
......@@ -129,10 +131,10 @@ add_library(migraphx_gpu
clip.cpp
code_object_op.cpp
compile_ops.cpp
compile_gen.cpp
compile_hip.cpp
compile_hip_code_object.cpp
compile_pointwise.cpp
compile_roialign.cpp
compiler.cpp
concat.cpp
convert.cpp
convolution.cpp
......@@ -157,6 +159,7 @@ add_library(migraphx_gpu
nonzero.cpp
pack_args.cpp
pack_int8_args.cpp
prefuse_ops.cpp
pad.cpp
pooling.cpp
quant_convolution.cpp
......@@ -170,6 +173,7 @@ add_library(migraphx_gpu
target.cpp
topk.cpp
write_literals.cpp
${JIT_GPU_SRCS}
)
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
......@@ -330,6 +334,12 @@ target_compile_definitions(migraphx_gpu PRIVATE
"-DMIGRAPHX_EXTRACT_KERNEL=${MIGRAPHX_EXTRACT_KERNEL}"
"-DMIGRAPHX_USE_HIPRTC=0"
)
if(DEFINED CMAKE_CXX_COMPILER_LAUNCHER)
execute_process(COMMAND which ${CMAKE_CXX_COMPILER_LAUNCHER} OUTPUT_VARIABLE MIGRAPHX_HIP_COMPILER_LAUNCHER)
string(STRIP "${MIGRAPHX_HIP_COMPILER_LAUNCHER}" MIGRAPHX_HIP_COMPILER_LAUNCHER)
target_compile_definitions(migraphx_gpu PRIVATE "-DMIGRAPHX_HIP_COMPILER_LAUNCHER=${MIGRAPHX_HIP_COMPILER_LAUNCHER}")
endif()
endif()
# Check miopen find mode api
......
......@@ -28,30 +28,30 @@ struct hip_stream_model
bool is_wait(migraphx::instruction_ref ins) const { return ins->name() == "gpu::wait_event"; }
};
stream_model make_stream_model(const module& p)
stream_model make_stream_model(const module& m)
{
hip_stream_model m;
hip_stream_model hsm;
std::size_t stream = 0;
for(auto ins : iterator_for(p))
for(auto ins : iterator_for(m))
{
if(ins->name() == "gpu::set_stream")
{
auto v = ins->get_operator().to_value();
stream = v["stream"].to<std::size_t>();
m.max_stream = std::max(stream, m.max_stream);
auto v = ins->get_operator().to_value();
stream = v["stream"].to<std::size_t>();
hsm.max_stream = std::max(stream, hsm.max_stream);
}
if(ins->get_operator().is_context_free())
continue;
if(contains({"hip::hip_allocate_memory", "hip::hip_copy_literal", "@param"}, ins->name()))
continue;
m.ins2stream[ins] = stream;
hsm.ins2stream[ins] = stream;
}
return m;
return hsm;
}
std::vector<stream_race> analyze_streams(const module& p)
std::vector<stream_race> analyze_streams(const module& m)
{
return migraphx::analyze_streams(p, make_stream_model(p));
return migraphx::analyze_streams(m, make_stream_model(m));
}
} // namespace gpu
......
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace gen {
static std::vector<std::size_t> vector_sizes(const std::vector<shape>& inputs)
{
// If all inputs are half then only use half2
if(std::all_of(inputs.begin(), inputs.end(), [](const auto& s) {
return s.type() == shape::half_type;
}))
return {2};
return {4, 2};
}
vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs)
{
auto sizes = vector_sizes(inputs);
std::vector<std::size_t> max_vec_size;
std::transform(inputs.begin(),
inputs.end(),
std::back_inserter(max_vec_size),
[&](const auto& input) -> std::size_t {
auto stride = input.strides()[axis];
auto len = input.lens()[axis];
if(stride != 0 and stride != 1)
return 1;
if(len == 1 and input.elements() > sizes.front())
return sizes.front();
auto it = std::find_if(
sizes.begin(), sizes.end(), [&](auto i) { return (len % i) == 0; });
if(it != sizes.end())
return *it;
return 1;
});
return {*std::min_element(max_vec_size.begin(), max_vec_size.end()), axis};
}
std::string vectorize::str() const
{
return "vectorize<" + to_string(size) + ", " + to_string(axis) + ">()";
}
preload preload::broadcasts(std::size_t axis, const std::vector<shape>& inputs)
{
const std::size_t max_lds_bytes = 4096;
std::vector<bool> result;
std::transform(inputs.begin(),
inputs.end(),
std::back_inserter(result),
[&](const shape& input) { return input.strides()[axis] == 0; });
auto bytes = std::inner_product(inputs.begin(),
inputs.end(),
result.begin(),
std::size_t{0},
std::plus<>{},
[](const shape& s, bool b) -> std::size_t {
if(b)
return s.bytes();
return 0;
});
if(bytes < max_lds_bytes)
return {result};
// TODO: Try to partially preload items
std::fill(result.begin(), result.end(), false);
return {result};
}
std::string preload::str() const
{
std::vector<std::string> bool_strs;
std::transform(args.begin(), std::prev(args.end()), std::back_inserter(bool_strs), [](bool b) {
if(b)
return "true";
return "false";
});
return "auto_preload<false, " + join_strings(bool_strs, ", ") + ">(idx)";
}
bool preload::is_preloading() const
{
return std::accumulate(args.begin(), args.end(), false, std::logical_or<>{});
}
std::size_t find_fast_axis(const std::vector<shape>& inputs)
{
auto permutation = find_permutation(inputs);
auto it = std::max_element(permutation.begin(), permutation.end());
return it - permutation.begin();
}
std::string make_transformer_args(std::vector<std::string> transformers)
{
return join_strings(std::move(transformers), ", ");
}
} // namespace gen
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -21,6 +21,8 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DEBUG);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_OPTIMIZE);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_ASM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DUMP_SRC);
#if MIGRAPHX_USE_HIPRTC
......@@ -178,6 +180,19 @@ bool is_hip_clang_compiler()
return result;
}
bool has_compiler_launcher()
{
static const auto result = fs::exists(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_LAUNCHER));
return result;
}
src_compiler assemble(src_compiler compiler)
{
compiler.out_ext = ".S";
compiler.flags = replace_string(compiler.flags, " -c", " -S");
return compiler;
}
std::vector<std::vector<char>>
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch)
{
......@@ -210,6 +225,10 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
src_compiler compiler;
compiler.flags = params;
compiler.compiler = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER);
#ifdef MIGRAPHX_HIP_COMPILER_LAUNCHER
if(has_compiler_launcher())
compiler.launcher = MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_LAUNCHER);
#endif
if(is_hcc_compiler())
compiler.process = [&](const fs::path& obj_path) -> fs::path {
......@@ -228,6 +247,22 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
MIGRAPHX_THROW("Missing hsaco");
};
if(enabled(MIGRAPHX_GPU_DUMP_SRC{}))
{
for(const auto& src : srcs)
{
if(src.path.extension() != ".cpp")
continue;
std::cout << std::string(src.content.first, src.len()) << std::endl;
}
}
if(enabled(MIGRAPHX_GPU_DUMP_ASM{}))
{
std::cout << assemble(compiler).compile(srcs).data() << std::endl;
}
return {compiler.compile(srcs)};
}
......@@ -238,13 +273,6 @@ std::string enum_params(std::size_t count, std::string param)
return join_strings(items, ",");
}
std::size_t compute_global(std::size_t n, std::size_t local)
{
std::size_t groups = (n + local - 1) / local;
std::size_t nglobal = std::min<std::size_t>(256, groups) * local;
return nglobal;
}
#endif // MIGRAPHX_USE_HIPRTC
} // namespace gpu
......
......@@ -93,8 +93,47 @@ const std::vector<std::string>& compiler_warnings()
return warnings;
}
void hip_compile_options::set_launch_params(
const value& v,
const std::function<std::size_t(std::size_t local)>& compute_global,
std::size_t default_local)
{
local = v.get("local", default_local);
if(v.contains("global"))
global = v.at("global").to<std::size_t>();
else
global = compute_global(local);
}
std::function<std::size_t(std::size_t local)>
compute_global_for(context& ctx, std::size_t n, std::size_t over)
{
assert(over > 0);
std::size_t max_global = ctx.get_current_device().get_cu_count() *
ctx.get_current_device().get_max_workitems_per_cu();
return [n, over, max_global](std::size_t local) {
std::size_t groups = (n + local - 1) / local;
std::size_t max_blocks = max_global / local;
std::size_t nglobal = std::min(max_blocks * over, groups) * local;
return nglobal;
};
}
std::size_t compute_block_size(std::size_t n, std::size_t max_block_size)
{
size_t block_size = 128;
while(block_size <= max_block_size and block_size <= n)
block_size *= 2;
return block_size / 2;
}
operation compile_hip_code_object(const std::string& content, hip_compile_options options)
{
assert(options.global > 0);
assert(options.local > 0);
assert(not options.inputs.empty());
assert(options.inputs.size() == options.virtual_inputs.size() or
options.virtual_inputs.empty());
std::vector<src_file> srcs;
std::transform(migraphx_kernels().begin(),
migraphx_kernels().end(),
......
......@@ -6,12 +6,14 @@
#include <migraphx/par_for.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/gpu/compile_pointwise.hpp>
#include <migraphx/gpu/compiler.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_COMPILE_PARALLEL);
struct precompile_op
{
operation op = op::identity{};
......@@ -38,41 +40,22 @@ struct precompile_op
MIGRAPHX_REGISTER_OP(precompile_op);
struct pointwise_compiler
struct compiled_result
{
std::string name() const { return "pointwise"; }
operation apply(context& ctx, instruction_ref ins, const operation&) const
{
assert(not ins->module_inputs().empty());
auto* pm = ins->module_inputs().front();
return compile_pointwise(ctx, to_shapes(ins->inputs()), *pm);
}
compiler_replace replace;
instruction_ref ins;
};
using compiler_function = std::function<operation(context&, instruction_ref, operation)>;
template <class T>
compiler_function make_compiler_function(T x)
template <class F>
void par_compile(std::size_t n, F f)
{
return {[=](auto&&... xs) { return x.apply(xs...); }};
if(n == 0)
return;
par_for(n, n / value_of(MIGRAPHX_GPU_COMPILE_PARALLEL{}, n), f);
}
template <class... Ts>
std::unordered_map<std::string, compiler_function> make_compilers(Ts... xs)
{
return {{xs.name(), make_compiler_function(xs)}...};
}
struct compiled_result
{
operation op;
instruction_ref ins;
};
void compile_ops::apply(module& m) const
{
auto compilers = make_compilers(pointwise_compiler{});
std::vector<std::function<compiled_result()>> compiles;
for(auto ins : iterator_for(m))
......@@ -80,15 +63,15 @@ void compile_ops::apply(module& m) const
if(ins->name() != "gpu::precompile_op")
continue;
operation preop = any_cast<precompile_op>(ins->get_operator()).op;
assert(contains(compilers, preop.name()));
auto c = compilers[preop.name()];
compiles.emplace_back([=]() -> compiled_result { return {c(*ctx, ins, preop), ins}; });
compiles.emplace_back([=]() -> compiled_result {
return {compile(*ctx, ins, preop), ins};
});
}
std::vector<compiled_result> results(compiles.size());
par_for(compiles.size(), 1, [&](auto i) { results[i] = compiles[i](); });
par_compile(compiles.size(), [&](auto i) { results[i] = compiles[i](); });
for(const auto& cr : results)
{
m.replace_instruction(cr.ins, cr.op, cr.ins->inputs());
cr.replace(m, cr.ins);
}
}
......
#include <migraphx/gpu/compile_pointwise.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
static const char* const pointwise_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
__global__ void kernel(${params})
{
pointwise(${lambda}, ${args});
}
}
} // namespace migraphx
int main() {}
)__migraphx__";
operation compile_pointwise(context&,
const std::vector<shape>& inputs,
const std::string& lambda,
const std::string& preamble)
{
hip_compile_options options;
options.global = compute_global(inputs.front().elements());
options.local = 1024;
options.inputs = inputs;
options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs);
options.params = "-Wno-float-equal";
auto src = interpolate_string(pointwise_kernel,
{{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"lambda", lambda},
{"preamble", preamble}});
return compile_hip_code_object(src, options);
}
operation compile_pointwise(context& ctx, const std::vector<shape>& inputs, module m)
{
run_passes(m, {eliminate_common_subexpression{}, dead_code_elimination{}});
cpp_generator g;
g.fmap([](const std::string& fname) { return "migraphx::" + fname; });
g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})");
g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})");
g.add_point_op("sign", "${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))");
g.add_point_op("equal", "migraphx::abs(${0} == ${1})");
g.add_point_op("less", "migraphx::abs(${0} < ${1})");
g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
g.add_point_op("not", "migraphx::abs(not ${0})");
auto name =
g.create_function(g.generate_module(m).set_attributes({"__device__"}).set_generic_types(m));
return compile_pointwise((ctx), inputs, "MIGRAPHX_LIFT(" + name + ")", g.str());
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment