Commit 1877b194 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix comments about enum rnn_direction, rnn_last_output, and added more tests...

fix comments about enum rnn_direction, rnn_last_output, and added more tests to achieve better code coverage
parent 35bc9bc7
......@@ -1164,19 +1164,19 @@ struct outline
argument compute(const shape&, const std::vector<argument>&) const { return {s, nullptr}; }
};
struct rnn
// indicate rnn computation direction
enum class rnn_direction
{
enum rnn_direction_t
{
forward,
reverse,
bidirectional,
};
};
struct rnn
{
std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{tanh{}, tanh{}};
rnn_direction_t direction = forward;
rnn_direction direction = rnn_direction::forward;
float clip = 0.0f;
std::string name() const { return "rnn"; }
......@@ -1190,7 +1190,7 @@ struct rnn
}
std::size_t num_directions = 1;
if(direction == bidirectional)
if(direction == rnn_direction::bidirectional)
{
num_directions = 2;
}
......@@ -1224,16 +1224,9 @@ struct rnn_last_output
struct gru
{
enum gru_direction_t
{
forward,
reverse,
bidirectional,
};
std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{sigmoid{}, tanh{}};
gru_direction_t direction = forward;
rnn_direction direction = rnn_direction::forward;
float clip = 0.0f;
int linear_before_reset = 0;
......@@ -1248,7 +1241,7 @@ struct gru
}
std::size_t num_directions = 1;
if(direction == bidirectional)
if(direction == rnn_direction::bidirectional)
{
num_directions = 2;
}
......@@ -1266,20 +1259,6 @@ struct gru
}
};
struct gru_last_output
{
std::string name() const { return "gru_last_output"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto dims = inputs[0].lens();
// remove the first dimension, remaing are output shape
dims.erase(dims.begin());
return {inputs[0].type(), dims};
}
};
struct undefined
{
std::string name() const { return "undefined"; }
......
......@@ -734,14 +734,14 @@ struct onnx_parser
direction = attributes.at("direction").s();
}
op::rnn::rnn_direction_t dirct = op::rnn::forward;
op::rnn_direction dirct = op::rnn_direction::forward;
if(direction == "bidirectional")
{
dirct = op::rnn::bidirectional;
dirct = op::rnn_direction::bidirectional;
}
else if(direction == "reverse")
{
dirct = op::rnn::reverse;
dirct = op::rnn_direction::reverse;
}
std::vector<std::string> vec_names{"tanh"};
......@@ -763,7 +763,7 @@ struct onnx_parser
// one is for forward, and the other is for reverse.
// if only one actv function is provided, we use it in both
// forward and reverse direction
if(dirct == op::rnn::bidirectional)
if(dirct == op::rnn_direction::bidirectional)
{
if(vec_names.size() == 1)
{
......@@ -823,14 +823,14 @@ struct onnx_parser
direction = attributes.at("direction").s();
}
op::gru::gru_direction_t dirct = op::gru::forward;
op::rnn_direction dirct = op::rnn_direction::forward;
if(direction == "bidirectional")
{
dirct = op::gru::bidirectional;
dirct = op::rnn_direction::bidirectional;
}
else if(direction == "reverse")
{
dirct = op::gru::reverse;
dirct = op::rnn_direction::reverse;
}
std::vector<std::string> vec_names = {"sigmoid", "tanh"};
......@@ -844,7 +844,7 @@ struct onnx_parser
}
// need 4 activation functions
if(dirct == op::gru::bidirectional)
if(dirct == op::rnn_direction::bidirectional)
{
// 4 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we
......@@ -915,7 +915,7 @@ struct onnx_parser
std::move(args));
// second output for last gru output
auto last_output = prog.add_instruction(op::gru_last_output{}, hidden_states);
auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
return {hidden_states, last_output};
}
......
......@@ -42,9 +42,9 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
auto actv_funcs = vanilla_rnn_actv_funcs(ins);
auto rnn_op = any_cast<op::rnn>(ins->get_operator());
op::rnn::rnn_direction_t dicrt = rnn_op.direction;
op::rnn_direction dicrt = rnn_op.direction;
instruction_ref last_output{};
if(dicrt == op::rnn::bidirectional)
if(dicrt == op::rnn_direction::bidirectional)
{
// input weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
......@@ -119,7 +119,7 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
}
else
{
bool is_forward = (dicrt == op::rnn::forward);
bool is_forward = (dicrt == op::rnn_direction::forward);
// input weight matrix
auto w = args[1];
......@@ -275,7 +275,7 @@ std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins)
// append undefined operators to make 6 arguments when parsing
// an onnx file. Another case is user can have any num of arguments
// when writing their program.
if(rnn_op.direction == op::rnn::bidirectional)
if(rnn_op.direction == op::rnn_direction::bidirectional)
{
if(rnn_op.actv_funcs.empty())
{
......@@ -323,9 +323,9 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
std::vector<float> data(ih_shape.elements(), 0.0);
auto gru_op = any_cast<op::gru>(ins->get_operator());
op::gru::gru_direction_t dicrt = gru_op.direction;
op::rnn_direction dicrt = gru_op.direction;
instruction_ref last_output{};
if(dicrt == op::gru::bidirectional)
if(dicrt == op::rnn_direction::bidirectional)
{
// w weight matrix
auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
......@@ -395,7 +395,7 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
}
else
{
bool is_forward = (dicrt == op::gru::forward);
bool is_forward = (dicrt == op::rnn_direction::forward);
// weight matrix
auto w = args[1];
auto r = args[2];
......@@ -440,14 +440,14 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
}
}
// replace the corresponding gru_last_output instruction
// with the last_output, if gru_last_output exists
// while loop to handle case of multiple gru_last_output operators
// replace the corresponding rnn_last_output instruction
// with the last_output, if rnn_last_output exists
// while loop to handle case of multiple rnn_last_output operators
auto last_output_it = ins->outputs().begin();
while(last_output_it != ins->outputs().end())
{
last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
return i->name() == "gru_last_output";
return i->name() == "rnn_last_output";
});
if(last_output_it != ins->outputs().end())
......@@ -631,7 +631,7 @@ std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
// we have 4 actv funcs, even though a user does not
// specifiy any actv func. If less than 4, use the
// algorithm in parse_gru to make 4 actv functions
if(gru_op.direction == op::gru::bidirectional)
if(gru_op.direction == op::rnn_direction::bidirectional)
{
if(gru_op.actv_funcs.empty())
return {op::sigmoid{}, op::tanh{}, op::sigmoid{}, op::tanh{}};
......
This diff is collapsed.
......@@ -1156,7 +1156,7 @@ struct test_rnn_forward
auto output =
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::forward,
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
......@@ -1198,7 +1198,7 @@ struct test_rnn_forward10
auto output =
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::forward,
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
......@@ -1239,7 +1239,7 @@ struct test_rnn_reverse
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::reverse,
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
......@@ -1279,7 +1279,7 @@ struct test_rnn_reverse2
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::reverse,
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
......@@ -1314,7 +1314,7 @@ struct test_rnn_3args
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::reverse,
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
......@@ -1348,7 +1348,7 @@ struct test_rnn_4args
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::reverse,
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
......@@ -1385,7 +1385,7 @@ struct test_rnn_5args
auto output =
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::forward,
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
......@@ -1426,7 +1426,7 @@ struct test_rnn_bidirectional
auto output =
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::bidirectional,
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
......@@ -1467,7 +1467,7 @@ struct test_rnn_bidirectional10
auto output =
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::bidirectional,
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
......@@ -1505,7 +1505,7 @@ struct test_rnn_bi_3args
auto output =
p.add_instruction(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn::bidirectional,
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
......@@ -1546,7 +1546,7 @@ struct test_gru_forward_last
auto output =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward,
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
......@@ -1554,7 +1554,7 @@ struct test_gru_forward_last
bias,
und,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, output);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p;
}
......@@ -1589,7 +1589,7 @@ struct test_gru_forward_hs
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward,
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
......@@ -1625,7 +1625,7 @@ struct test_gru_forward_3args_und
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward,
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
......@@ -1660,7 +1660,7 @@ struct test_gru_forward_3args
auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward,
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
......@@ -1692,7 +1692,7 @@ struct test_gru_forward_seq1
auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::forward,
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
......@@ -1723,7 +1723,7 @@ struct test_gru_forward_default_actv
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
p.add_instruction(
migraphx::op::gru{hidden_size, {}, migraphx::op::gru::forward, clip}, seq, w, r);
migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::forward, clip}, seq, w, r);
return p;
}
......@@ -1758,7 +1758,7 @@ struct test_gru_forward_default_actv1
p.add_instruction(
migraphx::op::gru{
hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::gru::forward, clip},
hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::forward, clip},
seq,
w,
r,
......@@ -1800,7 +1800,7 @@ struct test_gru_reverse_last
auto output =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::reverse,
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
......@@ -1808,7 +1808,7 @@ struct test_gru_reverse_last
bias,
und,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, output);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p;
}
......@@ -1836,7 +1836,7 @@ struct test_gru_reverse_3args
auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::reverse,
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
......@@ -1876,7 +1876,7 @@ struct test_gru_bidirct_last
auto output =
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional,
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
......@@ -1884,7 +1884,7 @@ struct test_gru_bidirct_last
bias,
und,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, output);
p.add_instruction(migraphx::op::rnn_last_output{}, output);
return p;
}
......@@ -1919,7 +1919,7 @@ struct test_gru_bidirct_hs
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional,
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
......@@ -1955,7 +1955,7 @@ struct test_gru_bidirct_3args_und
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional,
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
......@@ -1990,7 +1990,7 @@ struct test_gru_bidirct_3args
auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional,
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
......@@ -2022,7 +2022,7 @@ struct test_gru_bidirct_seq1
auto r = p.add_parameter("r", r_shape);
p.add_instruction(migraphx::op::gru{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional,
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
......@@ -2053,7 +2053,7 @@ struct test_gru_bidirct_default_actv
auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape);
p.add_instruction(
migraphx::op::gru{hidden_size, {}, migraphx::op::gru::bidirectional, clip}, seq, w, r);
migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip}, seq, w, r);
return p;
}
......@@ -2088,7 +2088,7 @@ struct test_gru_bidirct_default_actv1
p.add_instruction(
migraphx::op::gru{
hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::gru::bidirectional, clip},
hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
r,
......
......@@ -491,7 +491,7 @@ TEST_CASE(rnn_test)
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::bidirectional,
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
......@@ -523,7 +523,7 @@ TEST_CASE(rnn_test)
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::forward,
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
......@@ -555,7 +555,7 @@ TEST_CASE(rnn_test)
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::reverse,
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
......@@ -583,7 +583,7 @@ TEST_CASE(rnn_test)
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::reverse,
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
......@@ -615,7 +615,7 @@ TEST_CASE(rnn_test)
auto out_hs =
p.add_instruction(migraphx::op::rnn{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn::reverse,
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
......@@ -658,15 +658,15 @@ TEST_CASE(gru_test)
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::gru::forward,
clip},
migraphx::op::rnn_direction::forward,
clip, 1},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_forward.onnx");
EXPECT(p == prog);
......@@ -692,7 +692,7 @@ TEST_CASE(gru_test)
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::gru::reverse,
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
......@@ -700,7 +700,7 @@ TEST_CASE(gru_test)
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_reverse.onnx");
EXPECT(p == prog);
......@@ -728,7 +728,7 @@ TEST_CASE(gru_test)
migraphx::op::sigmoid{},
migraphx::op::relu{},
migraphx::op::tanh{}},
migraphx::op::gru::bidirectional,
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
......@@ -736,7 +736,7 @@ TEST_CASE(gru_test)
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi.onnx");
EXPECT(p == prog);
......@@ -767,7 +767,7 @@ TEST_CASE(gru_test_args)
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::gru::forward,
migraphx::op::rnn_direction::forward,
clip},
seq,
w,
......@@ -775,7 +775,7 @@ TEST_CASE(gru_test_args)
und,
und,
und);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_3arg.onnx");
EXPECT(p == prog);
......@@ -799,7 +799,7 @@ TEST_CASE(gru_test_args)
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::gru::reverse,
migraphx::op::rnn_direction::reverse,
clip},
seq,
w,
......@@ -807,7 +807,7 @@ TEST_CASE(gru_test_args)
bias,
und,
und);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_4arg.onnx");
EXPECT(p == prog);
......@@ -833,7 +833,7 @@ TEST_CASE(gru_test_args)
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::gru::bidirectional,
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
......@@ -841,7 +841,7 @@ TEST_CASE(gru_test_args)
bias,
seq_len,
und);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_5arg.onnx");
EXPECT(p == prog);
......@@ -874,14 +874,14 @@ TEST_CASE(gru_test_actv_funcs)
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs =
p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::gru::bidirectional, clip},
p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_0.onnx");
EXPECT(p == prog);
......@@ -905,14 +905,14 @@ TEST_CASE(gru_test_actv_funcs)
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(
migraphx::op::gru{hs, {migraphx::op::tanh{}}, migraphx::op::gru::bidirectional, clip},
migraphx::op::gru{hs, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_1.onnx");
EXPECT(p == prog);
......@@ -938,7 +938,7 @@ TEST_CASE(gru_test_actv_funcs)
auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::gru::bidirectional,
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
......@@ -946,7 +946,7 @@ TEST_CASE(gru_test_actv_funcs)
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_2.onnx");
EXPECT(p == prog);
......@@ -972,7 +972,7 @@ TEST_CASE(gru_test_actv_funcs)
auto out_hs = p.add_instruction(
migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}},
migraphx::op::gru::bidirectional,
migraphx::op::rnn_direction::bidirectional,
clip},
seq,
w,
......@@ -980,7 +980,7 @@ TEST_CASE(gru_test_actv_funcs)
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_3.onnx");
EXPECT(p == prog);
......@@ -1003,14 +1003,14 @@ TEST_CASE(gru_test_actv_funcs)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::gru::forward, clip},
auto out_hs = p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::forward, clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_forward_0.onnx");
EXPECT(p == prog);
......@@ -1034,14 +1034,14 @@ TEST_CASE(gru_test_actv_funcs)
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(
migraphx::op::gru{hs, {migraphx::op::relu{}}, migraphx::op::gru::reverse, clip},
migraphx::op::gru{hs, {migraphx::op::relu{}}, migraphx::op::rnn_direction::reverse, clip},
seq,
w,
r,
bias,
seq_len,
ih);
p.add_instruction(migraphx::op::gru_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_reverse_1.onnx");
EXPECT(p == prog);
......
......@@ -285,7 +285,7 @@ TEST_CASE(rnn)
expect_shape(migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::forward, clip},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
in_shape,
w_shape,
r_shape,
......@@ -310,7 +310,7 @@ TEST_CASE(rnn)
expect_shape(migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::reverse, clip},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip},
in_shape,
w_shape,
r_shape,
......@@ -336,7 +336,7 @@ TEST_CASE(rnn)
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::bidirectional, clip},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip},
in_shape,
w_shape,
r_shape,
......@@ -360,7 +360,7 @@ TEST_CASE(rnn)
throws_shape(
migraphx::op::rnn{
hidden_size + 1, {migraphx::op::tanh{}}, migraphx::op::rnn::forward, clip},
hidden_size + 1, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
in_shape,
w_shape,
r_shape,
......@@ -384,7 +384,7 @@ TEST_CASE(rnn)
throws_shape(
migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::bidirectional, clip},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip},
in_shape,
w_shape,
r_shape,
......@@ -408,7 +408,7 @@ TEST_CASE(rnn)
throws_shape(
migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::forward, clip},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
in_shape,
w_shape,
r_shape,
......@@ -438,7 +438,7 @@ TEST_CASE(gru)
expect_shape(migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::gru{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::gru::forward, clip},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
in_shape,
w_shape,
r_shape,
......@@ -465,7 +465,7 @@ TEST_CASE(gru)
expect_shape(migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::gru{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::gru::reverse, clip},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip},
in_shape,
w_shape,
r_shape,
......@@ -493,7 +493,7 @@ TEST_CASE(gru)
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::gru{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::gru::bidirectional, clip},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip},
in_shape,
w_shape,
r_shape,
......@@ -519,7 +519,7 @@ TEST_CASE(gru)
throws_shape(
migraphx::op::gru{
hidden_size + 1, {migraphx::op::tanh{}}, migraphx::op::gru::forward, clip},
hidden_size + 1, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
in_shape,
w_shape,
r_shape,
......@@ -545,7 +545,7 @@ TEST_CASE(gru)
throws_shape(
migraphx::op::gru{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::gru::bidirectional, clip},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip},
in_shape,
w_shape,
r_shape,
......@@ -571,7 +571,7 @@ TEST_CASE(gru)
throws_shape(
migraphx::op::gru{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::gru::forward, clip},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
in_shape,
w_shape,
r_shape,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment