Commit fe30f007 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 60bbf654
...@@ -214,8 +214,8 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward, ...@@ -214,8 +214,8 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias); auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto rb = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias); auto rb = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
bwb = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, wb); bwb = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, wb);
brb = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, rb); brb = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, rb);
} }
instruction_ref hidden_out = prog.end(); instruction_ref hidden_out = prog.end();
...@@ -229,7 +229,7 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward, ...@@ -229,7 +229,7 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt); xt = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
instruction_ref xt_wi{}; instruction_ref xt_wi{};
instruction_ref ht_ri{}; instruction_ref ht_ri{};
if (bias != prog.end()) if(bias != prog.end())
{ {
xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw, bwb); xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw, bwb);
ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr, brb); ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr, brb);
...@@ -237,13 +237,13 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward, ...@@ -237,13 +237,13 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
else else
{ {
xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw); xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr); ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
} }
auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri); auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
// apply activation function // apply activation function
auto ht = prog.insert_instruction(ins, actv_func, xt_ht); auto ht = prog.insert_instruction(ins, actv_func, xt_ht);
sih = ht; sih = ht;
// add the dimensions of sequence length (axis 0 for sequence length, // add the dimensions of sequence length (axis 0 for sequence length,
// axis 1 for num_directions // axis 1 for num_directions
...@@ -970,23 +970,23 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -970,23 +970,23 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
auto wbi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias); auto wbi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
auto rbi = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias); auto rbi = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
wbi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, wbi); wbi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, wbi);
rbi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, rbi); rbi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, rbi);
auto wbo = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias); auto wbo = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
auto rbo = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias); auto rbo = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
wbo_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, wbo); wbo_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, wbo);
rbo_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, rbo); rbo_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, rbo);
auto wbf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias); auto wbf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
auto rbf = prog.insert_instruction(ins, op::slice{{0}, {6 * hs}, {7 * hs}}, sbias); auto rbf = prog.insert_instruction(ins, op::slice{{0}, {6 * hs}, {7 * hs}}, sbias);
wbf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, wbf); wbf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, wbf);
rbf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, rbf); rbf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, rbf);
auto wbc = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias); auto wbc = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
auto rbc = prog.insert_instruction(ins, op::slice{{0}, {7 * hs}, {8 * hs}}, sbias); auto rbc = prog.insert_instruction(ins, op::slice{{0}, {7 * hs}, {8 * hs}}, sbias);
wbc_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, wbc); wbc_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, wbc);
rbc_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, rbc); rbc_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, rbc);
} }
// peep hole // peep hole
...@@ -1014,15 +1014,15 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1014,15 +1014,15 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
// equation it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) // equation it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
instruction_ref xt_wi{}, ht_ri{}; instruction_ref xt_wi{}, ht_ri{};
if (bias != prog.end()) if(bias != prog.end())
{ {
xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_wi, wbi_brcst); xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_wi, wbi_brcst);
ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_ri, rbi_brcst); ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_ri, rbi_brcst);
} }
else else
{ {
xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_wi); xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_wi);
ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_ri); ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_ri);
} }
auto it_before_actv = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri); auto it_before_actv = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
if(pph != prog.end()) if(pph != prog.end())
...@@ -1034,15 +1034,15 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1034,15 +1034,15 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
// equation ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf) // equation ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
instruction_ref xt_wf{}, ht_rf{}; instruction_ref xt_wf{}, ht_rf{};
if (bias != prog.end()) if(bias != prog.end())
{ {
xt_wf = prog.insert_instruction(ins, op::dot{}, xt, tran_wf, wbf_brcst); xt_wf = prog.insert_instruction(ins, op::dot{}, xt, tran_wf, wbf_brcst);
ht_rf = prog.insert_instruction(ins, op::dot{}, sih, tran_rf, rbf_brcst); ht_rf = prog.insert_instruction(ins, op::dot{}, sih, tran_rf, rbf_brcst);
} }
else else
{ {
xt_wf = prog.insert_instruction(ins, op::dot{}, xt, tran_wf); xt_wf = prog.insert_instruction(ins, op::dot{}, xt, tran_wf);
ht_rf = prog.insert_instruction(ins, op::dot{}, sih, tran_rf); ht_rf = prog.insert_instruction(ins, op::dot{}, sih, tran_rf);
} }
auto ft_before_actv = prog.insert_instruction(ins, op::add{}, xt_wf, ht_rf); auto ft_before_actv = prog.insert_instruction(ins, op::add{}, xt_wf, ht_rf);
if(pph != prog.end()) if(pph != prog.end())
...@@ -1056,16 +1056,16 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1056,16 +1056,16 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
instruction_ref xt_wc{}, ht_rc{}; instruction_ref xt_wc{}, ht_rc{};
if(bias != prog.end()) if(bias != prog.end())
{ {
xt_wc = prog.insert_instruction(ins, op::dot{}, xt, tran_wc, wbc_brcst); xt_wc = prog.insert_instruction(ins, op::dot{}, xt, tran_wc, wbc_brcst);
ht_rc = prog.insert_instruction(ins, op::dot{}, sih, tran_rc, rbc_brcst); ht_rc = prog.insert_instruction(ins, op::dot{}, sih, tran_rc, rbc_brcst);
} }
else else
{ {
xt_wc = prog.insert_instruction(ins, op::dot{}, xt, tran_wc); xt_wc = prog.insert_instruction(ins, op::dot{}, xt, tran_wc);
ht_rc = prog.insert_instruction(ins, op::dot{}, sih, tran_rc); ht_rc = prog.insert_instruction(ins, op::dot{}, sih, tran_rc);
} }
auto ct_before_actv = prog.insert_instruction(ins, op::add{}, xt_wc, ht_rc); auto ct_before_actv = prog.insert_instruction(ins, op::add{}, xt_wc, ht_rc);
auto ct = prog.insert_instruction(ins, actv_func2, ct_before_actv); auto ct = prog.insert_instruction(ins, actv_func2, ct_before_actv);
// equation Ct = ft (.) Ct-1 + it (.) ct // equation Ct = ft (.) Ct-1 + it (.) ct
auto ft_cell = prog.insert_instruction(ins, op::mul{}, ft, sic); auto ft_cell = prog.insert_instruction(ins, op::mul{}, ft, sic);
...@@ -1077,13 +1077,13 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -1077,13 +1077,13 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
instruction_ref xt_wo{}, ht_ro{}; instruction_ref xt_wo{}, ht_ro{};
if(bias != prog.end()) if(bias != prog.end())
{ {
xt_wo = prog.insert_instruction(ins, op::dot{}, xt, tran_wo, wbo_brcst); xt_wo = prog.insert_instruction(ins, op::dot{}, xt, tran_wo, wbo_brcst);
ht_ro = prog.insert_instruction(ins, op::dot{}, sih, tran_ro, rbo_brcst); ht_ro = prog.insert_instruction(ins, op::dot{}, sih, tran_ro, rbo_brcst);
} }
else else
{ {
xt_wo = prog.insert_instruction(ins, op::dot{}, xt, tran_wo); xt_wo = prog.insert_instruction(ins, op::dot{}, xt, tran_wo);
ht_ro = prog.insert_instruction(ins, op::dot{}, sih, tran_ro); ht_ro = prog.insert_instruction(ins, op::dot{}, sih, tran_ro);
} }
auto ot_before_actv = prog.insert_instruction(ins, op::add{}, xt_wo, ht_ro); auto ot_before_actv = prog.insert_instruction(ins, op::add{}, xt_wo, ht_ro);
if(pph != prog.end()) if(pph != prog.end())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment