Commit 71273501 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 28f8b1e8
...@@ -136,11 +136,11 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re ...@@ -136,11 +136,11 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
assert(has_instruction(rep)); assert(has_instruction(rep));
assert(ins != rep); assert(ins != rep);
if (ins == std::prev(this->end())) if(ins == std::prev(this->end()))
{ {
// additional check to ensure the ins to be replaced is either // additional check to ensure the ins to be replaced is either
// the rnn_last_output, gru_last_output, or lstm_last_output // the rnn_last_output, gru_last_output, or lstm_last_output
if (ins->name() == "rnn_last_output") if(ins->name() == "rnn_last_output")
{ {
return replace_instruction(ins, op::identity{}, rep); return replace_instruction(ins, op::identity{}, rep);
} }
......
...@@ -85,20 +85,23 @@ void rewrite_rnn::apply(program& prog) const ...@@ -85,20 +85,23 @@ void rewrite_rnn::apply(program& prog) const
ih_reverse, ih_reverse,
rnn_op.actv_funcs.at(1)); rnn_op.actv_funcs.at(1));
auto concat_output = prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]); auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output); last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
// The following logic is to ensure the last instruction rewritten from // The following logic is to ensure the last instruction rewritten from
// rnn operator is a concat instruction // rnn operator is a concat instruction
// sequence len is 1 // sequence len is 1
if (ret_forward[0] == prog.end()) if(ret_forward[0] == prog.end())
{ {
prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]); prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
} }
else else
{ {
ret_forward[0] = prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]); ret_forward[0] =
ret_reverse[0] = prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]); prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]}); prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
} }
} }
...@@ -134,10 +137,10 @@ void rewrite_rnn::apply(program& prog) const ...@@ -134,10 +137,10 @@ void rewrite_rnn::apply(program& prog) const
is_forward, prog, ins, args[0], w, r, bias, ih, rnn_op.actv_funcs.at(0)); is_forward, prog, ins, args[0], w, r, bias, ih, rnn_op.actv_funcs.at(0));
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]); last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
// following logic is to ensure the last instruction is a // following logic is to ensure the last instruction is a
// concat instruction // concat instruction
// sequence len is 1 // sequence len is 1
if (ret[0] == prog.end()) if(ret[0] == prog.end())
{ {
prog.replace_instruction(ins, op::concat{0}, ret[1]); prog.replace_instruction(ins, op::concat{0}, ret[1]);
} }
...@@ -204,7 +207,7 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward, ...@@ -204,7 +207,7 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
} }
instruction_ref hidden_out = prog.end(), last_out; instruction_ref hidden_out = prog.end(), last_out;
std::size_t seq_len = input->get_shape().lens()[0]; std::size_t seq_len = input->get_shape().lens()[0];
for(std::size_t i = 0; i < seq_len; i++) for(std::size_t i = 0; i < seq_len; i++)
{ {
long seq_index = is_forward ? i : (seq_len - 1 - i); long seq_index = is_forward ? i : (seq_len - 1 - i);
...@@ -234,19 +237,21 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward, ...@@ -234,19 +237,21 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
// concatenation for the last last_out is performed in the apply() // concatenation for the last last_out is performed in the apply()
// function to ensure the last instruction is concat, then we have // function to ensure the last instruction is concat, then we have
// output inserted // output inserted
if (i < seq_len - 1) if(i < seq_len - 1)
{ {
if(is_forward) if(is_forward)
{ {
hidden_out = (seq_index == 0) hidden_out =
? last_out (seq_index == 0)
: prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out); ? last_out
: prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out);
} }
else else
{ {
hidden_out = (seq_index == seq_len - 1) hidden_out =
? last_out (seq_index == seq_len - 1)
: prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out); ? last_out
: prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out);
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment