Commit 8b3027ba authored by Shucai Xiao's avatar Shucai Xiao
Browse files

more optimization of the rnn operators.

parent 93ca0082
......@@ -208,16 +208,15 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
auto sih_lens = sih->get_shape().lens();
// bias
instruction_ref bwb{};
instruction_ref brb{};
instruction_ref bb{};
if(bias != prog.end())
{
long hs = static_cast<long>(r->get_shape().lens()[2]);
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto rb = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
bwb = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, wb);
brb = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, rb);
auto wrb = prog.insert_instruction(ins, op::add{}, wb, rb);
bb = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, wrb);
}
instruction_ref hidden_out = prog.end();
......@@ -229,17 +228,11 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
instruction_ref xt_wi{};
instruction_ref ht_ri{};
auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
if(bias != prog.end())
{
xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw, bwb);
ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr, brb);
}
else
{
xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
xt_wi = prog.insert_instruction(ins, op::add{}, xt_wi, bb);
}
auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
......@@ -532,17 +525,12 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
instruction_ref xt_w{};
instruction_ref ih1_rzr{};
auto xt_w = prog.insert_instruction(ins, op::dot{}, xt, tw);
auto ih1_rzr = prog.insert_instruction(ins, op::dot{}, sih, trzr);
if(bias != prog.end())
{
xt_w = prog.insert_instruction(ins, op::dot{}, xt, tw, bwb);
ih1_rzr = prog.insert_instruction(ins, op::dot{}, sih, trzr, brb_zr);
}
else
{
xt_w = prog.insert_instruction(ins, op::dot{}, xt, tw);
ih1_rzr = prog.insert_instruction(ins, op::dot{}, sih, trzr);
xt_w = prog.insert_instruction(ins, op::add{}, xt_w, bwb);
ih1_rzr = prog.insert_instruction(ins, op::add{}, ih1_rzr, brb_zr);
}
auto xw_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_w);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment