Commit 87697905 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

optimize the gru operator with the enhanced dot operator.

parent fa4f5244
......@@ -513,32 +513,19 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
// bias
instruction_ref brcst_bz{};
instruction_ref brcst_br{};
instruction_ref brcst_wbh{};
instruction_ref brcst_rbh{};
instruction_ref brcst_bh{};
instruction_ref wbz{}, rbz{};
instruction_ref wbr{}, rbr{};
instruction_ref wbh{}, rbh{};
if(bias != prog.end())
{
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
brcst_wbh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, wbh);
wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
auto rbz = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
auto rbr = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
auto rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
brcst_rbh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, rbh);
auto bz = prog.insert_instruction(ins, op::add{}, wbz, rbz);
brcst_bz = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bz);
auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr);
brcst_br = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, br);
auto bh = prog.insert_instruction(ins, op::add{}, wbh, rbh);
brcst_bh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bh);
rbz = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
rbr = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
}
for(long i = 0; i < seq_len; i++)
......@@ -548,53 +535,74 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
auto xt_wz = prog.insert_instruction(ins, op::dot{}, xt, tran_wz);
auto ht_rz = prog.insert_instruction(ins, op::dot{}, sih, tran_rz);
auto xht_z = prog.insert_instruction(ins, op::add{}, xt_wz, ht_rz);
if(bias != prog.end())
instruction_ref xt_wz{};
instruction_ref ht_rz{};
if (bias != prog.end())
{
xt_wz = prog.insert_instruction(ins, op::dot{}, xt, tran_wz, wbz);
ht_rz = prog.insert_instruction(ins, op::dot{}, sih, tran_rz, rbz);
}
else
{
xht_z = prog.insert_instruction(ins, op::add{}, xht_z, brcst_bz);
xt_wz = prog.insert_instruction(ins, op::dot{}, xt, tran_wz);
ht_rz = prog.insert_instruction(ins, op::dot{}, sih, tran_rz);
}
auto xht_z = prog.insert_instruction(ins, op::add{}, xt_wz, ht_rz);
auto zt = prog.insert_instruction(ins, actv_func1, xht_z);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
auto xt_wr = prog.insert_instruction(ins, op::dot{}, xt, tran_wr);
auto ht_rr = prog.insert_instruction(ins, op::dot{}, sih, tran_rr);
auto xht_r = prog.insert_instruction(ins, op::add{}, xt_wr, ht_rr);
if(bias != prog.end())
instruction_ref xt_wr{};
instruction_ref ht_rr{};
if (bias != prog.end())
{
xt_wr = prog.insert_instruction(ins, op::dot{}, xt, tran_wr, wbr);
ht_rr = prog.insert_instruction(ins, op::dot{}, sih, tran_rr, rbr);
}
else
{
xht_r = prog.insert_instruction(ins, op::add{}, xht_r, brcst_br);
xt_wr = prog.insert_instruction(ins, op::dot{}, xt, tran_wr);
ht_rr = prog.insert_instruction(ins, op::dot{}, sih, tran_rr);
}
auto xht_r = prog.insert_instruction(ins, op::add{}, xt_wr, ht_rr);
auto rt = prog.insert_instruction(ins, actv_func1, xht_r);
instruction_ref xht_h;
if(linear_before_reset == 0)
{
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
instruction_ref xt_wh{};
instruction_ref rt_rh{};
auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
auto rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht1, tran_rh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
if(bias != prog.end())
if (bias != prog.end())
{
xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_bh);
xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh, wbh);
rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht1, tran_rh, rbh);
}
else
{
xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht1, tran_rh);
}
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
}
else
{
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
auto ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, tran_rh);
instruction_ref ht1_rh{};
instruction_ref xt_wh{};
if(bias != prog.end())
{
ht1_rh = prog.insert_instruction(ins, op::add{}, ht1_rh, brcst_rbh);
xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh, wbh);
ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, tran_rh, rbh);
}
auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
if(bias != prog.end())
else
{
xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_wbh);
xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, tran_rh);
}
auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
}
auto ht = prog.insert_instruction(ins, actv_func2, xht_h);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment