Commit 34d9ed70 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'rnn_operator' into gru_operator

parents 80016cff d3a09f1a
......@@ -43,8 +43,8 @@ void rewrite_rnn::apply(program& prog) const
auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
// process bias
instruction_ref bias_forward, bias_reverse;
bias_forward = bias_reverse = prog.end();
instruction_ref bias_forward = prog.end();
instruction_ref bias_reverse = prog.end();
if(args.size() >= 4 && args[3]->get_operator().name() != "undefined")
{
bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
......@@ -53,7 +53,8 @@ void rewrite_rnn::apply(program& prog) const
// process intial hidden state, it could be the 6th argument
// or the 5th one (if the sequence len argument is ignored)
instruction_ref ih_forward, ih_reverse;
instruction_ref ih_forward{};
instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->get_operator().name() != "undefined")
{
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
......@@ -215,7 +216,8 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
bias = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, b);
}
instruction_ref hidden_out = prog.end(), last_out;
instruction_ref hidden_out = prog.end();
instruction_ref last_out{};
last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);
std::size_t seq_len = input->get_shape().lens()[0];
for(std::size_t i = 0; i < seq_len; i++)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment