Commit 7113cdfa authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix a bug in rnn operator pass.

parent 20f89fcc
......@@ -733,8 +733,8 @@ struct onnx_parser
result.push_back(hidden_states);
// second out for the last hidden state
// auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
// result.push_back(last_output);
auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
result.push_back(last_output);
return result;
}
......
......@@ -136,15 +136,15 @@ void rewrite_rnn::apply(program& prog) const
// operator. Intuitively, we can do a slice on the input to get
// the last output, but it is already existed in the rnn operator,
// so we can just use it as the output here
// if (ins->name() == "rnn_last_output")
//{
// // if rnn operator is executed, the last_output != prog.end()
// if (last_output != prog.end())
// {
// prog.replace_instruction(ins, op::identity{}, last_output);
// last_output = prog.end();
// }
//}
if (ins->name() == "rnn_last_output")
{
// if rnn operator is executed, the last_output != prog.end()
if (last_output != prog.end())
{
prog.replace_instruction(ins, op::identity{}, last_output);
last_output = prog.end();
}
}
}
}
......@@ -161,7 +161,7 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
// squeeze and transpose w
std::vector<int64_t> perm{1, 0};
auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
auto tran_sw = prog.insert_instruction(sw, op::transpose{perm}, sw);
auto tran_sw = prog.insert_instruction(ins, op::transpose{perm}, sw);
// squeeze and transpose r
auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment