Commit 7113cdfa authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix a bug in rnn operator pass.

parent 20f89fcc
...@@ -733,8 +733,8 @@ struct onnx_parser ...@@ -733,8 +733,8 @@ struct onnx_parser
result.push_back(hidden_states); result.push_back(hidden_states);
// second out for the last hidden state // second out for the last hidden state
// auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states); auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
// result.push_back(last_output); result.push_back(last_output);
return result; return result;
} }
......
...@@ -136,15 +136,15 @@ void rewrite_rnn::apply(program& prog) const ...@@ -136,15 +136,15 @@ void rewrite_rnn::apply(program& prog) const
// operator. Intuitively, we can do a slice on the input to get // operator. Intuitively, we can do a slice on the input to get
// the last output, but it is already existed in the rnn operator, // the last output, but it is already existed in the rnn operator,
// so we can just use it as the output here // so we can just use it as the output here
// if (ins->name() == "rnn_last_output") if (ins->name() == "rnn_last_output")
//{ {
// // if rnn operator is executed, the last_output != prog.end() // if rnn operator is executed, the last_output != prog.end()
// if (last_output != prog.end()) if (last_output != prog.end())
// { {
// prog.replace_instruction(ins, op::identity{}, last_output); prog.replace_instruction(ins, op::identity{}, last_output);
// last_output = prog.end(); last_output = prog.end();
// } }
//} }
} }
} }
...@@ -161,7 +161,7 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward, ...@@ -161,7 +161,7 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
// squeeze and transpose w // squeeze and transpose w
std::vector<int64_t> perm{1, 0}; std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w); auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto tran_sw = prog.insert_instruction(sw, op::transpose{perm}, sw); auto tran_sw = prog.insert_instruction(ins, op::transpose{perm}, sw);
// squeeze and transpose r // squeeze and transpose r
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r); auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment