Commit c5a9d22f authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent b7f5e9bd
...@@ -986,20 +986,20 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -986,20 +986,20 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
instruction_ref ppho_brcst{}; instruction_ref ppho_brcst{};
instruction_ref pphf_brcst{}; instruction_ref pphf_brcst{};
if (pph != prog.end()) if(pph != prog.end())
{ {
auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph); auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph);
auto pphi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph); auto pphi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph);
pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphi); pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphi);
pphi_brcst = prog.insert_instruction(ins, op::contiguous{}, pphi_brcst); pphi_brcst = prog.insert_instruction(ins, op::contiguous{}, pphi_brcst);
auto ppho = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph); auto ppho = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph);
ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, ppho); ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, ppho);
ppho_brcst = prog.insert_instruction(ins, op::contiguous{}, ppho_brcst); ppho_brcst = prog.insert_instruction(ins, op::contiguous{}, ppho_brcst);
auto pphf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph); auto pphf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph);
pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphf); pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphf);
pphf_brcst = prog.insert_instruction(ins, op::contiguous{}, pphf_brcst); pphf_brcst = prog.insert_instruction(ins, op::contiguous{}, pphf_brcst);
} }
for(long i = 0; i < seq_len; ++i) for(long i = 0; i < seq_len; ++i)
...@@ -1012,9 +1012,10 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1012,9 +1012,10 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_wi); auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_wi);
auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_ri); auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_ri);
auto it_before_actv = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri); auto it_before_actv = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
if (pph != prog.end()) { if(pph != prog.end())
auto pphi_ct = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic); {
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, pphi_ct); auto pphi_ct = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic);
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, pphi_ct);
} }
if(bias != prog.end()) if(bias != prog.end())
{ {
...@@ -1026,10 +1027,10 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1026,10 +1027,10 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto xt_wf = prog.insert_instruction(ins, op::dot{}, xt, tran_wf); auto xt_wf = prog.insert_instruction(ins, op::dot{}, xt, tran_wf);
auto ht_rf = prog.insert_instruction(ins, op::dot{}, sih, tran_rf); auto ht_rf = prog.insert_instruction(ins, op::dot{}, sih, tran_rf);
auto ft_before_actv = prog.insert_instruction(ins, op::add{}, xt_wf, ht_rf); auto ft_before_actv = prog.insert_instruction(ins, op::add{}, xt_wf, ht_rf);
if (pph != prog.end()) if(pph != prog.end())
{ {
auto pphf_ct = prog.insert_instruction(ins, op::mul{}, pphf_brcst, sic); auto pphf_ct = prog.insert_instruction(ins, op::mul{}, pphf_brcst, sic);
ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, pphf_ct); ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, pphf_ct);
} }
if(bias != prog.end()) if(bias != prog.end())
{ {
...@@ -1057,10 +1058,10 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1057,10 +1058,10 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto xt_wo = prog.insert_instruction(ins, op::dot{}, xt, tran_wo); auto xt_wo = prog.insert_instruction(ins, op::dot{}, xt, tran_wo);
auto ht_ro = prog.insert_instruction(ins, op::dot{}, sih, tran_ro); auto ht_ro = prog.insert_instruction(ins, op::dot{}, sih, tran_ro);
auto ot_before_actv = prog.insert_instruction(ins, op::add{}, xt_wo, ht_ro); auto ot_before_actv = prog.insert_instruction(ins, op::add{}, xt_wo, ht_ro);
if (pph != prog.end()) if(pph != prog.end())
{ {
auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt); auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt);
ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, ppho_cellt); ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, ppho_cellt);
} }
if(bias != prog.end()) if(bias != prog.end())
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment