Commit 08207d66 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fixed bugs for gru operator.

parent 6b6f0f05
......@@ -103,15 +103,15 @@ void rewrite_gru::apply(program& prog) const
gru_op.actv_funcs.at(2),
gru_op.actv_funcs.at(3));
// auto final_output = prog.insert_instruction(ins, op::concat{0}, ret_forward[1],
// ret_reverse[1]);
auto final_output = prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]);
// add the dimension of num_direction
ret_forward[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_forward[0]);
ret_reverse[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_reverse[0]);
// concat the forward and reverse output
prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
auto replaced_arg = prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
replaced_arg->add_output(final_output);
}
else
{
......@@ -151,7 +151,8 @@ void rewrite_gru::apply(program& prog) const
gru_op.actv_funcs.at(1));
// add the dimension of num_direction
prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]);
auto replaced_arg = prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]);
replaced_arg->add_output(ret[1]);
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment