Commit ce7b4b17 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 1877b194
...@@ -1177,7 +1177,7 @@ struct rnn ...@@ -1177,7 +1177,7 @@ struct rnn
std::size_t hidden_size = 1; std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{tanh{}, tanh{}}; std::vector<operation> actv_funcs{tanh{}, tanh{}};
rnn_direction direction = rnn_direction::forward; rnn_direction direction = rnn_direction::forward;
float clip = 0.0f; float clip = 0.0f;
std::string name() const { return "rnn"; } std::string name() const { return "rnn"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
...@@ -1227,8 +1227,8 @@ struct gru ...@@ -1227,8 +1227,8 @@ struct gru
std::size_t hidden_size = 1; std::size_t hidden_size = 1;
std::vector<operation> actv_funcs{sigmoid{}, tanh{}}; std::vector<operation> actv_funcs{sigmoid{}, tanh{}};
rnn_direction direction = rnn_direction::forward; rnn_direction direction = rnn_direction::forward;
float clip = 0.0f; float clip = 0.0f;
int linear_before_reset = 0; int linear_before_reset = 0;
std::string name() const { return "gru"; } std::string name() const { return "gru"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
......
...@@ -40,8 +40,8 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const ...@@ -40,8 +40,8 @@ void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
migraphx::shape ih_shape{type, {1, batch_size, hidden_size}}; migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
std::vector<float> data(ih_shape.elements(), 0); std::vector<float> data(ih_shape.elements(), 0);
auto actv_funcs = vanilla_rnn_actv_funcs(ins); auto actv_funcs = vanilla_rnn_actv_funcs(ins);
auto rnn_op = any_cast<op::rnn>(ins->get_operator()); auto rnn_op = any_cast<op::rnn>(ins->get_operator());
op::rnn_direction dicrt = rnn_op.direction; op::rnn_direction dicrt = rnn_op.direction;
instruction_ref last_output{}; instruction_ref last_output{};
if(dicrt == op::rnn_direction::bidirectional) if(dicrt == op::rnn_direction::bidirectional)
...@@ -322,7 +322,7 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const ...@@ -322,7 +322,7 @@ void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
migraphx::shape ih_shape{type, {1, batch_size, hidden_size}}; migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
std::vector<float> data(ih_shape.elements(), 0.0); std::vector<float> data(ih_shape.elements(), 0.0);
auto gru_op = any_cast<op::gru>(ins->get_operator()); auto gru_op = any_cast<op::gru>(ins->get_operator());
op::rnn_direction dicrt = gru_op.direction; op::rnn_direction dicrt = gru_op.direction;
instruction_ref last_output{}; instruction_ref last_output{};
if(dicrt == op::rnn_direction::bidirectional) if(dicrt == op::rnn_direction::bidirectional)
......
...@@ -111,14 +111,14 @@ TEST_CASE(rnn_forward) ...@@ -111,14 +111,14 @@ TEST_CASE(rnn_forward)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = auto out_hs = p.add_instruction(
p.add_instruction(migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::forward, clip}, migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::forward, clip},
seq, seq,
w, w,
r, r,
bias, bias,
und, und,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
...@@ -148,14 +148,14 @@ TEST_CASE(rnn_forward) ...@@ -148,14 +148,14 @@ TEST_CASE(rnn_forward)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = auto out_hs = p.add_instruction(
p.add_instruction(migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::forward, clip}, migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::forward, clip},
seq, seq,
w, w,
r, r,
bias, bias,
und, und,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
...@@ -183,7 +183,10 @@ TEST_CASE(rnn_forward) ...@@ -183,7 +183,10 @@ TEST_CASE(rnn_forward)
auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::forward, clip}, seq, w, r); migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::forward, clip},
seq,
w,
r);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
...@@ -294,13 +297,14 @@ TEST_CASE(rnn_reverse) ...@@ -294,13 +297,14 @@ TEST_CASE(rnn_reverse)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::reverse, clip}, p.add_instruction(
seq, migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::reverse, clip},
w, seq,
r, w,
bias, r,
und, bias,
ih); und,
ih);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> hs_data; std::vector<float> hs_data;
...@@ -335,14 +339,14 @@ TEST_CASE(rnn_reverse) ...@@ -335,14 +339,14 @@ TEST_CASE(rnn_reverse)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = auto out_hs = p.add_instruction(
p.add_instruction(migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::reverse, clip}, migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::reverse, clip},
seq, seq,
w, w,
r, r,
bias, bias,
und, und,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
...@@ -451,15 +455,17 @@ TEST_CASE(rnn_bidirectional) ...@@ -451,15 +455,17 @@ TEST_CASE(rnn_bidirectional)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto out_hs = p.add_instruction( auto out_hs =
migraphx::op::rnn{ p.add_instruction(migraphx::op::rnn{hidden_size,
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip}, {migraphx::op::tanh{}},
seq, migraphx::op::rnn_direction::bidirectional,
w, clip},
r, seq,
bias, w,
und, r,
ih); bias,
und,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
...@@ -589,11 +595,22 @@ TEST_CASE(rnn_bidirectional) ...@@ -589,11 +595,22 @@ TEST_CASE(rnn_bidirectional)
std::vector<float> hs_data; std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{ std::vector<float> hs_data_gold{0.37780784,
0.37780784, 0.61055139, 0.55168478, -0.5888475, 0.61055139,
-0.37144644, 0.31708236, 0.13104209, -0.18736027, 0.55168478,
-0.16915828, 0.1938169, 0.20667936, 0.58609703, -0.5888475,
-0.0070999, 0.46251031, -0.20639211, 0.37488942}; -0.37144644,
0.31708236,
0.13104209,
-0.18736027,
-0.16915828,
0.1938169,
0.20667936,
0.58609703,
-0.0070999,
0.46251031,
-0.20639211,
0.37488942};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
} }
...@@ -1074,15 +1091,17 @@ TEST_CASE(gru_forward_actv_funcs) ...@@ -1074,15 +1091,17 @@ TEST_CASE(gru_forward_actv_funcs)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction( p.add_instruction(migraphx::op::gru{hidden_size,
migraphx::op::gru{ {migraphx::op::sigmoid{}},
hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::forward, clip, 1}, migraphx::op::rnn_direction::forward,
seq, clip,
w, 1},
r, seq,
bias, w,
und, r,
ih); bias,
und,
ih);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> hs_data; std::vector<float> hs_data;
...@@ -1365,13 +1384,14 @@ TEST_CASE(gru_reverse) ...@@ -1365,13 +1384,14 @@ TEST_CASE(gru_reverse)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction(migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::reverse, clip, 1}, p.add_instruction(
seq, migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::reverse, clip, 1},
w, seq,
r, w,
bias, r,
und, bias,
ih); und,
ih);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> hs_data; std::vector<float> hs_data;
...@@ -1917,15 +1937,17 @@ TEST_CASE(gru_bidirectional_actv_funcs) ...@@ -1917,15 +1937,17 @@ TEST_CASE(gru_bidirectional_actv_funcs)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction( p.add_instruction(migraphx::op::gru{hidden_size,
migraphx::op::gru{ {migraphx::op::sigmoid{}},
hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::bidirectional, clip, 0}, migraphx::op::rnn_direction::bidirectional,
seq, clip,
w, 0},
r, seq,
bias, w,
und, r,
ih); bias,
und,
ih);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> hs_data; std::vector<float> hs_data;
...@@ -1952,15 +1974,17 @@ TEST_CASE(gru_bidirectional_actv_funcs) ...@@ -1952,15 +1974,17 @@ TEST_CASE(gru_bidirectional_actv_funcs)
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
p.add_instruction( p.add_instruction(migraphx::op::gru{hidden_size,
migraphx::op::gru{ {migraphx::op::tanh{}},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip, 1}, migraphx::op::rnn_direction::bidirectional,
seq, clip,
w, 1},
r, seq,
bias, w,
und, r,
ih); bias,
und,
ih);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> hs_data; std::vector<float> hs_data;
......
...@@ -1723,7 +1723,10 @@ struct test_gru_forward_default_actv ...@@ -1723,7 +1723,10 @@ struct test_gru_forward_default_actv
auto w = p.add_parameter("w", w_shape); auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
p.add_instruction( p.add_instruction(
migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::forward, clip}, seq, w, r); migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::forward, clip},
seq,
w,
r);
return p; return p;
} }
...@@ -2053,7 +2056,10 @@ struct test_gru_bidirct_default_actv ...@@ -2053,7 +2056,10 @@ struct test_gru_bidirct_default_actv
auto w = p.add_parameter("w", w_shape); auto w = p.add_parameter("w", w_shape);
auto r = p.add_parameter("r", r_shape); auto r = p.add_parameter("r", r_shape);
p.add_instruction( p.add_instruction(
migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip}, seq, w, r); migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip},
seq,
w,
r);
return p; return p;
} }
...@@ -2086,15 +2092,16 @@ struct test_gru_bidirct_default_actv1 ...@@ -2086,15 +2092,16 @@ struct test_gru_bidirct_default_actv1
auto ih = p.add_parameter("ih", ih_shape); auto ih = p.add_parameter("ih", ih_shape);
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction( p.add_instruction(migraphx::op::gru{hidden_size,
migraphx::op::gru{ {migraphx::op::sigmoid{}},
hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::bidirectional, clip}, migraphx::op::rnn_direction::bidirectional,
seq, clip},
w, seq,
r, w,
bias, r,
und, bias,
ih); und,
ih);
return p; return p;
} }
......
...@@ -659,7 +659,8 @@ TEST_CASE(gru_test) ...@@ -659,7 +659,8 @@ TEST_CASE(gru_test)
p.add_instruction(migraphx::op::gru{hs, p.add_instruction(migraphx::op::gru{hs,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::forward, migraphx::op::rnn_direction::forward,
clip, 1}, clip,
1},
seq, seq,
w, w,
r, r,
...@@ -723,19 +724,20 @@ TEST_CASE(gru_test) ...@@ -723,19 +724,20 @@ TEST_CASE(gru_test)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(migraphx::op::gru{hs, auto out_hs =
{migraphx::op::tanh{}, p.add_instruction(migraphx::op::gru{hs,
migraphx::op::sigmoid{}, {migraphx::op::tanh{},
migraphx::op::relu{}, migraphx::op::sigmoid{},
migraphx::op::tanh{}}, migraphx::op::relu{},
migraphx::op::rnn_direction::bidirectional, migraphx::op::tanh{}},
clip}, migraphx::op::rnn_direction::bidirectional,
seq, clip},
w, seq,
r, w,
bias, r,
seq_len, bias,
ih); seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_bi.onnx");
...@@ -873,14 +875,14 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -873,14 +875,14 @@ TEST_CASE(gru_test_actv_funcs)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = auto out_hs = p.add_instruction(
p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::bidirectional, clip}, migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::bidirectional, clip},
seq, seq,
w, w,
r, r,
bias, bias,
seq_len, seq_len,
ih); ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_bi_0.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_bi_0.onnx");
...@@ -905,7 +907,8 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -905,7 +907,8 @@ TEST_CASE(gru_test_actv_funcs)
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::gru{hs, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip}, migraphx::op::gru{
hs, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip},
seq, seq,
w, w,
r, r,
...@@ -1003,13 +1006,14 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -1003,13 +1006,14 @@ TEST_CASE(gru_test_actv_funcs)
p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}}); p.add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::forward, clip}, auto out_hs =
seq, p.add_instruction(migraphx::op::gru{hs, {}, migraphx::op::rnn_direction::forward, clip},
w, seq,
r, w,
bias, r,
seq_len, bias,
ih); seq_len,
ih);
p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs);
auto prog = migraphx::parse_onnx("onnx_gru_forward_0.onnx"); auto prog = migraphx::parse_onnx("onnx_gru_forward_0.onnx");
...@@ -1034,7 +1038,8 @@ TEST_CASE(gru_test_actv_funcs) ...@@ -1034,7 +1038,8 @@ TEST_CASE(gru_test_actv_funcs)
auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}}); auto ih = p.add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
auto out_hs = p.add_instruction( auto out_hs = p.add_instruction(
migraphx::op::gru{hs, {migraphx::op::relu{}}, migraphx::op::rnn_direction::reverse, clip}, migraphx::op::gru{
hs, {migraphx::op::relu{}}, migraphx::op::rnn_direction::reverse, clip},
seq, seq,
w, w,
r, r,
......
...@@ -282,15 +282,16 @@ TEST_CASE(rnn) ...@@ -282,15 +282,16 @@ TEST_CASE(rnn)
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape(migraphx::shape{migraphx::shape::float_type, expect_shape(
{seq_len, num_dirct, batch_size, hidden_size}}, migraphx::shape{migraphx::shape::float_type,
migraphx::op::rnn{ {seq_len, num_dirct, batch_size, hidden_size}},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip}, migraphx::op::rnn{
in_shape, hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
w_shape, in_shape,
r_shape, w_shape,
b_shape, r_shape,
ih_shape); b_shape,
ih_shape);
} }
{ {
...@@ -307,15 +308,16 @@ TEST_CASE(rnn) ...@@ -307,15 +308,16 @@ TEST_CASE(rnn)
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape(migraphx::shape{migraphx::shape::float_type, expect_shape(
{seq_len, num_dirct, batch_size, hidden_size}}, migraphx::shape{migraphx::shape::float_type,
migraphx::op::rnn{ {seq_len, num_dirct, batch_size, hidden_size}},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip}, migraphx::op::rnn{
in_shape, hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip},
w_shape, in_shape,
r_shape, w_shape,
b_shape, r_shape,
ih_shape); b_shape,
ih_shape);
} }
{ {
...@@ -332,16 +334,17 @@ TEST_CASE(rnn) ...@@ -332,16 +334,17 @@ TEST_CASE(rnn)
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
expect_shape( expect_shape(migraphx::shape{migraphx::shape::float_type,
migraphx::shape{migraphx::shape::float_type, {seq_len, num_dirct, batch_size, hidden_size}},
{seq_len, num_dirct, batch_size, hidden_size}}, migraphx::op::rnn{hidden_size,
migraphx::op::rnn{ {migraphx::op::tanh{}},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip}, migraphx::op::rnn_direction::bidirectional,
in_shape, clip},
w_shape, in_shape,
r_shape, w_shape,
b_shape, r_shape,
ih_shape); b_shape,
ih_shape);
} }
{ {
...@@ -358,14 +361,15 @@ TEST_CASE(rnn) ...@@ -358,14 +361,15 @@ TEST_CASE(rnn)
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape( throws_shape(migraphx::op::rnn{hidden_size + 1,
migraphx::op::rnn{ {migraphx::op::tanh{}},
hidden_size + 1, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip}, migraphx::op::rnn_direction::forward,
in_shape, clip},
w_shape, in_shape,
r_shape, w_shape,
b_shape, r_shape,
ih_shape); b_shape,
ih_shape);
} }
{ {
...@@ -382,14 +386,15 @@ TEST_CASE(rnn) ...@@ -382,14 +386,15 @@ TEST_CASE(rnn)
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}};
throws_shape( throws_shape(migraphx::op::rnn{hidden_size,
migraphx::op::rnn{ {migraphx::op::tanh{}},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip}, migraphx::op::rnn_direction::bidirectional,
in_shape, clip},
w_shape, in_shape,
r_shape, w_shape,
b_shape, r_shape,
ih_shape); b_shape,
ih_shape);
} }
{ {
...@@ -435,15 +440,16 @@ TEST_CASE(gru) ...@@ -435,15 +440,16 @@ TEST_CASE(gru)
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape(migraphx::shape{migraphx::shape::float_type, expect_shape(
{seq_len, num_dirct, batch_size, hidden_size}}, migraphx::shape{migraphx::shape::float_type,
migraphx::op::gru{ {seq_len, num_dirct, batch_size, hidden_size}},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip}, migraphx::op::gru{
in_shape, hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip},
w_shape, in_shape,
r_shape, w_shape,
b_shape, r_shape,
ih_shape); b_shape,
ih_shape);
} }
{ {
...@@ -462,15 +468,16 @@ TEST_CASE(gru) ...@@ -462,15 +468,16 @@ TEST_CASE(gru)
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape(migraphx::shape{migraphx::shape::float_type, expect_shape(
{seq_len, num_dirct, batch_size, hidden_size}}, migraphx::shape{migraphx::shape::float_type,
migraphx::op::gru{ {seq_len, num_dirct, batch_size, hidden_size}},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip}, migraphx::op::gru{
in_shape, hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip},
w_shape, in_shape,
r_shape, w_shape,
b_shape, r_shape,
ih_shape); b_shape,
ih_shape);
} }
{ {
...@@ -489,16 +496,17 @@ TEST_CASE(gru) ...@@ -489,16 +496,17 @@ TEST_CASE(gru)
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
expect_shape( expect_shape(migraphx::shape{migraphx::shape::float_type,
migraphx::shape{migraphx::shape::float_type, {seq_len, num_dirct, batch_size, hidden_size}},
{seq_len, num_dirct, batch_size, hidden_size}}, migraphx::op::gru{hidden_size,
migraphx::op::gru{ {migraphx::op::tanh{}},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip}, migraphx::op::rnn_direction::bidirectional,
in_shape, clip},
w_shape, in_shape,
r_shape, w_shape,
b_shape, r_shape,
ih_shape); b_shape,
ih_shape);
} }
{ {
...@@ -517,14 +525,15 @@ TEST_CASE(gru) ...@@ -517,14 +525,15 @@ TEST_CASE(gru)
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape( throws_shape(migraphx::op::gru{hidden_size + 1,
migraphx::op::gru{ {migraphx::op::tanh{}},
hidden_size + 1, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip}, migraphx::op::rnn_direction::forward,
in_shape, clip},
w_shape, in_shape,
r_shape, w_shape,
b_shape, r_shape,
ih_shape); b_shape,
ih_shape);
} }
{ {
...@@ -543,14 +552,15 @@ TEST_CASE(gru) ...@@ -543,14 +552,15 @@ TEST_CASE(gru)
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
throws_shape( throws_shape(migraphx::op::gru{hidden_size,
migraphx::op::gru{ {migraphx::op::tanh{}},
hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip}, migraphx::op::rnn_direction::bidirectional,
in_shape, clip},
w_shape, in_shape,
r_shape, w_shape,
b_shape, r_shape,
ih_shape); b_shape,
ih_shape);
} }
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment