Commit 1133d782 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'gru_operator' into lstm_operator

parents 483c4508 82349b7d
......@@ -29,7 +29,7 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
assert(ins->name() == "rnn");
// could be 3 to 6 inputs, but the parse_rnn function will
// append undefined operators to make 6 arguments when parsing
// an onnx file. Another case is user can have only 3 arguments
// an onnx file. Another case is user can have num of arguments
// when writing their program.
auto args = ins->inputs();
......@@ -275,9 +275,10 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins) const
{
auto rnn_op = any_cast<op::rnn>(ins->get_operator());
// before rewrite the rnn operator, need to ensure
// we have 2 actv funcs. If less than 2, use the
// algorithm in parse_rnn to make 2 actv functions
// could be 3 to 6 inputs, but the parse_gru function will
// append undefined operators to make 6 arguments when parsing
// an onnx file. Another case is user can have any num of arguments
// when writing their program.
if(rnn_op.direction == op::rnn::bidirectional)
{
if(rnn_op.actv_funcs.empty())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment