Commit 731f2dbb authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent cc6a6e8d
...@@ -2618,10 +2618,18 @@ TEST_CASE(lstm_reverse) ...@@ -2618,10 +2618,18 @@ TEST_CASE(lstm_reverse)
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> output_data; std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{ std::vector<float> output_data_gold{-0.443077,
-0.443077, -0.325425, -0.249367, -0.270812, -0.325425,
0.122913, 0.118537, 0.0370199, -0.0164687, -0.249367,
-0.00754759, 0.141613, 0.348002, 0.667298}; -0.270812,
0.122913,
0.118537,
0.0370199,
-0.0164687,
-0.00754759,
0.141613,
0.348002,
0.667298};
EXPECT(migraphx::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify_range(output_data, output_data_gold));
} }
...@@ -2632,12 +2640,7 @@ TEST_CASE(lstm_reverse) ...@@ -2632,12 +2640,7 @@ TEST_CASE(lstm_reverse)
auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto hs = p.add_instruction( auto hs = p.add_instruction(
migraphx::op::lstm{ migraphx::op::lstm{hidden_size, {}, migraphx::op::rnn_direction::reverse, clip, 0},
hidden_size,
{},
migraphx::op::rnn_direction::reverse,
clip,
0},
seq, seq,
w, w,
r); r);
...@@ -2647,10 +2650,18 @@ TEST_CASE(lstm_reverse) ...@@ -2647,10 +2650,18 @@ TEST_CASE(lstm_reverse)
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> output_data; std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{ std::vector<float> output_data_gold{-0.443077,
-0.443077, -0.325425, -0.249367, -0.270812, -0.325425,
0.122913, 0.118537, 0.0370199, -0.0164687, -0.249367,
-0.00754759, 0.141613, 0.348002, 0.667298}; -0.270812,
0.122913,
0.118537,
0.0370199,
-0.0164687,
-0.00754759,
0.141613,
0.348002,
0.667298};
EXPECT(migraphx::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify_range(output_data, output_data_gold));
} }
...@@ -2661,9 +2672,7 @@ TEST_CASE(lstm_reverse) ...@@ -2661,9 +2672,7 @@ TEST_CASE(lstm_reverse)
auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data});
p.add_instruction( p.add_instruction(migraphx::op::lstm{hidden_size,
migraphx::op::lstm{
hidden_size,
{migraphx::op::sigmoid{}}, {migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse, migraphx::op::rnn_direction::reverse,
clip, clip,
...@@ -2676,18 +2685,12 @@ TEST_CASE(lstm_reverse) ...@@ -2676,18 +2685,12 @@ TEST_CASE(lstm_reverse)
std::vector<float> output_data; std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{ std::vector<float> output_data_gold{
0.246078, 0.199709, 0.303753, 0.301178, 0.246078, 0.199709, 0.303753, 0.301178, 0.264634, 0.304661, 0.349371, 0.288934,
0.264634, 0.304661, 0.349371, 0.288934, 0.405483, 0.445586, 0.515814, 0.473186, 0.301937, 0.264893, 0.254353, 0.269231,
0.405483, 0.445586, 0.515814, 0.473186, 0.359258, 0.400097, 0.288884, 0.247329, 0.276519, 0.264249, 0.1769, 0.23213,
0.301937, 0.264893, 0.254353, 0.269231, 0.310306, 0.262902, 0.276964, 0.295002, 0.373802, 0.366785, 0.419791, 0.393216,
0.359258, 0.400097, 0.288884, 0.247329, 0.262827, 0.371441, 0.369022, 0.298262, 0.334143, 0.309444, 0.174822, 0.251634,
0.276519, 0.264249, 0.1769, 0.23213, 0.244564, 0.214386, 0.185994, 0.226699, 0.28445, 0.376092, 0.338326, 0.259502};
0.310306, 0.262902, 0.276964, 0.295002,
0.373802, 0.366785, 0.419791, 0.393216,
0.262827, 0.371441, 0.369022, 0.298262,
0.334143, 0.309444, 0.174822, 0.251634,
0.244564, 0.214386, 0.185994, 0.226699,
0.28445, 0.376092, 0.338326, 0.259502};
EXPECT(migraphx::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify_range(output_data, output_data_gold));
} }
...@@ -2698,9 +2701,8 @@ TEST_CASE(lstm_reverse) ...@@ -2698,9 +2701,8 @@ TEST_CASE(lstm_reverse)
auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto hs = p.add_instruction( auto hs =
migraphx::op::lstm{ p.add_instruction(migraphx::op::lstm{hidden_size,
hidden_size,
{migraphx::op::tanh{}, migraphx::op::sigmoid{}}, {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
migraphx::op::rnn_direction::reverse, migraphx::op::rnn_direction::reverse,
clip, clip,
...@@ -2713,10 +2715,18 @@ TEST_CASE(lstm_reverse) ...@@ -2713,10 +2715,18 @@ TEST_CASE(lstm_reverse)
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> output_data; std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{ std::vector<float> output_data_gold{-0.132123,
-0.132123, -0.37531, -0.12943, -0.00798307, -0.37531,
-0.133882, -0.0251383, 0.0486486, -0.0220606, -0.12943,
0.292495, 0.233866, 0.48646, 0.481844}; -0.00798307,
-0.133882,
-0.0251383,
0.0486486,
-0.0220606,
0.292495,
0.233866,
0.48646,
0.481844};
EXPECT(migraphx::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify_range(output_data, output_data_gold));
} }
...@@ -2746,10 +2756,18 @@ TEST_CASE(lstm_reverse) ...@@ -2746,10 +2756,18 @@ TEST_CASE(lstm_reverse)
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> output_data; std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{ std::vector<float> output_data_gold{-0.104351,
-0.104351, -0.0471426, -0.0905753, 0.01506, -0.0471426,
0.059797, 0.104239, -0.0266768, 0.0727547, -0.0905753,
-0.146298, 0.070535, 0.327809, 0.407388}; 0.01506,
0.059797,
0.104239,
-0.0266768,
0.0727547,
-0.146298,
0.070535,
0.327809,
0.407388};
EXPECT(migraphx::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify_range(output_data, output_data_gold));
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment