Commit 48c72797 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 55453fe1
...@@ -3158,26 +3158,25 @@ TEST_CASE(lstm_bidirectional_actv_func) ...@@ -3158,26 +3158,25 @@ TEST_CASE(lstm_bidirectional_actv_func)
auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto hs = p.add_instruction(migraphx::op::lstm{hidden_size, auto hs =
{migraphx::op::sigmoid{}, migraphx::op::tanh{}}, p.add_instruction(migraphx::op::lstm{hidden_size,
migraphx::op::rnn_direction::bidirectional, {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
clip, migraphx::op::rnn_direction::bidirectional,
0}, clip,
seq, 0},
w, seq,
r); w,
r);
p.add_instruction(migraphx::op::rnn_last_output{}, hs); p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> output_data; std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{ std::vector<float> output_data_gold{
-0.165194, -0.0372928, 0.273786, -0.100877, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315,
-0.0458544, -0.0401315, 0.0737483, -0.064505, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
0.136898, 0.00160891, -0.184812, 0.147774, -0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876,
-0.162851, -0.102647, -0.113827, -0.142818, 0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805};
0.0513685, 0.0547876, 0.0201981, -0.00808453,
-0.00520328, 0.0945081, 0.264123, 0.410805};
EXPECT(migraphx::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify_range(output_data, output_data_gold));
} }
...@@ -3187,27 +3186,27 @@ TEST_CASE(lstm_bidirectional_actv_func) ...@@ -3187,27 +3186,27 @@ TEST_CASE(lstm_bidirectional_actv_func)
auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto hs = p.add_instruction(migraphx::op::lstm{hidden_size, auto hs = p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}, {migraphx::op::sigmoid{},
migraphx::op::sigmoid{}}, migraphx::op::tanh{},
migraphx::op::rnn_direction::bidirectional, migraphx::op::tanh{},
clip, migraphx::op::sigmoid{}},
0}, migraphx::op::rnn_direction::bidirectional,
seq, clip,
w, 0},
r); seq,
w,
r);
p.add_instruction(migraphx::op::rnn_last_output{}, hs); p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> output_data; std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{ std::vector<float> output_data_gold{
-0.165194, -0.0372928, 0.273786, -0.100877, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315,
-0.0458544, -0.0401315, 0.0737483, -0.064505, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
0.136898, 0.00160891, -0.184812, 0.147774, 0.246078, 0.199709, 0.303753, 0.301178, 0.264634, 0.304661,
0.246078, 0.199709, 0.303753, 0.301178, 0.349371, 0.288934, 0.405483, 0.445586, 0.515814, 0.473186};
0.264634, 0.304661, 0.349371, 0.288934,
0.405483, 0.445586, 0.515814, 0.473186};
EXPECT(migraphx::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify_range(output_data, output_data_gold));
} }
...@@ -3217,27 +3216,28 @@ TEST_CASE(lstm_bidirectional_actv_func) ...@@ -3217,27 +3216,28 @@ TEST_CASE(lstm_bidirectional_actv_func)
auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto hs = p.add_instruction(migraphx::op::lstm{hidden_size, auto hs = p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}, {migraphx::op::sigmoid{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::tanh{},
migraphx::op::rnn_direction::bidirectional, migraphx::op::tanh{},
clip, migraphx::op::sigmoid{},
0}, migraphx::op::tanh{}},
seq, migraphx::op::rnn_direction::bidirectional,
w, clip,
r); 0},
seq,
w,
r);
p.add_instruction(migraphx::op::rnn_last_output{}, hs); p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> output_data; std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{ std::vector<float> output_data_gold{
-0.165194, -0.0372928, 0.273786, -0.100877, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315,
-0.0458544, -0.0401315, 0.0737483, -0.064505, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
0.136898, 0.00160891, -0.184812, 0.147774, -0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876,
-0.162851, -0.102647, -0.113827, -0.142818, 0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805};
0.0513685, 0.0547876, 0.0201981, -0.00808453,
-0.00520328, 0.0945081, 0.264123, 0.410805};
EXPECT(migraphx::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify_range(output_data, output_data_gold));
} }
...@@ -3247,27 +3247,29 @@ TEST_CASE(lstm_bidirectional_actv_func) ...@@ -3247,27 +3247,29 @@ TEST_CASE(lstm_bidirectional_actv_func)
auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto hs = p.add_instruction(migraphx::op::lstm{hidden_size, auto hs = p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}, {migraphx::op::sigmoid{},
migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::tanh{},
migraphx::op::rnn_direction::bidirectional, migraphx::op::tanh{},
clip, migraphx::op::sigmoid{},
0}, migraphx::op::tanh{},
seq, migraphx::op::tanh{}},
w, migraphx::op::rnn_direction::bidirectional,
r); clip,
0},
seq,
w,
r);
p.add_instruction(migraphx::op::rnn_last_output{}, hs); p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> output_data; std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{ std::vector<float> output_data_gold{
-0.165194, -0.0372928, 0.273786, -0.100877, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315,
-0.0458544, -0.0401315, 0.0737483, -0.064505, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
0.136898, 0.00160891, -0.184812, 0.147774, -0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876,
-0.162851, -0.102647, -0.113827, -0.142818, 0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805};
0.0513685, 0.0547876, 0.0201981, -0.00808453,
-0.00520328, 0.0945081, 0.264123, 0.410805};
EXPECT(migraphx::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify_range(output_data, output_data_gold));
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment