Unverified Commit 6972ad26 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge pull request #170 from ROCmSoftwarePlatform/gru_operator

Gru operator
parents fb8fda8f bce629f1
......@@ -276,7 +276,58 @@ TEST_CASE(rnn)
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip},
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
......@@ -285,8 +336,10 @@ TEST_CASE(rnn)
expect_shape(migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::forward, clip},
migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
in_shape,
w_shape,
r_shape,
......@@ -302,17 +355,41 @@ TEST_CASE(rnn)
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape(migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::reverse, clip},
throws_shape(migraphx::op::rnn{hidden_size + 1,
{migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape(migraphx::op::rnn{hidden_size,
{migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
in_shape,
w_shape,
r_shape,
......@@ -328,25 +405,25 @@ TEST_CASE(rnn)
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
throws_shape(
migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::bidirectional, clip},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
}
TEST_CASE(gru)
{
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
......@@ -355,16 +432,19 @@ TEST_CASE(rnn)
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape(
migraphx::op::rnn{
hidden_size + 1, {migraphx::op::tanh{}}, migraphx::op::rnn::forward, clip},
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::gru{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
in_shape,
w_shape,
r_shape,
......@@ -380,16 +460,19 @@ TEST_CASE(rnn)
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape(
migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::bidirectional, clip},
expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::gru{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip},
in_shape,
w_shape,
r_shape,
......@@ -405,16 +488,100 @@ TEST_CASE(rnn)
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::program p;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape(migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::gru{hidden_size,
{migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape(migraphx::op::gru{hidden_size + 1,
{migraphx::op::tanh{}},
migraphx::op::rnn_direction::forward,
clip},
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape(migraphx::op::gru{hidden_size,
{migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
in_shape,
w_shape,
r_shape,
b_shape,
ih_shape);
}
{
std::size_t batch_size = 2;
std::size_t seq_len = 2;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type,
{num_dirct, 3 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape(
migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn::forward, clip},
migraphx::op::gru{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
in_shape,
w_shape,
r_shape,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment