Commit 71273501 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 28f8b1e8
...@@ -136,11 +136,11 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re ...@@ -136,11 +136,11 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
assert(has_instruction(rep)); assert(has_instruction(rep));
assert(ins != rep); assert(ins != rep);
if (ins == std::prev(this->end())) if(ins == std::prev(this->end()))
{ {
// additional check to ensure the ins to be replaced is either // additional check to ensure the ins to be replaced is either
// the rnn_last_output, gru_last_output, or lstm_last_output // the rnn_last_output, gru_last_output, or lstm_last_output
if (ins->name() == "rnn_last_output") if(ins->name() == "rnn_last_output")
{ {
return replace_instruction(ins, op::identity{}, rep); return replace_instruction(ins, op::identity{}, rep);
} }
......
...@@ -85,20 +85,23 @@ void rewrite_rnn::apply(program& prog) const ...@@ -85,20 +85,23 @@ void rewrite_rnn::apply(program& prog) const
ih_reverse, ih_reverse,
rnn_op.actv_funcs.at(1)); rnn_op.actv_funcs.at(1));
auto concat_output = prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]); auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output); last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
// The following logic is to ensure the last instruction rewritten from // The following logic is to ensure the last instruction rewritten from
// rnn operator is a concat instruction // rnn operator is a concat instruction
// sequence len is 1 // sequence len is 1
if (ret_forward[0] == prog.end()) if(ret_forward[0] == prog.end())
{ {
prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]); prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
} }
else else
{ {
ret_forward[0] = prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]); ret_forward[0] =
ret_reverse[0] = prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]); prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]}); prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
} }
} }
...@@ -137,7 +140,7 @@ void rewrite_rnn::apply(program& prog) const ...@@ -137,7 +140,7 @@ void rewrite_rnn::apply(program& prog) const
// following logic is to ensure the last instruction is a // following logic is to ensure the last instruction is a
// concat instruction // concat instruction
// sequence len is 1 // sequence len is 1
if (ret[0] == prog.end()) if(ret[0] == prog.end())
{ {
prog.replace_instruction(ins, op::concat{0}, ret[1]); prog.replace_instruction(ins, op::concat{0}, ret[1]);
} }
...@@ -234,17 +237,19 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward, ...@@ -234,17 +237,19 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
// concatenation for the last last_out is performed in the apply() // concatenation for the last last_out is performed in the apply()
// function to ensure the last instruction is concat, then we have // function to ensure the last instruction is concat, then we have
// output inserted // output inserted
if (i < seq_len - 1) if(i < seq_len - 1)
{ {
if(is_forward) if(is_forward)
{ {
hidden_out = (seq_index == 0) hidden_out =
(seq_index == 0)
? last_out ? last_out
: prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out); : prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out);
} }
else else
{ {
hidden_out = (seq_index == seq_len - 1) hidden_out =
(seq_index == seq_len - 1)
? last_out ? last_out
: prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out); : prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment