Commit 71273501 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 28f8b1e8
......@@ -136,11 +136,11 @@ instruction_ref program::replace_instruction(instruction_ref ins, instruction_re
assert(has_instruction(rep));
assert(ins != rep);
if (ins == std::prev(this->end()))
if(ins == std::prev(this->end()))
{
// additional check to ensure the ins to be replaced is either
// the rnn_last_output, gru_last_output, or lstm_last_output
if (ins->name() == "rnn_last_output")
if(ins->name() == "rnn_last_output")
{
return replace_instruction(ins, op::identity{}, rep);
}
......
......@@ -85,20 +85,23 @@ void rewrite_rnn::apply(program& prog) const
ih_reverse,
rnn_op.actv_funcs.at(1));
auto concat_output = prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
auto concat_output =
prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
// The following logic is to ensure the last instruction rewritten from
// rnn operator is a concat instruction
// sequence len is 1
if (ret_forward[0] == prog.end())
if(ret_forward[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
}
else
else
{
ret_forward[0] = prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] = prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
ret_forward[0] =
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
}
}
......@@ -134,10 +137,10 @@ void rewrite_rnn::apply(program& prog) const
is_forward, prog, ins, args[0], w, r, bias, ih, rnn_op.actv_funcs.at(0));
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
// following logic is to ensure the last instruction is a
// following logic is to ensure the last instruction is a
// concat instruction
// sequence len is 1
if (ret[0] == prog.end())
if(ret[0] == prog.end())
{
prog.replace_instruction(ins, op::concat{0}, ret[1]);
}
......@@ -204,7 +207,7 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
}
instruction_ref hidden_out = prog.end(), last_out;
std::size_t seq_len = input->get_shape().lens()[0];
std::size_t seq_len = input->get_shape().lens()[0];
for(std::size_t i = 0; i < seq_len; i++)
{
long seq_index = is_forward ? i : (seq_len - 1 - i);
......@@ -234,19 +237,21 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
// concatenation for the last last_out is performed in the apply()
// function to ensure the last instruction is concat, then we have
// output inserted
if (i < seq_len - 1)
if(i < seq_len - 1)
{
if(is_forward)
{
hidden_out = (seq_index == 0)
? last_out
: prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out);
hidden_out =
(seq_index == 0)
? last_out
: prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out);
}
else
{
hidden_out = (seq_index == seq_len - 1)
? last_out
: prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out);
hidden_out =
(seq_index == seq_len - 1)
? last_out
: prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out);
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment