Commit 6c97a744 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

further optimization of lstm operator.

parent 8b3027ba
...@@ -910,19 +910,17 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -910,19 +910,17 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto ic_lens = sic->get_shape().lens(); auto ic_lens = sic->get_shape().lens();
// bias // bias
instruction_ref wb{}; instruction_ref wrb{};
instruction_ref rb{};
if(bias != prog.end()) if(bias != prog.end())
{ {
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto ub_wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {4 * hs}}, sbias); auto ub_wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {4 * hs}}, sbias);
auto ub_rb = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {8 * hs}}, sbias); auto ub_rb = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {8 * hs}}, sbias);
auto ub_wrb = prog.insert_instruction(ins, op::add{}, ub_wb, ub_rb);
wb = prog.insert_instruction( wrb = prog.insert_instruction(
ins, op::broadcast{1, {bs, 4 * static_cast<size_t>(hs)}}, ub_wb); ins, op::broadcast{1, {bs, 4 * static_cast<size_t>(hs)}}, ub_wrb);
rb = prog.insert_instruction(
ins, op::broadcast{1, {bs, 4 * static_cast<size_t>(hs)}}, ub_rb);
} }
// peep hole // peep hole
...@@ -948,18 +946,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -948,18 +946,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq); auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt); xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
instruction_ref xt_sih{}; auto xt_tsw = prog.insert_instruction(ins, op::dot{}, xt, tsw);
auto sih_tsr = prog.insert_instruction(ins, op::dot{}, sih, tsr);
auto xt_sih = prog.insert_instruction(ins, op::add{}, xt_tsw, sih_tsr);
if(bias != prog.end()) if(bias != prog.end())
{ {
auto xt_tsw = prog.insert_instruction(ins, op::dot{}, xt, tsw, wb); xt_sih = prog.insert_instruction(ins, op::add{}, xt_sih, wrb);
auto sih_tsr = prog.insert_instruction(ins, op::dot{}, sih, tsr, rb);
xt_sih = prog.insert_instruction(ins, op::add{}, xt_tsw, sih_tsr);
}
else
{
auto xt_tsw = prog.insert_instruction(ins, op::dot{}, xt, tsw);
auto sih_tsr = prog.insert_instruction(ins, op::dot{}, sih, tsr);
xt_sih = prog.insert_instruction(ins, op::add{}, xt_tsw, sih_tsr);
} }
auto it_before_actv = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_sih); auto it_before_actv = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_sih);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment