Commit 35bc9bc7 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

remove the unnecessary hidden_state variable

parent 5571a352
......@@ -104,11 +104,9 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
// The following logic is to ensure the last instruction rewritten from
// rnn operator is a concat instruction
// sequence len is 1
instruction_ref hidden_output{};
if(ret_forward[0] == prog.end())
{
hidden_output =
prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
}
else
{
......@@ -116,8 +114,7 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
hidden_output =
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
}
}
else
......@@ -154,16 +151,15 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
// following logic is to ensure the last instruction is a
// concat instruction
// sequence len is 1
instruction_ref hidden_output{};
if(ret[0] == prog.end())
{
hidden_output = prog.replace_instruction(ins, op::concat{0}, ret[1]);
prog.replace_instruction(ins, op::concat{0}, ret[1]);
}
else
{
auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0];
hidden_output = prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment