Commit 34d9ed70 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'rnn_operator' into gru_operator

parents 80016cff d3a09f1a
...@@ -43,8 +43,8 @@ void rewrite_rnn::apply(program& prog) const ...@@ -43,8 +43,8 @@ void rewrite_rnn::apply(program& prog) const
auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]); auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
// process bias // process bias
instruction_ref bias_forward, bias_reverse; instruction_ref bias_forward = prog.end();
bias_forward = bias_reverse = prog.end(); instruction_ref bias_reverse = prog.end();
if(args.size() >= 4 && args[3]->get_operator().name() != "undefined") if(args.size() >= 4 && args[3]->get_operator().name() != "undefined")
{ {
bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]); bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
...@@ -53,7 +53,8 @@ void rewrite_rnn::apply(program& prog) const ...@@ -53,7 +53,8 @@ void rewrite_rnn::apply(program& prog) const
// process intial hidden state, it could be the 6th argument // process intial hidden state, it could be the 6th argument
// or the 5th one (if the sequence len argument is ignored) // or the 5th one (if the sequence len argument is ignored)
instruction_ref ih_forward, ih_reverse; instruction_ref ih_forward{};
instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->get_operator().name() != "undefined") if(args.size() == 6 && args[5]->get_operator().name() != "undefined")
{ {
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]); ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
...@@ -215,9 +216,10 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward, ...@@ -215,9 +216,10 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
bias = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, b); bias = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, b);
} }
instruction_ref hidden_out = prog.end(), last_out; instruction_ref hidden_out = prog.end();
last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih); instruction_ref last_out{};
std::size_t seq_len = input->get_shape().lens()[0]; last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);
std::size_t seq_len = input->get_shape().lens()[0];
for(std::size_t i = 0; i < seq_len; i++) for(std::size_t i = 0; i < seq_len; i++)
{ {
long seq_index = is_forward ? i : (seq_len - 1 - i); long seq_index = is_forward ? i : (seq_len - 1 - i);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment