Commit 2ce55c6e authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 0ce9fcbb
...@@ -968,39 +968,39 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -968,39 +968,39 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
instruction_ref bo_brcst{}; instruction_ref bo_brcst{};
instruction_ref bf_brcst{}; instruction_ref bf_brcst{};
instruction_ref bc_brcst{}; instruction_ref bc_brcst{};
if (bias != prog.end()) if(bias != prog.end())
{ {
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto bxi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias); auto bxi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto bhi = prog.insert_instruction(ins, op::slice{{0}, {4*hs}, {5*hs}}, sbias); auto bhi = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
auto bi = prog.insert_instruction(ins, op::add{}, bxi, bhi); auto bi = prog.insert_instruction(ins, op::add{}, bxi, bhi);
bi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, bi); bi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, bi);
auto bxo = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2*hs}}, sbias); auto bxo = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto bho = prog.insert_instruction(ins, op::slice{{0}, {5*hs}, {6*hs}}, sbias); auto bho = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
auto bo = prog.insert_instruction(ins, op::add{}, bxo, bho); auto bo = prog.insert_instruction(ins, op::add{}, bxo, bho);
bo_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, bo); bo_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, bo);
auto bxf = prog.insert_instruction(ins, op::slice{{0}, {2*hs}, {3*hs}}, sbias); auto bxf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
auto bhf = prog.insert_instruction(ins, op::slice{{0}, {5*hs}, {6*hs}}, sbias); auto bhf = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
auto bf = prog.insert_instruction(ins, op::add{}, bxf, bhf); auto bf = prog.insert_instruction(ins, op::add{}, bxf, bhf);
bf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, bf); bf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, bf);
auto bxc = prog.insert_instruction(ins, op::slice{{0}, {3*hs}, {4*hs}}, sbias); auto bxc = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
auto bhc = prog.insert_instruction(ins, op::slice{{0}, {7*hs}, {8*hs}}, sbias); auto bhc = prog.insert_instruction(ins, op::slice{{0}, {7 * hs}, {8 * hs}}, sbias);
auto bc = prog.insert_instruction(ins, op::add{}, bxc, bhc); auto bc = prog.insert_instruction(ins, op::add{}, bxc, bhc);
bc_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, bc); bc_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, bc);
} }
// peep hole // peep hole
auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph); auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph);
auto pphi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph); auto pphi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph);
auto pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphi); auto pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphi);
auto ppho = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph); auto ppho = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph);
auto ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, ppho); auto ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, ppho);
auto pphf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph); auto pphf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph);
auto pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphf); auto pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphf);
for(long i = 0; i < seq_len; ++i) for(long i = 0; i < seq_len; ++i)
...@@ -1010,49 +1010,49 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1010,49 +1010,49 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt); xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
// equation it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) // equation it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_wi); auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_wi);
auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_ri); auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_ri);
auto pphi_ct = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic); auto pphi_ct = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic);
auto it_before_actv = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri, pphi_ct); auto it_before_actv = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri, pphi_ct);
if (bias != prog.end()) if(bias != prog.end())
{ {
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, bi_brcst); it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, bi_brcst);
} }
auto it = prog.insert_instruction(ins, actv_func1, it_before_actv); auto it = prog.insert_instruction(ins, actv_func1, it_before_actv);
// equation ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf) // equation ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
auto xt_wf = prog.insert_instruction(ins, op::dot{}, xt, tran_wf); auto xt_wf = prog.insert_instruction(ins, op::dot{}, xt, tran_wf);
auto ht_rf = prog.insert_instruction(ins, op::dot{}, sih, tran_rf); auto ht_rf = prog.insert_instruction(ins, op::dot{}, sih, tran_rf);
auto pphf_ct = prog.insert_instruction(ins, op::mul{}, pphf_brcst, sic); auto pphf_ct = prog.insert_instruction(ins, op::mul{}, pphf_brcst, sic);
auto ft_before_actv = prog.insert_instruction(ins, op::add{}, xt_wf, ht_rf, pphf_ct); auto ft_before_actv = prog.insert_instruction(ins, op::add{}, xt_wf, ht_rf, pphf_ct);
if (bias != prog.end()) if(bias != prog.end())
{ {
ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, bf_brcst); ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, bf_brcst);
} }
auto ft = prog.insert_instruction(ins, actv_func1, ft_before_actv); auto ft = prog.insert_instruction(ins, actv_func1, ft_before_actv);
// equation ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc) // equation ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
auto xt_wc = prog.insert_instruction(ins, op::dot{}, xt, tran_wc); auto xt_wc = prog.insert_instruction(ins, op::dot{}, xt, tran_wc);
auto ht_rc = prog.insert_instruction(ins, op::dot{}, sih, tran_rc); auto ht_rc = prog.insert_instruction(ins, op::dot{}, sih, tran_rc);
auto ct_before_actv = prog.insert_instruction(ins, op::add{}, xt_wc, ht_rc); auto ct_before_actv = prog.insert_instruction(ins, op::add{}, xt_wc, ht_rc);
if (bias != prog.end()) if(bias != prog.end())
{ {
ct_before_actv = prog.insert_instruction(ins, op::add{}, ct_before_actv, bc_brcst); ct_before_actv = prog.insert_instruction(ins, op::add{}, ct_before_actv, bc_brcst);
} }
auto ct = prog.insert_instruction(ins, actv_func2, ct_before_actv); auto ct = prog.insert_instruction(ins, actv_func2, ct_before_actv);
// equation Ct = ft (.) Ct-1 + it (.) ct // equation Ct = ft (.) Ct-1 + it (.) ct
auto ft_cell = prog.insert_instruction(ins, op::mul{}, ft, sic); auto ft_cell = prog.insert_instruction(ins, op::mul{}, ft, sic);
auto it_ct = prog.insert_instruction(ins, op::mul{}, it, ct); auto it_ct = prog.insert_instruction(ins, op::mul{}, it, ct);
auto cellt = prog.insert_instruction(ins, op::add{}, ft_cell, it_ct); auto cellt = prog.insert_instruction(ins, op::add{}, ft_cell, it_ct);
last_cell_output = cellt; last_cell_output = cellt;
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo) // ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
auto xt_wo = prog.insert_instruction(ins, op::dot{}, xt, tran_wo); auto xt_wo = prog.insert_instruction(ins, op::dot{}, xt, tran_wo);
auto ht_ro = prog.insert_instruction(ins, op::dot{}, sih, tran_ro); auto ht_ro = prog.insert_instruction(ins, op::dot{}, sih, tran_ro);
auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt); auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt);
auto ot_before_actv = prog.insert_instruction(ins, op::add{}, xt_wo, ht_ro, ppho_cellt); auto ot_before_actv = prog.insert_instruction(ins, op::add{}, xt_wo, ht_ro, ppho_cellt);
if (bias != prog.end()) if(bias != prog.end())
{ {
ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, bo_brcst); ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, bo_brcst);
} }
...@@ -1060,26 +1060,26 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1060,26 +1060,26 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
// Ht = ot (.) h(Ct) // Ht = ot (.) h(Ct)
auto h_cellt = prog.insert_instruction(ins, actv_func3, cellt); auto h_cellt = prog.insert_instruction(ins, actv_func3, cellt);
auto ht = prog.insert_instruction(ins, op::mul{}, ot, h_cellt); auto ht = prog.insert_instruction(ins, op::mul{}, ot, h_cellt);
sic = cellt; sic = cellt;
sih = ht; sih = ht;
last_output = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht); last_output = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht);
if (i < seq_len - 1) if(i < seq_len - 1)
{ {
if (i == 0) if(i == 0)
{ {
hidden_states = last_output; hidden_states = last_output;
} }
else else
{ {
auto concat_arg0 = is_forward ? hidden_states : last_output; auto concat_arg0 = is_forward ? hidden_states : last_output;
auto concat_arg1 = is_forward ? last_output : hidden_states; auto concat_arg1 = is_forward ? last_output : hidden_states;
hidden_states = prog.insert_instruction(ins, op::concat{0}, concat_arg0, concat_arg1); hidden_states =
prog.insert_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment