Commit c194c35c authored by Shucai Xiao's avatar Shucai Xiao
Browse files

code cleanup

parent 82349b7d
...@@ -313,9 +313,10 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const ...@@ -313,9 +313,10 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
{ {
assert(ins->name() == "gru"); assert(ins->name() == "gru");
const auto actv_funcs = gru_actv_funcs(ins); const auto actv_funcs = gru_actv_funcs(ins);
// could be 3 to 5 inputs (though onnx::rnn has 6 inputs, // could be 3 to 6 inputs, but the parse_gru function will
// the 5th one is undefined and ignored by protobuf. so // append undefined operators to make 6 arguments when parsing
// we need to process up to 5 inputs // an onnx file. Another case is user can have num of arguments
// when writing their program.
auto args = ins->inputs(); auto args = ins->inputs();
shape seq_shape = args[0]->get_shape(); shape seq_shape = args[0]->get_shape();
...@@ -383,11 +384,9 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const ...@@ -383,11 +384,9 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
// The following logic is to ensure the last instruction rewritten // The following logic is to ensure the last instruction rewritten
// from gru operator is a concat // from gru operator is a concat
instruction_ref hidden_state{};
if(ret_forward[0] == prog.end()) if(ret_forward[0] == prog.end())
{ {
hidden_state = prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
} }
else else
{ {
...@@ -395,8 +394,7 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const ...@@ -395,8 +394,7 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]); prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
ret_reverse[0] = ret_reverse[0] =
prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]); prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
hidden_state = prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
} }
} }
else else
...@@ -434,16 +432,15 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const ...@@ -434,16 +432,15 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]); last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
instruction_ref hidden_state{};
if(ret[0] == prog.end()) if(ret[0] == prog.end())
{ {
hidden_state = prog.replace_instruction(ins, op::concat{0}, ret[1]); prog.replace_instruction(ins, op::concat{0}, ret[1]);
} }
else else
{ {
auto concat_arg0 = is_forward ? ret[0] : ret[1]; auto concat_arg0 = is_forward ? ret[0] : ret[1];
auto concat_arg1 = is_forward ? ret[1] : ret[0]; auto concat_arg1 = is_forward ? ret[1] : ret[0];
hidden_state = prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1); prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment