Commit 15db7a1e authored by Shucai Xiao's avatar Shucai Xiao
Browse files

code refactor for the gru operator

parent 86d48b8e
...@@ -534,77 +534,73 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward, ...@@ -534,77 +534,73 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq); auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt); xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz) instruction_ref zt{};
instruction_ref xt_wz{}; instruction_ref ht{};
instruction_ref ht_rz{}; if (bias != prog.end())
if(bias != prog.end()) {
{ // equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
xt_wz = prog.insert_instruction(ins, op::dot{}, xt, tran_wz, wbz); auto xt_wz = prog.insert_instruction(ins, op::dot{}, xt, tran_wz, wbz);
ht_rz = prog.insert_instruction(ins, op::dot{}, sih, tran_rz, rbz); auto ht_rz = prog.insert_instruction(ins, op::dot{}, sih, tran_rz, rbz);
} auto xht_z = prog.insert_instruction(ins, op::add{}, xt_wz, ht_rz);
else zt = prog.insert_instruction(ins, actv_func1, xht_z);
{
xt_wz = prog.insert_instruction(ins, op::dot{}, xt, tran_wz); // equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
ht_rz = prog.insert_instruction(ins, op::dot{}, sih, tran_rz); auto xt_wr = prog.insert_instruction(ins, op::dot{}, xt, tran_wr, wbr);
} auto ht_rr = prog.insert_instruction(ins, op::dot{}, sih, tran_rr, rbr);
auto xht_z = prog.insert_instruction(ins, op::add{}, xt_wz, ht_rz); auto xht_r = prog.insert_instruction(ins, op::add{}, xt_wr, ht_rr);
auto zt = prog.insert_instruction(ins, actv_func1, xht_z); auto rt = prog.insert_instruction(ins, actv_func1, xht_r);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr) instruction_ref xht_h{};
instruction_ref xt_wr{}; auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh, wbh);
instruction_ref ht_rr{}; if(linear_before_reset == 0)
if(bias != prog.end())
{
xt_wr = prog.insert_instruction(ins, op::dot{}, xt, tran_wr, wbr);
ht_rr = prog.insert_instruction(ins, op::dot{}, sih, tran_rr, rbr);
}
else
{
xt_wr = prog.insert_instruction(ins, op::dot{}, xt, tran_wr);
ht_rr = prog.insert_instruction(ins, op::dot{}, sih, tran_rr);
}
auto xht_r = prog.insert_instruction(ins, op::add{}, xt_wr, ht_rr);
auto rt = prog.insert_instruction(ins, actv_func1, xht_r);
instruction_ref xht_h;
if(linear_before_reset == 0)
{
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
instruction_ref xt_wh{};
instruction_ref rt_rh{};
auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
if(bias != prog.end())
{ {
xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh, wbh); // equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht1, tran_rh, rbh); auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
auto rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht1, tran_rh, rbh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
} }
else else
{ {
xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh); // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht1, tran_rh); auto ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, tran_rh, rbh);
auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
} }
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh); ht = prog.insert_instruction(ins, actv_func2, xht_h);
} }
else else
{ {
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) // equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
instruction_ref ht1_rh{}; auto xt_wz = prog.insert_instruction(ins, op::dot{}, xt, tran_wz);
instruction_ref xt_wh{}; auto ht_rz = prog.insert_instruction(ins, op::dot{}, sih, tran_rz);
if(bias != prog.end()) auto xht_z = prog.insert_instruction(ins, op::add{}, xt_wz, ht_rz);
zt = prog.insert_instruction(ins, actv_func1, xht_z);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
auto xt_wr = prog.insert_instruction(ins, op::dot{}, xt, tran_wr);
auto ht_rr = prog.insert_instruction(ins, op::dot{}, sih, tran_rr);
auto xht_r = prog.insert_instruction(ins, op::add{}, xt_wr, ht_rr);
auto rt = prog.insert_instruction(ins, actv_func1, xht_r);
instruction_ref xht_h;
auto xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
if(linear_before_reset == 0)
{ {
xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh, wbh); // equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, tran_rh, rbh); auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
auto rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht1, tran_rh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
} }
else else
{ {
xt_wh = prog.insert_instruction(ins, op::dot{}, xt, tran_wh); // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, tran_rh); auto ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, tran_rh);
auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
} }
ht = prog.insert_instruction(ins, actv_func2, xht_h);
auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
xht_h = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
} }
auto ht = prog.insert_instruction(ins, actv_func2, xht_h);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1 // equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto one_minus_zt = prog.insert_instruction(ins, op::sub{}, l1, zt); auto one_minus_zt = prog.insert_instruction(ins, op::sub{}, l1, zt);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment