Commit 6c97a744 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

further optimization of lstm operator.

parent 8b3027ba
......@@ -910,19 +910,17 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto ic_lens = sic->get_shape().lens();
// bias
instruction_ref wb{};
instruction_ref rb{};
instruction_ref wrb{};
if(bias != prog.end())
{
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto ub_wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {4 * hs}}, sbias);
auto ub_rb = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {8 * hs}}, sbias);
auto ub_wrb = prog.insert_instruction(ins, op::add{}, ub_wb, ub_rb);
wb = prog.insert_instruction(
ins, op::broadcast{1, {bs, 4 * static_cast<size_t>(hs)}}, ub_wb);
rb = prog.insert_instruction(
ins, op::broadcast{1, {bs, 4 * static_cast<size_t>(hs)}}, ub_rb);
wrb = prog.insert_instruction(
ins, op::broadcast{1, {bs, 4 * static_cast<size_t>(hs)}}, ub_wrb);
}
// peep hole
......@@ -948,18 +946,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
instruction_ref xt_sih{};
if(bias != prog.end())
{
auto xt_tsw = prog.insert_instruction(ins, op::dot{}, xt, tsw, wb);
auto sih_tsr = prog.insert_instruction(ins, op::dot{}, sih, tsr, rb);
xt_sih = prog.insert_instruction(ins, op::add{}, xt_tsw, sih_tsr);
}
else
{
auto xt_tsw = prog.insert_instruction(ins, op::dot{}, xt, tsw);
auto sih_tsr = prog.insert_instruction(ins, op::dot{}, sih, tsr);
xt_sih = prog.insert_instruction(ins, op::add{}, xt_tsw, sih_tsr);
auto xt_sih = prog.insert_instruction(ins, op::add{}, xt_tsw, sih_tsr);
if(bias != prog.end())
{
xt_sih = prog.insert_instruction(ins, op::add{}, xt_sih, wrb);
}
auto it_before_actv = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_sih);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment