Commit 9025504f authored by Shucai Xiao's avatar Shucai Xiao
Browse files

rename rnn to vanilla_rnn

parent 657c6996
...@@ -21,9 +21,9 @@ struct rewrite_rnn ...@@ -21,9 +21,9 @@ struct rewrite_rnn
void apply(program& prog) const; void apply(program& prog) const;
private: private:
// for vallina rnn operators // for vanilla rnn operators
void apply_vallina_rnn(program& prog, instruction_ref ins) const; void apply_vanilla_rnn(program& prog, instruction_ref ins) const;
std::vector<instruction_ref> rnn_cell(bool is_forward, std::vector<instruction_ref> vanilla_rnn_cell(bool is_forward,
program& prog, program& prog,
instruction_ref ins, instruction_ref ins,
instruction_ref input, instruction_ref input,
...@@ -32,7 +32,7 @@ struct rewrite_rnn ...@@ -32,7 +32,7 @@ struct rewrite_rnn
instruction_ref bias, instruction_ref bias,
instruction_ref ih, instruction_ref ih,
operation& actv_func) const; operation& actv_func) const;
std::vector<operation> rnn_actv_funcs(instruction_ref ins) const; std::vector<operation> vanilla_rnn_actv_funcs(instruction_ref ins) const;
// for gru operators // for gru operators
void apply_gru(program& prog, instruction_ref ins) const; void apply_gru(program& prog, instruction_ref ins) const;
......
...@@ -14,7 +14,7 @@ void rewrite_rnn::apply(program& prog) const ...@@ -14,7 +14,7 @@ void rewrite_rnn::apply(program& prog) const
{ {
if(ins->name() == "rnn") if(ins->name() == "rnn")
{ {
apply_vallina_rnn(prog, ins); apply_vanilla_rnn(prog, ins);
} }
if(ins->name() == "gru") if(ins->name() == "gru")
...@@ -24,7 +24,7 @@ void rewrite_rnn::apply(program& prog) const ...@@ -24,7 +24,7 @@ void rewrite_rnn::apply(program& prog) const
} }
} }
void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
{ {
assert(ins->name() == "rnn"); assert(ins->name() == "rnn");
// could be 3 to 6 inputs, but the parse_rnn function will // could be 3 to 6 inputs, but the parse_rnn function will
...@@ -40,7 +40,7 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const ...@@ -40,7 +40,7 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const
migraphx::shape ih_shape{type, {1, batch_size, hidden_size}}; migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
std::vector<float> data(ih_shape.elements(), 0); std::vector<float> data(ih_shape.elements(), 0);
auto actv_funcs = rnn_actv_funcs(ins); auto actv_funcs = vanilla_rnn_actv_funcs(ins);
auto rnn_op = any_cast<op::rnn>(ins->get_operator()); auto rnn_op = any_cast<op::rnn>(ins->get_operator());
op::rnn::rnn_direction_t dicrt = rnn_op.direction; op::rnn::rnn_direction_t dicrt = rnn_op.direction;
instruction_ref last_output{}; instruction_ref last_output{};
...@@ -78,7 +78,7 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const ...@@ -78,7 +78,7 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const
ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data}); ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
} }
auto ret_forward = rnn_cell(true, auto ret_forward = vanilla_rnn_cell(true,
prog, prog,
ins, ins,
args[0], args[0],
...@@ -87,7 +87,7 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const ...@@ -87,7 +87,7 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const
bias_forward, bias_forward,
ih_forward, ih_forward,
actv_funcs.at(0)); actv_funcs.at(0));
auto ret_reverse = rnn_cell(false, auto ret_reverse = vanilla_rnn_cell(false,
prog, prog,
ins, ins,
args[0], args[0],
...@@ -147,7 +147,7 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const ...@@ -147,7 +147,7 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const
ih = prog.add_literal(migraphx::literal{ih_shape, data}); ih = prog.add_literal(migraphx::literal{ih_shape, data});
} }
auto ret = rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0)); auto ret = vanilla_rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0));
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]); last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
// following logic is to ensure the last instruction is a // following logic is to ensure the last instruction is a
...@@ -183,7 +183,7 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const ...@@ -183,7 +183,7 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const
} }
} }
std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward, std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
program& prog, program& prog,
instruction_ref ins, instruction_ref ins,
instruction_ref input, instruction_ref input,
...@@ -271,7 +271,7 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward, ...@@ -271,7 +271,7 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
return {hidden_out, last_out}; return {hidden_out, last_out};
} }
std::vector<operation> rewrite_rnn::rnn_actv_funcs(instruction_ref ins) const std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins) const
{ {
auto rnn_op = any_cast<op::rnn>(ins->get_operator()); auto rnn_op = any_cast<op::rnn>(ins->get_operator());
// before rewrite the rnn operator, need to ensure // before rewrite the rnn operator, need to ensure
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment