Commit 8b3027ba authored by Shucai Xiao's avatar Shucai Xiao
Browse files

more optimization of the rnn operators.

parent 93ca0082
...@@ -208,16 +208,15 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward, ...@@ -208,16 +208,15 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
auto sih_lens = sih->get_shape().lens(); auto sih_lens = sih->get_shape().lens();
// bias // bias
instruction_ref bwb{}; instruction_ref bb{};
instruction_ref brb{};
if(bias != prog.end()) if(bias != prog.end())
{ {
long hs = static_cast<long>(r->get_shape().lens()[2]); long hs = static_cast<long>(r->get_shape().lens()[2]);
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias); auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto rb = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias); auto rb = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
bwb = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, wb); auto wrb = prog.insert_instruction(ins, op::add{}, wb, rb);
brb = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, rb); bb = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, wrb);
} }
instruction_ref hidden_out = prog.end(); instruction_ref hidden_out = prog.end();
...@@ -229,17 +228,11 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward, ...@@ -229,17 +228,11 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
long seq_index = is_forward ? i : (seq_len - 1 - i); long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input); auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt); xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
instruction_ref xt_wi{}; auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
instruction_ref ht_ri{}; auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
if(bias != prog.end()) if(bias != prog.end())
{ {
xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw, bwb); xt_wi = prog.insert_instruction(ins, op::add{}, xt_wi, bb);
ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr, brb);
}
else
{
xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
} }
auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri); auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
...@@ -532,17 +525,12 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -532,17 +525,12 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq); auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt); xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
instruction_ref xt_w{}; auto xt_w = prog.insert_instruction(ins, op::dot{}, xt, tw);
instruction_ref ih1_rzr{}; auto ih1_rzr = prog.insert_instruction(ins, op::dot{}, sih, trzr);
if(bias != prog.end()) if(bias != prog.end())
{ {
xt_w = prog.insert_instruction(ins, op::dot{}, xt, tw, bwb); xt_w = prog.insert_instruction(ins, op::add{}, xt_w, bwb);
ih1_rzr = prog.insert_instruction(ins, op::dot{}, sih, trzr, brb_zr); ih1_rzr = prog.insert_instruction(ins, op::add{}, ih1_rzr, brb_zr);
}
else
{
xt_w = prog.insert_instruction(ins, op::dot{}, xt, tw);
ih1_rzr = prog.insert_instruction(ins, op::dot{}, sih, trzr);
} }
auto xw_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_w); auto xw_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_w);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment