Commit 1877b194 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix comments about enum rnn_direction, rnn_last_output, and added more tests...

fix comments about enum rnn_direction, rnn_last_output, and added more tests to achieve better code coverage
parent 35bc9bc7
...@@ -1164,19 +1164,19 @@ struct outline ...@@ -1164,19 +1164,19 @@ struct outline
argument compute(const shape&, const std::vector<argument>&) const { return {s, nullptr}; } argument compute(const shape&, const std::vector<argument>&) const { return {s, nullptr}; }
}; };
struct rnn // indicate rnn computation direction
enum class rnn_direction
{ {
forward,
reverse,
bidirectional,
};
enum rnn_direction_t struct rnn
{ {
forward,
reverse,
bidirectional,
};
std::size_t hidden_size = 1; std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{tanh{}, tanh{}}; std::vector<operation> actv_funcs{tanh{}, tanh{}};
rnn_direction_t direction = forward; rnn_direction direction = rnn_direction::forward;
float clip = 0.0f; float clip = 0.0f;
std::string name() const { return "rnn"; } std::string name() const { return "rnn"; }
...@@ -1190,7 +1190,7 @@ struct rnn ...@@ -1190,7 +1190,7 @@ struct rnn
} }
std::size_t num_directions = 1; std::size_t num_directions = 1;
if(direction == bidirectional) if(direction == rnn_direction::bidirectional)
{ {
num_directions = 2; num_directions = 2;
} }
...@@ -1224,16 +1224,9 @@ struct rnn_last_output ...@@ -1224,16 +1224,9 @@ struct rnn_last_output
struct gru struct gru
{ {
enum gru_direction_t
{
forward,
reverse,
bidirectional,
};
std::size_t hidden_size = 1; std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{sigmoid{}, tanh{}}; std::vector<operation> actv_funcs{sigmoid{}, tanh{}};
gru_direction_t direction = forward; rnn_direction direction = rnn_direction::forward;
float clip = 0.0f; float clip = 0.0f;
int linear_before_reset = 0; int linear_before_reset = 0;
...@@ -1248,7 +1241,7 @@ struct gru ...@@ -1248,7 +1241,7 @@ struct gru
} }
std::size_t num_directions = 1; std::size_t num_directions = 1;
if(direction == bidirectional) if(direction == rnn_direction::bidirectional)
{ {
num_directions = 2; num_directions = 2;
} }
...@@ -1266,20 +1259,6 @@ struct gru ...@@ -1266,20 +1259,6 @@ struct gru
} }
}; };
struct gru_last_output
{
std::string name() const { return "gru_last_output"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto dims = inputs[0].lens();
// remove the first dimension, remaing are output shape
dims.erase(dims.begin());
return {inputs[0].type(), dims};
}
};
struct undefined struct undefined
{ {
std::string name() const { return "undefined"; } std::string name() const { return "undefined"; }
......
...@@ -734,14 +734,14 @@ struct onnx_parser ...@@ -734,14 +734,14 @@ struct onnx_parser
direction = attributes.at("direction").s(); direction = attributes.at("direction").s();
} }
op::rnn::rnn_direction_t dirct = op::rnn::forward; op::rnn_direction dirct = op::rnn_direction::forward;
if(direction == "bidirectional") if(direction == "bidirectional")
{ {
dirct = op::rnn::bidirectional; dirct = op::rnn_direction::bidirectional;
} }
else if(direction == "reverse") else if(direction == "reverse")
{ {
dirct = op::rnn::reverse; dirct = op::rnn_direction::reverse;
} }
std::vector<std::string> vec_names{"tanh"}; std::vector<std::string> vec_names{"tanh"};
...@@ -763,7 +763,7 @@ struct onnx_parser ...@@ -763,7 +763,7 @@ struct onnx_parser
// one is for forward, and the other is for reverse. // one is for forward, and the other is for reverse.
// if only one actv function is provided, we use it in both // if only one actv function is provided, we use it in both
// forward and reverse direction // forward and reverse direction
if(dirct == op::rnn::bidirectional) if(dirct == op::rnn_direction::bidirectional)
{ {
if(vec_names.size() == 1) if(vec_names.size() == 1)
{ {
...@@ -823,14 +823,14 @@ struct onnx_parser ...@@ -823,14 +823,14 @@ struct onnx_parser
direction = attributes.at("direction").s(); direction = attributes.at("direction").s();
} }
op::gru::gru_direction_t dirct = op::gru::forward; op::rnn_direction dirct = op::rnn_direction::forward;
if(direction == "bidirectional") if(direction == "bidirectional")
{ {
dirct = op::gru::bidirectional; dirct = op::rnn_direction::bidirectional;
} }
else if(direction == "reverse") else if(direction == "reverse")
{ {
dirct = op::gru::reverse; dirct = op::rnn_direction::reverse;
} }
std::vector<std::string> vec_names = {"sigmoid", "tanh"}; std::vector<std::string> vec_names = {"sigmoid", "tanh"};
...@@ -844,7 +844,7 @@ struct onnx_parser ...@@ -844,7 +844,7 @@ struct onnx_parser
} }
// need 4 activation functions // need 4 activation functions
if(dirct == op::gru::bidirectional) if(dirct == op::rnn_direction::bidirectional)
{ {
// 4 activation functions are used in the bidirectional // 4 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we // scenario. No spec is provided in onnx::operator. we
...@@ -915,7 +915,7 @@ struct onnx_parser ...@@ -915,7 +915,7 @@ struct onnx_parser
std::move(args)); std::move(args));
// second output for last gru output // second output for last gru output
auto last_output = prog.add_instruction(op::gru_last_output{}, hidden_states); auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
return {hidden_states, last_output}; return {hidden_states, last_output};
} }
......
...@@ -42,9 +42,9 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const ...@@ -42,9 +42,9 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
auto actv_funcs = vanilla_rnn_actv_funcs(ins); auto actv_funcs = vanilla_rnn_actv_funcs(ins);
auto rnn_op = any_cast<op::rnn>(ins->get_operator()); auto rnn_op = any_cast<op::rnn>(ins->get_operator());
op::rnn::rnn_direction_t dicrt = rnn_op.direction; op::rnn_direction dicrt = rnn_op.direction;
instruction_ref last_output{}; instruction_ref last_output{};
if(dicrt == op::rnn::bidirectional) if(dicrt == op::rnn_direction::bidirectional)
{ {
// input weight matrix // input weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]); auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
...@@ -119,7 +119,7 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const ...@@ -119,7 +119,7 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
} }
else else
{ {
bool is_forward = (dicrt == op::rnn::forward); bool is_forward = (dicrt == op::rnn_direction::forward);
// input weight matrix // input weight matrix
auto w = args[1]; auto w = args[1];
...@@ -275,7 +275,7 @@ std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins) ...@@ -275,7 +275,7 @@ std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins)
// append undefined operators to make 6 arguments when parsing // append undefined operators to make 6 arguments when parsing
// an onnx file. Another case is user can have any num of arguments // an onnx file. Another case is user can have any num of arguments
// when writing their program. // when writing their program.
if(rnn_op.direction == op::rnn::bidirectional) if(rnn_op.direction == op::rnn_direction::bidirectional)
{ {
if(rnn_op.actv_funcs.empty()) if(rnn_op.actv_funcs.empty())
{ {
...@@ -323,9 +323,9 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const ...@@ -323,9 +323,9 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
std::vector<float> data(ih_shape.elements(), 0.0); std::vector<float> data(ih_shape.elements(), 0.0);
auto gru_op = any_cast<op::gru>(ins->get_operator()); auto gru_op = any_cast<op::gru>(ins->get_operator());
op::gru::gru_direction_t dicrt = gru_op.direction; op::rnn_direction dicrt = gru_op.direction;
instruction_ref last_output{}; instruction_ref last_output{};
if(dicrt == op::gru::bidirectional) if(dicrt == op::rnn_direction::bidirectional)
{ {
// w weight matrix // w weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]); auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
...@@ -395,7 +395,7 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const ...@@ -395,7 +395,7 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
} }
else else
{ {
bool is_forward = (dicrt == op::gru::forward); bool is_forward = (dicrt == op::rnn_direction::forward);
// weight matrix // weight matrix
auto w = args[1]; auto w = args[1];
auto r = args[2]; auto r = args[2];
...@@ -440,14 +440,14 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const ...@@ -440,14 +440,14 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
} }
} }
// replace the corresponding gru_last_output instruction // replace the corresponding rnn_last_output instruction
// with the last_output, if gru_last_output exists // with the last_output, if rnn_last_output exists
// while loop to handle case of multiple gru_last_output operators // while loop to handle case of multiple rnn_last_output operators
auto last_output_it = ins->outputs().begin(); auto last_output_it = ins->outputs().begin();
while(last_output_it != ins->outputs().end()) while(last_output_it != ins->outputs().end())
{ {
last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) { last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "gru_last_output"; return i->name() == "rnn_last_output";
}); });
if(last_output_it != ins->outputs().end()) if(last_output_it != ins->outputs().end())
...@@ -631,7 +631,7 @@ std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const ...@@ -631,7 +631,7 @@ std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
// we have 4 actv funcs, even though a user does not // we have 4 actv funcs, even though a user does not
// specifiy any actv func. If less than 4, use the // specifiy any actv func. If less than 4, use the
// algorithm in parse_gru to make 4 actv functions // algorithm in parse_gru to make 4 actv functions
if(gru_op.direction == op::gru::bidirectional) if(gru_op.direction == op::rnn_direction::bidirectional)
{ {
if(gru_op.actv_funcs.empty()) if(gru_op.actv_funcs.empty())
return {op::sigmoid{}, op::tanh{}, op::sigmoid{}, op::tanh{}}; return {op::sigmoid{}, op::tanh{}, op::sigmoid{}, op::tanh{}};
......
This diff is collapsed.
...@@ -1156,7 +1156,7 @@ struct test_rnn_forward ...@@ -1156,7 +1156,7 @@ struct test_rnn_forward
auto output = auto output =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -1198,7 +1198,7 @@ struct test_rnn_forward10 ...@@ -1198,7 +1198,7 @@ struct test_rnn_forward10
auto output = auto output =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -1239,7 +1239,7 @@ struct test_rnn_reverse ...@@ -1239,7 +1239,7 @@ struct test_rnn_reverse
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -1279,7 +1279,7 @@ struct test_rnn_reverse2 ...@@ -1279,7 +1279,7 @@ struct test_rnn_reverse2
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -1314,7 +1314,7 @@ struct test_rnn_3args ...@@ -1314,7 +1314,7 @@ struct test_rnn_3args
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -1348,7 +1348,7 @@ struct test_rnn_4args ...@@ -1348,7 +1348,7 @@ struct test_rnn_4args
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -1385,7 +1385,7 @@ struct test_rnn_5args ...@@ -1385,7 +1385,7 @@ struct test_rnn_5args
auto output = auto output =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -1426,7 +1426,7 @@ struct test_rnn_bidirectional ...@@ -1426,7 +1426,7 @@ struct test_rnn_bidirectional
auto output = auto output =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -1467,7 +1467,7 @@ struct test_rnn_bidirectional10 ...@@ -1467,7 +1467,7 @@ struct test_rnn_bidirectional10
auto output = auto output =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -1505,7 +1505,7 @@ struct test_rnn_bi_3args ...@@ -1505,7 +1505,7 @@ struct test_rnn_bi_3args
auto output = auto output =
p.add_instruction(migraphx::op::rnn{hidden_size, p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -1546,7 +1546,7 @@ struct test_gru_forward_last ...@@ -1546,7 +1546,7 @@ struct test_gru_forward_last
auto output = auto output =
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -1554,7 +1554,7 @@ struct test_gru_forward_last ...@@ -1554,7 +1554,7 @@ struct test_gru_forward_last
bias, bias,
und, und,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, output); p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p; return p;
} }
...@@ -1589,7 +1589,7 @@ struct test_gru_forward_hs ...@@ -1589,7 +1589,7 @@ struct test_gru_forward_hs
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -1625,7 +1625,7 @@ struct test_gru_forward_3args_und ...@@ -1625,7 +1625,7 @@ struct test_gru_forward_3args_und
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -1660,7 +1660,7 @@ struct test_gru_forward_3args ...@@ -1660,7 +1660,7 @@ struct test_gru_forward_3args
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -1692,7 +1692,7 @@ struct test_gru_forward_seq1 ...@@ -1692,7 +1692,7 @@ struct test_gru_forward_seq1
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -1723,7 +1723,7 @@ struct test_gru_forward_default_actv ...@@ -1723,7 +1723,7 @@ struct test_gru_forward_default_actv
auto w = p.add_parameter("w", w_shape); auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
p.add_instruction( p.add_instruction(
migraphx::op::gru{hidden_size, {}, migraphx::op::gru::forward, clip}, seq, w, r); migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::forward, clip}, seq, w, r);
return p; return p;
} }
...@@ -1758,7 +1758,7 @@ struct test_gru_forward_default_actv1 ...@@ -1758,7 +1758,7 @@ struct test_gru_forward_default_actv1
p.add_instruction( p.add_instruction(
migraphx::op::gru{ migraphx::op::gru{
hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::gru::forward, clip}, hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::forward, clip},
seq, seq,
w, w,
r, r,
...@@ -1800,7 +1800,7 @@ struct test_gru_reverse_last ...@@ -1800,7 +1800,7 @@ struct test_gru_reverse_last
auto output = auto output =
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -1808,7 +1808,7 @@ struct test_gru_reverse_last ...@@ -1808,7 +1808,7 @@ struct test_gru_reverse_last
bias, bias,
und, und,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, output); p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p; return p;
} }
...@@ -1836,7 +1836,7 @@ struct test_gru_reverse_3args ...@@ -1836,7 +1836,7 @@ struct test_gru_reverse_3args
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -1876,7 +1876,7 @@ struct test_gru_bidirct_last ...@@ -1876,7 +1876,7 @@ struct test_gru_bidirct_last
auto output = auto output =
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -1884,7 +1884,7 @@ struct test_gru_bidirct_last ...@@ -1884,7 +1884,7 @@ struct test_gru_bidirct_last
bias, bias,
und, und,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, output); p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p; return p;
} }
...@@ -1919,7 +1919,7 @@ struct test_gru_bidirct_hs ...@@ -1919,7 +1919,7 @@ struct test_gru_bidirct_hs
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -1955,7 +1955,7 @@ struct test_gru_bidirct_3args_und ...@@ -1955,7 +1955,7 @@ struct test_gru_bidirct_3args_und
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -1990,7 +1990,7 @@ struct test_gru_bidirct_3args ...@@ -1990,7 +1990,7 @@ struct test_gru_bidirct_3args
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -2022,7 +2022,7 @@ struct test_gru_bidirct_seq1 ...@@ -2022,7 +2022,7 @@ struct test_gru_bidirct_seq1
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size, p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -2053,7 +2053,7 @@ struct test_gru_bidirct_default_actv ...@@ -2053,7 +2053,7 @@ struct test_gru_bidirct_default_actv
auto w = p.add_parameter("w", w_shape); auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
p.add_instruction( p.add_instruction(
migraphx::op::gru{hidden_size, {}, migraphx::op::gru::bidirectional, clip}, seq, w, r); migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip}, seq, w, r);
return p; return p;
} }
...@@ -2088,7 +2088,7 @@ struct test_gru_bidirct_default_actv1 ...@@ -2088,7 +2088,7 @@ struct test_gru_bidirct_default_actv1
p.add_instruction( p.add_instruction(
migraphx::op::gru{ migraphx::op::gru{
hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::gru::bidirectional, clip}, hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::bidirectional, clip},
seq, seq,
w, w,
r, r,
......
...@@ -491,7 +491,7 @@ TEST_CASE(rnn_test) ...@@ -491,7 +491,7 @@ TEST_CASE(rnn_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hs, p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -523,7 +523,7 @@ TEST_CASE(rnn_test) ...@@ -523,7 +523,7 @@ TEST_CASE(rnn_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hs, p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -555,7 +555,7 @@ TEST_CASE(rnn_test) ...@@ -555,7 +555,7 @@ TEST_CASE(rnn_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hs, p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -583,7 +583,7 @@ TEST_CASE(rnn_test) ...@@ -583,7 +583,7 @@ TEST_CASE(rnn_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hs, p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -615,7 +615,7 @@ TEST_CASE(rnn_test) ...@@ -615,7 +615,7 @@ TEST_CASE(rnn_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::rnn{hs, p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -658,15 +658,15 @@ TEST_CASE(gru_test) ...@@ -658,15 +658,15 @@ TEST_CASE(gru_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::gru::forward, migraphx::op::rnn_direction::forward,
clip}, clip, 1},
seq, seq,
w, w,
r, r,
bias, bias,
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_forward.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_forward.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -692,7 +692,7 @@ TEST_CASE(gru_test) ...@@ -692,7 +692,7 @@ TEST_CASE(gru_test)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::gru::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -700,7 +700,7 @@ TEST_CASE(gru_test) ...@@ -700,7 +700,7 @@ TEST_CASE(gru_test)
bias, bias,
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_reverse.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_reverse.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -728,7 +728,7 @@ TEST_CASE(gru_test) ...@@ -728,7 +728,7 @@ TEST_CASE(gru_test)
migraphx::op::sigmoid{}, migraphx::op::sigmoid{},
migraphx::op::relu{}, migraphx::op::relu{},
migraphx::op::tanh{}}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -736,7 +736,7 @@ TEST_CASE(gru_test) ...@@ -736,7 +736,7 @@ TEST_CASE(gru_test)
bias, bias,
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_bi.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -767,7 +767,7 @@ TEST_CASE(gru_test_args) ...@@ -767,7 +767,7 @@ TEST_CASE(gru_test_args)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::gru::forward, migraphx::op::rnn_direction::forward,
clip}, clip},
seq, seq,
w, w,
...@@ -775,7 +775,7 @@ TEST_CASE(gru_test_args) ...@@ -775,7 +775,7 @@ TEST_CASE(gru_test_args)
und, und,
und, und,
und); und);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_3arg.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_3arg.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -799,7 +799,7 @@ TEST_CASE(gru_test_args) ...@@ -799,7 +799,7 @@ TEST_CASE(gru_test_args)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::gru::reverse, migraphx::op::rnn_direction::reverse,
clip}, clip},
seq, seq,
w, w,
...@@ -807,7 +807,7 @@ TEST_CASE(gru_test_args) ...@@ -807,7 +807,7 @@ TEST_CASE(gru_test_args)
bias, bias,
und, und,
und); und);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_4arg.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_4arg.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -833,7 +833,7 @@ TEST_CASE(gru_test_args) ...@@ -833,7 +833,7 @@ TEST_CASE(gru_test_args)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -841,7 +841,7 @@ TEST_CASE(gru_test_args) ...@@ -841,7 +841,7 @@ TEST_CASE(gru_test_args)
bias, bias,
seq_len, seq_len,
und); und);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_5arg.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_5arg.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -874,14 +874,14 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -874,14 +874,14 @@ TEST_CASE(gru_test_actv_funcs)
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::gru::bidirectional, clip}, p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::bidirectional, clip},
seq, seq,
w, w,
r, r,
bias, bias,
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_0.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_bi_0.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -905,14 +905,14 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -905,14 +905,14 @@ TEST_CASE(gru_test_actv_funcs)
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::gru{hs, {migraphx::op::tanh{}}, migraphx::op::gru::bidirectional, clip}, migraphx::op::gru{hs, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip},
seq, seq,
w, w,
r, r,
bias, bias,
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_1.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_bi_1.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -938,7 +938,7 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -938,7 +938,7 @@ TEST_CASE(gru_test_actv_funcs)
auto out_hs = auto out_hs =
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -946,7 +946,7 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -946,7 +946,7 @@ TEST_CASE(gru_test_actv_funcs)
bias, bias,
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_2.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_bi_2.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -972,7 +972,7 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -972,7 +972,7 @@ TEST_CASE(gru_test_actv_funcs)
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::gru{hs, migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional, migraphx::op::rnn_direction::bidirectional,
clip}, clip},
seq, seq,
w, w,
...@@ -980,7 +980,7 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -980,7 +980,7 @@ TEST_CASE(gru_test_actv_funcs)
bias, bias,
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_3.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_bi_3.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -1003,14 +1003,14 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -1003,14 +1003,14 @@ TEST_CASE(gru_test_actv_funcs)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::gru::forward, clip}, auto out_hs = p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::forward, clip},
seq, seq,
w, w,
r, r,
bias, bias,
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_forward_0.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_forward_0.onnx");
EXPECT(p == prog); EXPECT(p == prog);
...@@ -1034,14 +1034,14 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -1034,14 +1034,14 @@ TEST_CASE(gru_test_actv_funcs)
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::gru{hs, {migraphx::op::relu{}}, migraphx::op::gru::reverse, clip}, migraphx::op::gru{hs, {migraphx::op::relu{}}, migraphx::op::rnn_direction::reverse, clip},
seq, seq,
w, w,
r, r,
bias, bias,
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_reverse_1.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_reverse_1.onnx");
EXPECT(p == prog); EXPECT(p == prog);
......
...@@ -285,7 +285,7 @@ TEST_CASE(rnn) ...@@ -285,7 +285,7 @@ TEST_CASE(rnn)
expect_shape(migraphx::shape{migraphx::shape::float_type, expect_shape(migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}}, {seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::rnn{ migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::forward, clip}, hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
in_shape, in_shape,
w_shape, w_shape,
r_shape, r_shape,
...@@ -310,7 +310,7 @@ TEST_CASE(rnn) ...@@ -310,7 +310,7 @@ TEST_CASE(rnn)
expect_shape(migraphx::shape{migraphx::shape::float_type, expect_shape(migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}}, {seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::rnn{ migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::reverse, clip}, hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip},
in_shape, in_shape,
w_shape, w_shape,
r_shape, r_shape,
...@@ -336,7 +336,7 @@ TEST_CASE(rnn) ...@@ -336,7 +336,7 @@ TEST_CASE(rnn)
migraphx::shape{migraphx::shape::float_type, migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}}, {seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::rnn{ migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::bidirectional, clip}, hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip},
in_shape, in_shape,
w_shape, w_shape,
r_shape, r_shape,
...@@ -360,7 +360,7 @@ TEST_CASE(rnn) ...@@ -360,7 +360,7 @@ TEST_CASE(rnn)
throws_shape( throws_shape(
migraphx::op::rnn{ migraphx::op::rnn{
hidden_size + 1, {migraphx::op::tanh{}}, migraphx::op::rnn::forward, clip}, hidden_size + 1, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
in_shape, in_shape,
w_shape, w_shape,
r_shape, r_shape,
...@@ -384,7 +384,7 @@ TEST_CASE(rnn) ...@@ -384,7 +384,7 @@ TEST_CASE(rnn)
throws_shape( throws_shape(
migraphx::op::rnn{ migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::bidirectional, clip}, hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip},
in_shape, in_shape,
w_shape, w_shape,
r_shape, r_shape,
...@@ -408,7 +408,7 @@ TEST_CASE(rnn) ...@@ -408,7 +408,7 @@ TEST_CASE(rnn)
throws_shape( throws_shape(
migraphx::op::rnn{ migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::forward, clip}, hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
in_shape, in_shape,
w_shape, w_shape,
r_shape, r_shape,
...@@ -438,7 +438,7 @@ TEST_CASE(gru) ...@@ -438,7 +438,7 @@ TEST_CASE(gru)
expect_shape(migraphx::shape{migraphx::shape::float_type, expect_shape(migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}}, {seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::gru{ migraphx::op::gru{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::gru::forward, clip}, hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
in_shape, in_shape,
w_shape, w_shape,
r_shape, r_shape,
...@@ -465,7 +465,7 @@ TEST_CASE(gru) ...@@ -465,7 +465,7 @@ TEST_CASE(gru)
expect_shape(migraphx::shape{migraphx::shape::float_type, expect_shape(migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}}, {seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::gru{ migraphx::op::gru{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::gru::reverse, clip}, hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip},
in_shape, in_shape,
w_shape, w_shape,
r_shape, r_shape,
...@@ -493,7 +493,7 @@ TEST_CASE(gru) ...@@ -493,7 +493,7 @@ TEST_CASE(gru)
migraphx::shape{migraphx::shape::float_type, migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}}, {seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::gru{ migraphx::op::gru{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::gru::bidirectional, clip}, hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip},
in_shape, in_shape,
w_shape, w_shape,
r_shape, r_shape,
...@@ -519,7 +519,7 @@ TEST_CASE(gru) ...@@ -519,7 +519,7 @@ TEST_CASE(gru)
throws_shape( throws_shape(
migraphx::op::gru{ migraphx::op::gru{
hidden_size + 1, {migraphx::op::tanh{}}, migraphx::op::gru::forward, clip}, hidden_size + 1, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
in_shape, in_shape,
w_shape, w_shape,
r_shape, r_shape,
...@@ -545,7 +545,7 @@ TEST_CASE(gru) ...@@ -545,7 +545,7 @@ TEST_CASE(gru)
throws_shape( throws_shape(
migraphx::op::gru{ migraphx::op::gru{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::gru::bidirectional, clip}, hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip},
in_shape, in_shape,
w_shape, w_shape,
r_shape, r_shape,
...@@ -571,7 +571,7 @@ TEST_CASE(gru) ...@@ -571,7 +571,7 @@ TEST_CASE(gru)
throws_shape( throws_shape(
migraphx::op::gru{ migraphx::op::gru{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::gru::forward, clip}, hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
in_shape, in_shape,
w_shape, w_shape,
r_shape, r_shape,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment