Commit 398c0157 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'gru_operator' into lstm_operator

parents 67ecbc57 d4aa7d46
...@@ -57,7 +57,7 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const ...@@ -57,7 +57,7 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
// process bias // process bias
instruction_ref bias_forward = prog.end(); instruction_ref bias_forward = prog.end();
instruction_ref bias_reverse = prog.end(); instruction_ref bias_reverse = prog.end();
if(args.size() >= 4 && args[3]->get_operator().name() != "undefined") if(args.size() >= 4 && args[3]->name() != "undefined")
{ {
bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]); bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]); bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
...@@ -67,7 +67,7 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const ...@@ -67,7 +67,7 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
// or the 5th one (if the sequence len argument is ignored) // or the 5th one (if the sequence len argument is ignored)
instruction_ref ih_forward{}; instruction_ref ih_forward{};
instruction_ref ih_reverse{}; instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->get_operator().name() != "undefined") if(args.size() == 6 && args[5]->name() != "undefined")
{ {
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]); ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]); ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
...@@ -131,14 +131,14 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const ...@@ -131,14 +131,14 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
// process bias and initial hidden state // process bias and initial hidden state
instruction_ref bias = prog.end(); instruction_ref bias = prog.end();
if(args.size() >= 4 && args[3]->get_operator().name() != "undefined") if(args.size() >= 4 && args[3]->name() != "undefined")
{ {
bias = args[3]; bias = args[3];
} }
// process intial hidden state // process intial hidden state
instruction_ref ih; instruction_ref ih;
if(args.size() == 6 && args[5]->get_operator().name() != "undefined") if(args.size() == 6 && args[5]->name() != "undefined")
{ {
ih = args[5]; ih = args[5];
} }
...@@ -340,7 +340,7 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const ...@@ -340,7 +340,7 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
// bias // bias
instruction_ref bias_forward = prog.end(); instruction_ref bias_forward = prog.end();
instruction_ref bias_reverse = prog.end(); instruction_ref bias_reverse = prog.end();
if(args.size() >= 4 && args[3]->get_operator().name() != "undefined") if(args.size() >= 4 && args[3]->name() != "undefined")
{ {
bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]); bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]); bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
...@@ -349,7 +349,7 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const ...@@ -349,7 +349,7 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
// intial hidden state // intial hidden state
instruction_ref ih_forward{}; instruction_ref ih_forward{};
instruction_ref ih_reverse{}; instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->get_operator().name() != "undefined") if(args.size() == 6 && args[5]->name() != "undefined")
{ {
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]); ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]); ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
...@@ -407,14 +407,14 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const ...@@ -407,14 +407,14 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
// bias // bias
instruction_ref bias = prog.end(); instruction_ref bias = prog.end();
if(args.size() >= 4 && args[3]->get_operator().name() != "undefined") if(args.size() >= 4 && args[3]->name() != "undefined")
{ {
bias = args[3]; bias = args[3];
} }
// intial hidden state // intial hidden state
instruction_ref ih{}; instruction_ref ih{};
if(args.size() == 6 && args[5]->get_operator().name() != "undefined") if(args.size() == 6 && args[5]->name() != "undefined")
{ {
ih = args[5]; ih = args[5];
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment