Commit 7bbacda0 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 5bd83328
......@@ -985,30 +985,50 @@ struct onnx_parser
switch(vec_names.size())
{
case 1:
vec_names = {vec_names.at(0), vec_names.at(0), vec_names.at(0),
vec_names.at(0), vec_names.at(0), vec_names.at(0)};
vec_names = {vec_names.at(0),
vec_names.at(0),
vec_names.at(0),
vec_names.at(0),
vec_names.at(0),
vec_names.at(0)};
break;
case 2:
// repeat the 2nd actv func once, then repeat all three another time
vec_names = {vec_names.at(0), vec_names.at(1), vec_names.at(1),
vec_names.at(0), vec_names.at(1), vec_names.at(1)};
vec_names = {vec_names.at(0),
vec_names.at(1),
vec_names.at(1),
vec_names.at(0),
vec_names.at(1),
vec_names.at(1)};
break;
case 3:
// repeat all three actv funcs once
vec_names = {vec_names.at(0), vec_names.at(1), vec_names.at(2),
vec_names.at(0), vec_names.at(1), vec_names.at(2)};
vec_names = {vec_names.at(0),
vec_names.at(1),
vec_names.at(2),
vec_names.at(0),
vec_names.at(1),
vec_names.at(2)};
break;
case 4:
vec_names = {vec_names.at(0), vec_names.at(1), vec_names.at(2),
vec_names.at(3), vec_names.at(3), vec_names.at(3)};
case 4:
vec_names = {vec_names.at(0),
vec_names.at(1),
vec_names.at(2),
vec_names.at(3),
vec_names.at(3),
vec_names.at(3)};
break;
case 5:
vec_names = {vec_names.at(0), vec_names.at(1), vec_names.at(2),
vec_names.at(3), vec_names.at(4), vec_names.at(4)};
case 5:
vec_names = {vec_names.at(0),
vec_names.at(1),
vec_names.at(2),
vec_names.at(3),
vec_names.at(4),
vec_names.at(4)};
break;
default: break;
......@@ -1018,13 +1038,11 @@ struct onnx_parser
{
switch(vec_names.size())
{
case 1:
vec_names = {vec_names.at(0), vec_names.at(0), vec_names.at(0)};
break;
case 1: vec_names = {vec_names.at(0), vec_names.at(0), vec_names.at(0)}; break;
case 2:
// repeat the 2nd actv func once, so we have 3 actv funcs
vec_names = {vec_names.at(0), vec_names.at(1), vec_names.at(1)};
vec_names = {vec_names.at(0), vec_names.at(1), vec_names.at(1)};
break;
default: break;
......
......@@ -994,15 +994,15 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto spph = prog.insert_instruction(ins, op::squeeze{{0}}, pph);
auto pphi = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph);
auto pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphi);
pphi_brcst = prog.insert_instruction(ins, op::contiguous{}, pphi_brcst);
pphi_brcst = prog.insert_instruction(ins, op::contiguous{}, pphi_brcst);
auto ppho = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph);
auto ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, ppho);
ppho_brcst = prog.insert_instruction(ins, op::contiguous{}, ppho_brcst);
ppho_brcst = prog.insert_instruction(ins, op::contiguous{}, ppho_brcst);
auto pphf = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph);
auto pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_shape}, pphf);
pphf_brcst = prog.insert_instruction(ins, op::contiguous{}, pphf_brcst);
pphf_brcst = prog.insert_instruction(ins, op::contiguous{}, pphf_brcst);
for(long i = 0; i < seq_len; ++i)
{
......@@ -1015,7 +1015,7 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_ri);
auto pphi_ct = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic);
auto it_before_actv = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, pphi_ct);
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, pphi_ct);
if(bias != prog.end())
{
it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, bi_brcst);
......@@ -1027,7 +1027,7 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto ht_rf = prog.insert_instruction(ins, op::dot{}, sih, tran_rf);
auto pphf_ct = prog.insert_instruction(ins, op::mul{}, pphf_brcst, sic);
auto ft_before_actv = prog.insert_instruction(ins, op::add{}, xt_wf, ht_rf);
ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, pphf_ct);
ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, pphf_ct);
if(bias != prog.end())
{
ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, bf_brcst);
......@@ -1055,7 +1055,7 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto ht_ro = prog.insert_instruction(ins, op::dot{}, sih, tran_ro);
auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt);
auto ot_before_actv = prog.insert_instruction(ins, op::add{}, xt_wo, ht_ro);
ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, ppho_cellt);
ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, ppho_cellt);
if(bias != prog.end())
{
ot_before_actv = prog.insert_instruction(ins, op::add{}, ot_before_actv, bo_brcst);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment