Commit 731f2dbb authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent cc6a6e8d
...@@ -2600,9 +2600,9 @@ TEST_CASE(lstm_reverse) ...@@ -2600,9 +2600,9 @@ TEST_CASE(lstm_reverse)
{ {
migraphx::program p; migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto hs = p.add_instruction( auto hs = p.add_instruction(
migraphx::op::lstm{ migraphx::op::lstm{
hidden_size, hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
...@@ -2618,10 +2618,18 @@ TEST_CASE(lstm_reverse) ...@@ -2618,10 +2618,18 @@ TEST_CASE(lstm_reverse)
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> output_data; std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{ std::vector<float> output_data_gold{-0.443077,
-0.443077, -0.325425, -0.249367, -0.270812, -0.325425,
0.122913, 0.118537, 0.0370199, -0.0164687, -0.249367,
-0.00754759, 0.141613, 0.348002, 0.667298}; -0.270812,
0.122913,
0.118537,
0.0370199,
-0.0164687,
-0.00754759,
0.141613,
0.348002,
0.667298};
EXPECT(migraphx::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify_range(output_data, output_data_gold));
} }
...@@ -2629,15 +2637,10 @@ TEST_CASE(lstm_reverse) ...@@ -2629,15 +2637,10 @@ TEST_CASE(lstm_reverse)
{ {
migraphx::program p; migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto hs = p.add_instruction( auto hs = p.add_instruction(
migraphx::op::lstm{ migraphx::op::lstm{hidden_size, {}, migraphx::op::rnn_direction::reverse, clip, 0},
hidden_size,
{},
migraphx::op::rnn_direction::reverse,
clip,
0},
seq, seq,
w, w,
r); r);
...@@ -2647,10 +2650,18 @@ TEST_CASE(lstm_reverse) ...@@ -2647,10 +2650,18 @@ TEST_CASE(lstm_reverse)
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> output_data; std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{ std::vector<float> output_data_gold{-0.443077,
-0.443077, -0.325425, -0.249367, -0.270812, -0.325425,
0.122913, 0.118537, 0.0370199, -0.0164687, -0.249367,
-0.00754759, 0.141613, 0.348002, 0.667298}; -0.270812,
0.122913,
0.118537,
0.0370199,
-0.0164687,
-0.00754759,
0.141613,
0.348002,
0.667298};
EXPECT(migraphx::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify_range(output_data, output_data_gold));
} }
...@@ -2659,35 +2670,27 @@ TEST_CASE(lstm_reverse) ...@@ -2659,35 +2670,27 @@ TEST_CASE(lstm_reverse)
migraphx::program p; migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data});
p.add_instruction( p.add_instruction(migraphx::op::lstm{hidden_size,
migraphx::op::lstm{ {migraphx::op::sigmoid{}},
hidden_size, migraphx::op::rnn_direction::reverse,
{migraphx::op::sigmoid{}}, clip,
migraphx::op::rnn_direction::reverse, 0},
clip, seq,
0}, w,
seq, r);
w,
r);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> output_data; std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{ std::vector<float> output_data_gold{
0.246078, 0.199709, 0.303753, 0.301178, 0.246078, 0.199709, 0.303753, 0.301178, 0.264634, 0.304661, 0.349371, 0.288934,
0.264634, 0.304661, 0.349371, 0.288934, 0.405483, 0.445586, 0.515814, 0.473186, 0.301937, 0.264893, 0.254353, 0.269231,
0.405483, 0.445586, 0.515814, 0.473186, 0.359258, 0.400097, 0.288884, 0.247329, 0.276519, 0.264249, 0.1769, 0.23213,
0.301937, 0.264893, 0.254353, 0.269231, 0.310306, 0.262902, 0.276964, 0.295002, 0.373802, 0.366785, 0.419791, 0.393216,
0.359258, 0.400097, 0.288884, 0.247329, 0.262827, 0.371441, 0.369022, 0.298262, 0.334143, 0.309444, 0.174822, 0.251634,
0.276519, 0.264249, 0.1769, 0.23213, 0.244564, 0.214386, 0.185994, 0.226699, 0.28445, 0.376092, 0.338326, 0.259502};
0.310306, 0.262902, 0.276964, 0.295002,
0.373802, 0.366785, 0.419791, 0.393216,
0.262827, 0.371441, 0.369022, 0.298262,
0.334143, 0.309444, 0.174822, 0.251634,
0.244564, 0.214386, 0.185994, 0.226699,
0.28445, 0.376092, 0.338326, 0.259502};
EXPECT(migraphx::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify_range(output_data, output_data_gold));
} }
...@@ -2696,42 +2699,49 @@ TEST_CASE(lstm_reverse) ...@@ -2696,42 +2699,49 @@ TEST_CASE(lstm_reverse)
migraphx::program p; migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto hs = p.add_instruction( auto hs =
migraphx::op::lstm{ p.add_instruction(migraphx::op::lstm{hidden_size,
hidden_size, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::reverse,
migraphx::op::rnn_direction::reverse, clip,
clip, 0},
0}, seq,
seq, w,
w, r);
r); p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> output_data; std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{ std::vector<float> output_data_gold{-0.132123,
-0.132123, -0.37531, -0.12943, -0.00798307, -0.37531,
-0.133882, -0.0251383, 0.0486486, -0.0220606, -0.12943,
0.292495, 0.233866, 0.48646, 0.481844}; -0.00798307,
-0.133882,
-0.0251383,
0.0486486,
-0.0220606,
0.292495,
0.233866,
0.48646,
0.481844};
EXPECT(migraphx::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify_range(output_data, output_data_gold));
} }
// reverse, 3 args, seq_len = 1, last output as program output // reverse, 3 args, seq_len = 1, last output as program output
{ {
seq_len = 1; seq_len = 1;
std::vector<float> input_data1{ std::vector<float> input_data1{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331}; -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331};
migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::program p; migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape1, input_data1}); auto seq = p.add_literal(migraphx::literal{in_shape1, input_data1});
auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data});
p.add_instruction( p.add_instruction(
migraphx::op::lstm{ migraphx::op::lstm{
hidden_size, hidden_size,
...@@ -2746,10 +2756,18 @@ TEST_CASE(lstm_reverse) ...@@ -2746,10 +2756,18 @@ TEST_CASE(lstm_reverse)
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> output_data; std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{ std::vector<float> output_data_gold{-0.104351,
-0.104351, -0.0471426, -0.0905753, 0.01506, -0.0471426,
0.059797, 0.104239, -0.0266768, 0.0727547, -0.0905753,
-0.146298, 0.070535, 0.327809, 0.407388}; 0.01506,
0.059797,
0.104239,
-0.0266768,
0.0727547,
-0.146298,
0.070535,
0.327809,
0.407388};
EXPECT(migraphx::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify_range(output_data, output_data_gold));
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment