Commit 467649aa authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix a bug in gru pass.

parent 4c59b8fd
...@@ -276,18 +276,18 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward, ...@@ -276,18 +276,18 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
auto z1tht = prog.insert_instruction(ins, op::mul{}, z1t, ht); auto z1tht = prog.insert_instruction(ins, op::mul{}, z1t, ht);
auto ztht1 = prog.insert_instruction(ins, op::mul{}, zt, ih); auto ztht1 = prog.insert_instruction(ins, op::mul{}, zt, ih);
ih = prog.insert_instruction(ins, op::add{}, z1tht, ztht1); ih = prog.insert_instruction(ins, op::add{}, z1tht, ztht1);
final_out = ih; final_out = prog.insert_instruction(ins, op::unsqueeze{{0}}, ih);
if(is_forward) if(is_forward)
{ {
hidden_out = hidden_out =
(seq_index == 0) ? ih : prog.insert_instruction(ins, op::concat{0}, hidden_out, ih); (seq_index == 0) ? final_out : prog.insert_instruction(ins, op::concat{0}, hidden_out, final_out);
} }
else else
{ {
hidden_out = (seq_index == seq_len - 1) hidden_out = (seq_index == seq_len - 1)
? ih ? final_out
: prog.insert_instruction(ins, op::concat{0}, ih, hidden_out); : prog.insert_instruction(ins, op::concat{0}, final_out, hidden_out);
} }
seq_index = is_forward ? (seq_index + 1) : (seq_index - 1); seq_index = is_forward ? (seq_index + 1) : (seq_index - 1);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment