Commit ce7b4b17 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 1877b194
...@@ -111,8 +111,8 @@ TEST_CASE(rnn_forward) ...@@ -111,8 +111,8 @@ TEST_CASE(rnn_forward)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = auto out_hs = p.add_instruction(
p.add_instruction(migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::forward, clip}, migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::forward, clip},
seq, seq,
w, w,
r, r,
...@@ -148,8 +148,8 @@ TEST_CASE(rnn_forward) ...@@ -148,8 +148,8 @@ TEST_CASE(rnn_forward)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = auto out_hs = p.add_instruction(
p.add_instruction(migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::forward, clip}, migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::forward, clip},
seq, seq,
w, w,
r, r,
...@@ -183,7 +183,10 @@ TEST_CASE(rnn_forward) ...@@ -183,7 +183,10 @@ TEST_CASE(rnn_forward)
auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::forward, clip}, seq, w, r); migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::forward, clip},
seq,
w,
r);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
...@@ -294,7 +297,8 @@ TEST_CASE(rnn_reverse) ...@@ -294,7 +297,8 @@ TEST_CASE(rnn_reverse)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::reverse, clip}, p.add_instruction(
migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::reverse, clip},
seq, seq,
w, w,
r, r,
...@@ -335,8 +339,8 @@ TEST_CASE(rnn_reverse) ...@@ -335,8 +339,8 @@ TEST_CASE(rnn_reverse)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = auto out_hs = p.add_instruction(
p.add_instruction(migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::reverse, clip}, migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::reverse, clip},
seq, seq,
w, w,
r, r,
...@@ -451,9 +455,11 @@ TEST_CASE(rnn_bidirectional) ...@@ -451,9 +455,11 @@ TEST_CASE(rnn_bidirectional)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::rnn{ p.add_instruction(migraphx::op::rnn{hidden_size,
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip}, {migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
seq, seq,
w, w,
r, r,
...@@ -589,11 +595,22 @@ TEST_CASE(rnn_bidirectional) ...@@ -589,11 +595,22 @@ TEST_CASE(rnn_bidirectional)
std::vector<float> hs_data; std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{ std::vector<float> hs_data_gold{0.37780784,
0.37780784, 0.61055139, 0.55168478, -0.5888475, 0.61055139,
-0.37144644, 0.31708236, 0.13104209, -0.18736027, 0.55168478,
-0.16915828, 0.1938169, 0.20667936, 0.58609703, -0.5888475,
-0.0070999, 0.46251031, -0.20639211, 0.37488942}; -0.37144644,
0.31708236,
0.13104209,
-0.18736027,
-0.16915828,
0.1938169,
0.20667936,
0.58609703,
-0.0070999,
0.46251031,
-0.20639211,
0.37488942};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
} }
...@@ -1074,9 +1091,11 @@ TEST_CASE(gru_forward_actv_funcs) ...@@ -1074,9 +1091,11 @@ TEST_CASE(gru_forward_actv_funcs)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction( p.add_instruction(migraphx::op::gru{hidden_size,
migraphx::op::gru{ {migraphx::op::sigmoid{}},
hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::forward, clip, 1}, migraphx::op::rnn_direction::forward,
clip,
1},
seq, seq,
w, w,
r, r,
...@@ -1365,7 +1384,8 @@ TEST_CASE(gru_reverse) ...@@ -1365,7 +1384,8 @@ TEST_CASE(gru_reverse)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction(migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::reverse, clip, 1}, p.add_instruction(
migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::reverse, clip, 1},
seq, seq,
w, w,
r, r,
...@@ -1917,9 +1937,11 @@ TEST_CASE(gru_bidirectional_actv_funcs) ...@@ -1917,9 +1937,11 @@ TEST_CASE(gru_bidirectional_actv_funcs)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction( p.add_instruction(migraphx::op::gru{hidden_size,
migraphx::op::gru{ {migraphx::op::sigmoid{}},
hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::bidirectional, clip, 0}, migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq, seq,
w, w,
r, r,
...@@ -1952,9 +1974,11 @@ TEST_CASE(gru_bidirectional_actv_funcs) ...@@ -1952,9 +1974,11 @@ TEST_CASE(gru_bidirectional_actv_funcs)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction( p.add_instruction(migraphx::op::gru{hidden_size,
migraphx::op::gru{ {migraphx::op::tanh{}},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip, 1}, migraphx::op::rnn_direction::bidirectional,
clip,
1},
seq, seq,
w, w,
r, r,
......
...@@ -1723,7 +1723,10 @@ struct test_gru_forward_default_actv ...@@ -1723,7 +1723,10 @@ struct test_gru_forward_default_actv
auto w = p.add_parameter("w", w_shape); auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
p.add_instruction( p.add_instruction(
migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::forward, clip}, seq, w, r); migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::forward, clip},
seq,
w,
r);
return p; return p;
} }
...@@ -2053,7 +2056,10 @@ struct test_gru_bidirct_default_actv ...@@ -2053,7 +2056,10 @@ struct test_gru_bidirct_default_actv
auto w = p.add_parameter("w", w_shape); auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
p.add_instruction( p.add_instruction(
migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip}, seq, w, r); migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
r);
return p; return p;
} }
...@@ -2086,9 +2092,10 @@ struct test_gru_bidirct_default_actv1 ...@@ -2086,9 +2092,10 @@ struct test_gru_bidirct_default_actv1
auto ih = p.add_parameter("ih", ih_shape); auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction( p.add_instruction(migraphx::op::gru{hidden_size,
migraphx::op::gru{ {migraphx::op::sigmoid{}},
hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::bidirectional, clip}, migraphx::op::rnn_direction::bidirectional,
clip},
seq, seq,
w, w,
r, r,
......
...@@ -659,7 +659,8 @@ TEST_CASE(gru_test) ...@@ -659,7 +659,8 @@ TEST_CASE(gru_test)
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward, migraphx::op::rnn_direction::forward,
clip, 1}, clip,
1},
seq, seq,
w, w,
r, r,
...@@ -723,7 +724,8 @@ TEST_CASE(gru_test) ...@@ -723,7 +724,8 @@ TEST_CASE(gru_test)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(migraphx::op::gru{hs, auto out_hs =
p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, {migraphx::op::tanh{},
migraphx::op::sigmoid{}, migraphx::op::sigmoid{},
migraphx::op::relu{}, migraphx::op::relu{},
...@@ -873,8 +875,8 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -873,8 +875,8 @@ TEST_CASE(gru_test_actv_funcs)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = auto out_hs = p.add_instruction(
p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::bidirectional, clip}, migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::bidirectional, clip},
seq, seq,
w, w,
r, r,
...@@ -905,7 +907,8 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -905,7 +907,8 @@ TEST_CASE(gru_test_actv_funcs)
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::gru{hs, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip}, migraphx::op::gru{
hs, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip},
seq, seq,
w, w,
r, r,
...@@ -1003,7 +1006,8 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -1003,7 +1006,8 @@ TEST_CASE(gru_test_actv_funcs)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::forward, clip}, auto out_hs =
p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::forward, clip},
seq, seq,
w, w,
r, r,
...@@ -1034,7 +1038,8 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -1034,7 +1038,8 @@ TEST_CASE(gru_test_actv_funcs)
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::gru{hs, {migraphx::op::relu{}}, migraphx::op::rnn_direction::reverse, clip}, migraphx::op::gru{
hs, {migraphx::op::relu{}}, migraphx::op::rnn_direction::reverse, clip},
seq, seq,
w, w,
r, r,
......
...@@ -282,7 +282,8 @@ TEST_CASE(rnn) ...@@ -282,7 +282,8 @@ TEST_CASE(rnn)
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape(migraphx::shape{migraphx::shape::float_type, expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}}, {seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::rnn{ migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip}, hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
...@@ -307,7 +308,8 @@ TEST_CASE(rnn) ...@@ -307,7 +308,8 @@ TEST_CASE(rnn)
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape(migraphx::shape{migraphx::shape::float_type, expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}}, {seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::rnn{ migraphx::op::rnn{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip}, hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip},
...@@ -332,11 +334,12 @@ TEST_CASE(rnn) ...@@ -332,11 +334,12 @@ TEST_CASE(rnn)
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape( expect_shape(migraphx::shape{migraphx::shape::float_type,
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}}, {seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::rnn{ migraphx::op::rnn{hidden_size,
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip}, {migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
in_shape, in_shape,
w_shape, w_shape,
r_shape, r_shape,
...@@ -358,9 +361,10 @@ TEST_CASE(rnn) ...@@ -358,9 +361,10 @@ TEST_CASE(rnn)
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape( throws_shape(migraphx::op::rnn{hidden_size + 1,
migraphx::op::rnn{ {migraphx::op::tanh{}},
hidden_size + 1, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip}, migraphx::op::rnn_direction::forward,
clip},
in_shape, in_shape,
w_shape, w_shape,
r_shape, r_shape,
...@@ -382,9 +386,10 @@ TEST_CASE(rnn) ...@@ -382,9 +386,10 @@ TEST_CASE(rnn)
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape( throws_shape(migraphx::op::rnn{hidden_size,
migraphx::op::rnn{ {migraphx::op::tanh{}},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip}, migraphx::op::rnn_direction::bidirectional,
clip},
in_shape, in_shape,
w_shape, w_shape,
r_shape, r_shape,
...@@ -435,7 +440,8 @@ TEST_CASE(gru) ...@@ -435,7 +440,8 @@ TEST_CASE(gru)
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape(migraphx::shape{migraphx::shape::float_type, expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}}, {seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::gru{ migraphx::op::gru{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip}, hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
...@@ -462,7 +468,8 @@ TEST_CASE(gru) ...@@ -462,7 +468,8 @@ TEST_CASE(gru)
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape(migraphx::shape{migraphx::shape::float_type, expect_shape(
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}}, {seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::gru{ migraphx::op::gru{
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip}, hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip},
...@@ -489,11 +496,12 @@ TEST_CASE(gru) ...@@ -489,11 +496,12 @@ TEST_CASE(gru)
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape( expect_shape(migraphx::shape{migraphx::shape::float_type,
migraphx::shape{migraphx::shape::float_type,
{seq_len, num_dirct, batch_size, hidden_size}}, {seq_len, num_dirct, batch_size, hidden_size}},
migraphx::op::gru{ migraphx::op::gru{hidden_size,
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip}, {migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip},
in_shape, in_shape,
w_shape, w_shape,
r_shape, r_shape,
...@@ -517,9 +525,10 @@ TEST_CASE(gru) ...@@ -517,9 +525,10 @@ TEST_CASE(gru)
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape( throws_shape(migraphx::op::gru{hidden_size + 1,
migraphx::op::gru{ {migraphx::op::tanh{}},
hidden_size + 1, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip}, migraphx::op::rnn_direction::forward,
clip},
in_shape, in_shape,
w_shape, w_shape,
r_shape, r_shape,
...@@ -543,9 +552,10 @@ TEST_CASE(gru) ...@@ -543,9 +552,10 @@ TEST_CASE(gru)
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape( throws_shape(migraphx::op::gru{hidden_size,
migraphx::op::gru{ {migraphx::op::tanh{}},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip}, migraphx::op::rnn_direction::bidirectional,
clip},
in_shape, in_shape,
w_shape, w_shape,
r_shape, r_shape,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment