Commit 9025504f authored by Shucai Xiao's avatar Shucai Xiao
Browse files

rename rnn to vanilla_rnn

parent 657c6996
......@@ -21,9 +21,9 @@ struct rewrite_rnn
void apply(program& prog) const;
private:
// for vallina rnn operators
void apply_vallina_rnn(program& prog, instruction_ref ins) const;
std::vector<instruction_ref> rnn_cell(bool is_forward,
// for vanilla rnn operators
void apply_vanilla_rnn(program& prog, instruction_ref ins) const;
std::vector<instruction_ref> vanilla_rnn_cell(bool is_forward,
program& prog,
instruction_ref ins,
instruction_ref input,
......@@ -32,7 +32,7 @@ struct rewrite_rnn
instruction_ref bias,
instruction_ref ih,
operation& actv_func) const;
std::vector<operation> rnn_actv_funcs(instruction_ref ins) const;
std::vector<operation> vanilla_rnn_actv_funcs(instruction_ref ins) const;
// for gru operators
void apply_gru(program& prog, instruction_ref ins) const;
......
......@@ -14,7 +14,7 @@ void rewrite_rnn::apply(program& prog) const
{
if(ins->name() == "rnn")
{
apply_vallina_rnn(prog, ins);
apply_vanilla_rnn(prog, ins);
}
if(ins->name() == "gru")
......@@ -24,7 +24,7 @@ void rewrite_rnn::apply(program& prog) const
}
}
void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const
void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
{
assert(ins->name() == "rnn");
// could be 3 to 6 inputs, but the parse_rnn function will
......@@ -40,7 +40,7 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const
migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
std::vector<float> data(ih_shape.elements(), 0);
auto actv_funcs = rnn_actv_funcs(ins);
auto actv_funcs = vanilla_rnn_actv_funcs(ins);
auto rnn_op = any_cast<op::rnn>(ins->get_operator());
op::rnn::rnn_direction_t dicrt = rnn_op.direction;
instruction_ref last_output{};
......@@ -78,7 +78,7 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const
ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret_forward = rnn_cell(true,
auto ret_forward = vanilla_rnn_cell(true,
prog,
ins,
args[0],
......@@ -87,7 +87,7 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const
bias_forward,
ih_forward,
actv_funcs.at(0));
auto ret_reverse = rnn_cell(false,
auto ret_reverse = vanilla_rnn_cell(false,
prog,
ins,
args[0],
......@@ -147,7 +147,7 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const
ih = prog.add_literal(migraphx::literal{ih_shape, data});
}
auto ret = rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0));
auto ret = vanilla_rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0));
last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
// following logic is to ensure the last instruction is a
......@@ -183,7 +183,7 @@ void rewrite_rnn::apply_vallina_rnn(program& prog, instruction_ref ins) const
}
}
std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
program& prog,
instruction_ref ins,
instruction_ref input,
......@@ -271,7 +271,7 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
return {hidden_out, last_out};
}
std::vector<operation> rewrite_rnn::rnn_actv_funcs(instruction_ref ins) const
std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins) const
{
auto rnn_op = any_cast<op::rnn>(ins->get_operator());
// before rewrite the rnn operator, need to ensure
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment