Commit 87697905 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

optimize the gru operator with the enhanced dot operator.

parent fa4f5244
...@@ -513,32 +513,19 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -513,32 +513,19 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih); auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
// bias // bias
instruction_ref brcst_bz{}; instruction_ref wbz{}, rbz{};
instruction_ref brcst_br{}; instruction_ref wbr{}, rbr{};
instruction_ref brcst_wbh{}; instruction_ref wbh{}, rbh{};
instruction_ref brcst_rbh{};
instruction_ref brcst_bh{};
if(bias != prog.end()) if(bias != prog.end())
{ {
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias); wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias); wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias); wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
brcst_wbh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, wbh);
auto rbz = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias); rbz = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
auto rbr = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias); rbr = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
auto rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias); rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
brcst_rbh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, rbh);
auto bz = prog.insert_instruction(ins, op::add{}, wbz, rbz);
brcst_bz = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bz);
auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr);
brcst_br = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, br);
auto bh = prog.insert_instruction(ins, op::add{}, wbh, rbh);
brcst_bh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bh);
} }
for(long i = 0; i < seq_len; i++) for(long i = 0; i < seq_len; i++)
...@@ -548,53 +535,74 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -548,53 +535,74 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt); xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz) // equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
auto xt_wz = prog.insert_instruction(ins, op::dot{}, xt, tran_wz); instruction_ref xt_wz{};
auto ht_rz = prog.insert_instruction(ins, op::dot{}, sih, tran_rz); instruction_ref ht_rz{};
auto xht_z = prog.insert_instruction(ins, op::add{}, xt_wz, ht_rz); if (bias != prog.end())
if(bias != prog.end()) {
xt_wz = prog.insert_instruction(ins, op::dot{}, xt, tran_wz, wbz);
ht_rz = prog.insert_instruction(ins, op::dot{}, sih, tran_rz, rbz);
}
else
{ {
xht_z = prog.insert_instruction(ins, op::add{}, xht_z, brcst_bz); xt_wz = prog.insert_instruction(ins, op::dot{}, xt, tran_wz);
ht_rz = prog.insert_instruction(ins, op::dot{}, sih, tran_rz);
} }
auto xht_z = prog.insert_instruction(ins, op::add{}, xt_wz, ht_rz);
auto zt = prog.insert_instruction(ins, actv_func1, xht_z); auto zt = prog.insert_instruction(ins, actv_func1, xht_z);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr) // equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
auto xt_wr = prog.insert_instruction(ins, op::dot{}, xt, tran_wr); instruction_ref xt_wr{};
auto ht_rr = prog.insert_instruction(ins, op::dot{}, sih, tran_rr); instruction_ref ht_rr{};
auto xht_r = prog.insert_instruction(ins, op::add{}, xt_wr, ht_rr); if (bias != prog.end())
if(bias != prog.end()) {
xt_wr = prog.insert_instruction(ins, op::dot{}, xt, tran_wr, wbr);
ht_rr = prog.insert_instruction(ins, op::dot{}, sih, tran_rr, rbr);
}
else
{ {
xht_r = prog.insert_instruction(ins, op::add{}, xht_r, brcst_br); xt_wr = prog.insert_instruction(ins, op::dot{}, xt, tran_wr);
ht_rr = prog.insert_instruction(ins, op::dot{}, sih, tran_rr);
} }
auto xht_r = prog.insert_instruction(ins, op::add{}, xt_wr, ht_rr);
auto rt = prog.insert_instruction(ins, actv_func1, xht_r); auto rt = prog.insert_instruction(ins, actv_func1, xht_r);
instruction_ref xht_h; instruction_ref xht_h;
if(linear_before_reset == 0) if(linear_before_reset == 0)
{ {
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) // equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh); instruction_ref xt_wh{};
instruction_ref rt_rh{};
auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih); auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
auto rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht1, tran_rh); if (bias != prog.end())
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
if(bias != prog.end())
{ {
xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_bh); xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh, wbh);
rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht1, tran_rh, rbh);
} }
else
{
xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht1, tran_rh);
}
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
} }
else else
{ {
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh); instruction_ref ht1_rh{};
auto ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, tran_rh); instruction_ref xt_wh{};
if(bias != prog.end()) if(bias != prog.end())
{ {
ht1_rh = prog.insert_instruction(ins, op::add{}, ht1_rh, brcst_rbh); xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh, wbh);
ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, tran_rh, rbh);
} }
auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh); else
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
if(bias != prog.end())
{ {
xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_wbh); xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, tran_rh);
} }
auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
} }
auto ht = prog.insert_instruction(ins, actv_func2, xht_h); auto ht = prog.insert_instruction(ins, actv_func2, xht_h);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment