Commit 60bbf654 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

simplify the implementation of rnn operators with the enhanced gemm operator.

parent 6c77eae1
......@@ -206,14 +206,16 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
// bias
instruction_ref bwb{};
instruction_ref brb{};
if(bias != prog.end())
{
long hs = r->get_shape().lens()[2];
long hs = static_cast<long>(r->get_shape().lens()[2]);
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto rb = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto b = prog.insert_instruction(ins, op::add{}, wb, rb);
bias = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, b);
bwb = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, wb);
brb = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, rb);
}
instruction_ref hidden_out = prog.end();
......@@ -225,21 +227,22 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
long seq_index = is_forward ? i : (seq_len - 1 - i);
auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
instruction_ref ht;
if(bias != prog.end())
instruction_ref xt_wi{};
instruction_ref ht_ri{};
if (bias != prog.end())
{
ht = prog.insert_instruction(ins, op::add{}, xt_ht, bias);
xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw, bwb);
ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr, brb);
}
else
{
ht = xt_ht;
xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
}
auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
// apply activation function
ht = prog.insert_instruction(ins, actv_func, ht);
auto ht = prog.insert_instruction(ins, actv_func, xt_ht);
sih = ht;
// add the dimensions of sequence length (axis 0 for sequence length,
......@@ -958,39 +961,38 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto ic_shape = sic->get_shape();
// bias
instruction_ref bi_brcst{};
instruction_ref bo_brcst{};
instruction_ref bf_brcst{};
instruction_ref bc_brcst{};
instruction_ref wbi_brcst{}, rbi_brcst{};
instruction_ref wbo_brcst{}, rbo_brcst{};
instruction_ref wbf_brcst{}, rbf_brcst{};
instruction_ref wbc_brcst{}, rbc_brcst{};
if(bias != prog.end())
{
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto bxi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto bhi = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
auto bi = prog.insert_instruction(ins, op::add{}, bxi, bhi);
bi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, bi);
auto wbi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto rbi = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
wbi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, wbi);
rbi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, rbi);
auto bxo = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto bho = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
auto bo = prog.insert_instruction(ins, op::add{}, bxo, bho);
bo_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, bo);
auto wbo = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto rbo = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
wbo_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, wbo);
rbo_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, rbo);
auto bxf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
auto bhf = prog.insert_instruction(ins, op::slice{{0}, {6 * hs}, {7 * hs}}, sbias);
auto bf = prog.insert_instruction(ins, op::add{}, bxf, bhf);
bf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, bf);
auto wbf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
auto rbf = prog.insert_instruction(ins, op::slice{{0}, {6 * hs}, {7 * hs}}, sbias);
wbf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, wbf);
rbf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, rbf);
auto bxc = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
auto bhc = prog.insert_instruction(ins, op::slice{{0}, {7 * hs}, {8 * hs}}, sbias);
auto bc = prog.insert_instruction(ins, op::add{}, bxc, bhc);
bc_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, bc);
auto wbc = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
auto rbc = prog.insert_instruction(ins, op::slice{{0}, {7 * hs}, {8 * hs}}, sbias);
wbc_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, wbc);
rbc_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, rbc);
}
// peep hole
instruction_ref pphi_brcst{};
instruction_ref ppho_brcst{};
instruction_ref pphf_brcst{};
if(pph != prog.end())
{
auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph);
......@@ -1011,43 +1013,58 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
// equation it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_wi);
auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_ri);
instruction_ref xt_wi{}, ht_ri{};
if (bias != prog.end())
{
xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_wi, wbi_brcst);
ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_ri, rbi_brcst);
}
else
{
xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_wi);
ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_ri);
}
auto it_before_actv = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
if(pph != prog.end())
{
auto pphi_ct = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic);
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, pphi_ct);
}
if(bias != prog.end())
{
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, bi_brcst);
}
auto it = prog.insert_instruction(ins, actv_func1, it_before_actv);
// equation ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
auto xt_wf = prog.insert_instruction(ins, op::dot{}, xt, tran_wf);
auto ht_rf = prog.insert_instruction(ins, op::dot{}, sih, tran_rf);
instruction_ref xt_wf{}, ht_rf{};
if (bias != prog.end())
{
xt_wf = prog.insert_instruction(ins, op::dot{}, xt, tran_wf, wbf_brcst);
ht_rf = prog.insert_instruction(ins, op::dot{}, sih, tran_rf, rbf_brcst);
}
else
{
xt_wf = prog.insert_instruction(ins, op::dot{}, xt, tran_wf);
ht_rf = prog.insert_instruction(ins, op::dot{}, sih, tran_rf);
}
auto ft_before_actv = prog.insert_instruction(ins, op::add{}, xt_wf, ht_rf);
if(pph != prog.end())
{
auto pphf_ct = prog.insert_instruction(ins, op::mul{}, pphf_brcst, sic);
ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, pphf_ct);
}
if(bias != prog.end())
{
ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, bf_brcst);
}
auto ft = prog.insert_instruction(ins, actv_func1, ft_before_actv);
// equation ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
auto xt_wc = prog.insert_instruction(ins, op::dot{}, xt, tran_wc);
auto ht_rc = prog.insert_instruction(ins, op::dot{}, sih, tran_rc);
auto ct_before_actv = prog.insert_instruction(ins, op::add{}, xt_wc, ht_rc);
instruction_ref xt_wc{}, ht_rc{};
if(bias != prog.end())
{
ct_before_actv = prog.insert_instruction(ins, op::add{}, ct_before_actv, bc_brcst);
xt_wc = prog.insert_instruction(ins, op::dot{}, xt, tran_wc, wbc_brcst);
ht_rc = prog.insert_instruction(ins, op::dot{}, sih, tran_rc, rbc_brcst);
}
else
{
xt_wc = prog.insert_instruction(ins, op::dot{}, xt, tran_wc);
ht_rc = prog.insert_instruction(ins, op::dot{}, sih, tran_rc);
}
auto ct_before_actv = prog.insert_instruction(ins, op::add{}, xt_wc, ht_rc);
auto ct = prog.insert_instruction(ins, actv_func2, ct_before_actv);
// equation Ct = ft (.) Ct-1 + it (.) ct
......@@ -1057,18 +1074,23 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
last_cell_output = cellt;
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
auto xt_wo = prog.insert_instruction(ins, op::dot{}, xt, tran_wo);
auto ht_ro = prog.insert_instruction(ins, op::dot{}, sih, tran_ro);
instruction_ref xt_wo{}, ht_ro{};
if(bias != prog.end())
{
xt_wo = prog.insert_instruction(ins, op::dot{}, xt, tran_wo, wbo_brcst);
ht_ro = prog.insert_instruction(ins, op::dot{}, sih, tran_ro, rbo_brcst);
}
else
{
xt_wo = prog.insert_instruction(ins, op::dot{}, xt, tran_wo);
ht_ro = prog.insert_instruction(ins, op::dot{}, sih, tran_ro);
}
auto ot_before_actv = prog.insert_instruction(ins, op::add{}, xt_wo, ht_ro);
if(pph != prog.end())
{
auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt);
ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, ppho_cellt);
}
if(bias != prog.end())
{
ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, bo_brcst);
}
auto ot = prog.insert_instruction(ins, actv_func1, ot_before_actv);
// Ht = ot (.) h(Ct)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment