Commit 6eb69989 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add cpu test for gru operator.

parent 1d99f07b
...@@ -25,7 +25,7 @@ void rewrite_gru::apply(program& prog) const ...@@ -25,7 +25,7 @@ void rewrite_gru::apply(program& prog) const
std::size_t batch_size = seq_shape.lens()[1]; std::size_t batch_size = seq_shape.lens()[1];
shape::type_t type = seq_shape.type(); shape::type_t type = seq_shape.type();
migraphx::shape ih_shape{type, {1, batch_size, hidden_size}}; migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
std::vector<char> data(ih_shape.bytes(), 0); std::vector<float> data(ih_shape.elements(), 0.0);
auto gru_op = any_cast<op::gru>(ins->get_operator()); auto gru_op = any_cast<op::gru>(ins->get_operator());
op::gru::gru_direction_t dicrt = gru_op.direction; op::gru::gru_direction_t dicrt = gru_op.direction;
...@@ -41,8 +41,8 @@ void rewrite_gru::apply(program& prog) const ...@@ -41,8 +41,8 @@ void rewrite_gru::apply(program& prog) const
auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]); auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
// bias // bias
instruction_ref bias_forward, bias_reverse; instruction_ref bias_forward = prog.end();
bias_forward = bias_reverse = prog.end(); instruction_ref bias_reverse = prog.end();
if(args.size() >= 4 && args[3]->get_operator().name() != "undefined") if(args.size() >= 4 && args[3]->get_operator().name() != "undefined")
{ {
bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]); bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
...@@ -50,7 +50,8 @@ void rewrite_gru::apply(program& prog) const ...@@ -50,7 +50,8 @@ void rewrite_gru::apply(program& prog) const
} }
// intial hidden state // intial hidden state
instruction_ref ih_forward, ih_reverse; instruction_ref ih_forward{};
instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->get_operator().name() != "undefined") if(args.size() == 6 && args[5]->get_operator().name() != "undefined")
{ {
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]); ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
...@@ -117,7 +118,7 @@ void rewrite_gru::apply(program& prog) const ...@@ -117,7 +118,7 @@ void rewrite_gru::apply(program& prog) const
} }
// intial hidden state // intial hidden state
instruction_ref ih; instruction_ref ih{};
if(args.size() == 6 && args[5]->get_operator().name() != "undefined") if(args.size() == 6 && args[5]->get_operator().name() != "undefined")
{ {
ih = args[5]; ih = args[5];
...@@ -215,7 +216,11 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward, ...@@ -215,7 +216,11 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward,
auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih); auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
// bias // bias
instruction_ref brcst_bz, brcst_br, brcst_wbh, brcst_rbh, brcst_bh; instruction_ref brcst_bz{};
instruction_ref brcst_br{};
instruction_ref brcst_wbh{};
instruction_ref brcst_rbh{};
instruction_ref brcst_bh{};
if(bias != prog.end()) if(bias != prog.end())
{ {
auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias); auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment