Commit 2e78dc1f authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 21a492b2
...@@ -2120,95 +2120,115 @@ TEST_CASE(gru_bidirectional_actv_funcs) ...@@ -2120,95 +2120,115 @@ TEST_CASE(gru_bidirectional_actv_funcs)
} }
} }
TEST_CASE(lstm_forward) TEST_CASE(lstm_forward)
{ {
std::size_t batch_size = 3; std::size_t batch_size = 3;
std::size_t seq_len = 4; std::size_t seq_len = 4;
std::size_t hidden_size = 4; std::size_t hidden_size = 4;
std::size_t input_size = 3; std::size_t input_size = 3;
std::size_t num_dirct = 1; std::size_t num_dirct = 1;
std::vector<float> w_data{ std::vector<float> w_data{
0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344,
0.2602, -0.3098, 0.0567, 0.3344, 0.3607, -0.0551, 0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382,
0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, -0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279,
0.4267, 0.2382, -0.0784, -0.0032, -0.2476, -0.0206, -0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976,
-0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279, -0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285};
-0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504,
0.0575, -0.2138, 0.1071, 0.1976, -0.0758, 0.0139,
-0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285};
std::vector<float> r_data{
0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159,
-0.0522, 0.1685, -0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589,
-0.3109, 0.4908, -0.0133, -0.1858, -0.0590, -0.0347, -0.2353, -0.0671,
-0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265, -0.3902, 0.0755,
0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562,
0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418,
-0.2926, -0.3100, 0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277,
0.2315, 0.4087, -0.3963, -0.4161, -0.2169, -0.1344, 0.3468, -0.2260};
std::vector<float> bias_data{
0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881, -0.4407,
-0.2760, 0.1274, -0.0083, -0.2885, 0.3949, -0.0182, 0.4445, 0.3477,
0.2266, 0.3423, -0.0674, -0.4067, 0.0807, 0.1109, -0.2036, 0.1782,
-0.2467, -0.0730, -0.4216, 0.0316, -0.3025, 0.3637, -0.3181, -0.4655};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005,
2.3930, -0.5221, -0.1331, -0.0910, 1.2122, -0.1952,
0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122,
0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701,
-0.4100, -2.2344, 0.3685, 0.4583, 2.3794, 1.0372,
-0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{ std::vector<float> r_data{
1.9104, -1.9004, 0.3337, 0.5741, 0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685,
0.5671, 0.0458, 0.4514, -0.8968, -0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858,
-0.9201, 0.1962, 0.5771, -0.5332}; -0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265,
-0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562,
std::vector<float> ic_data{ 0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100,
0.9569, -0.5981, 1.1312, 1.0945, 0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161,
1.1055, -0.1212, -0.9097, 0.7831, -0.2169, -0.1344, 0.3468, -0.2260};
-1.6991, -1.9498, -1.2567, -0.4114};
std::vector<float> bias_data{0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881,
std::vector<float> pph_data{ -0.4407, -0.2760, 0.1274, -0.0083, -0.2885, 0.3949, -0.0182,
1.84369764, 0.68413646, -0.44892886, -1.50904413, 0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067, 0.0807,
0.3860796, -0.52186625, 1.08474445, -1.80867321, 0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316,
1.32594529, 0.4336262, -0.83699064, 0.49162736}; -0.3025, 0.3637, -0.3181, -0.4655};
float clip = 0.0f; std::vector<float> input_data{
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331,
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; -0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122,
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; 0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685,
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4*hidden_size, input_size}}; 0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4*hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8*hidden_size}}; std::vector<float> ih_data{1.9104,
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3*hidden_size}}; -1.9004,
0.3337,
0.5741,
0.5671,
0.0458,
0.4514,
-0.8968,
-0.9201,
0.1962,
0.5771,
-0.5332};
std::vector<float> ic_data{0.9569,
-0.5981,
1.1312,
1.0945,
1.1055,
-0.1212,
-0.9097,
0.7831,
-1.6991,
-1.9498,
-1.2567,
-0.4114};
std::vector<float> pph_data{1.84369764,
0.68413646,
-0.44892886,
-1.50904413,
0.3860796,
-0.52186625,
1.08474445,
-1.80867321,
1.32594529,
0.4336262,
-0.83699064,
0.49162736};
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
// forward, hidden state concatenation as output // forward, hidden state concatenation as output
{ {
migraphx::program p; migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data}); auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::lstm{hidden_size, p.add_instruction(
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::lstm{
migraphx::op::rnn_direction::forward, hidden_size,
clip, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
0}, migraphx::op::rnn_direction::forward,
seq, clip,
w, 0},
r, seq,
bias, w,
und, r,
ih, bias,
ic, und,
und); ih,
ic,
und);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
...@@ -2216,45 +2236,42 @@ TEST_CASE(lstm_forward) ...@@ -2216,45 +2236,42 @@ TEST_CASE(lstm_forward)
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{ std::vector<float> hs_data_gold{
0.0417273, -0.272355, 0.206765, 0.223879, 0.0417273, -0.272355, 0.206765, 0.223879, 0.138193, -0.0322939, -0.0891815,
0.138193, -0.0322939, -0.0891815, 0.15773, 0.15773, 0.19139, -0.127708, -0.409371, -0.136186, 0.0742487, -0.0800085,
0.19139, -0.127708, -0.409371, -0.136186, 0.259897, 0.0670196, 0.184266, 0.0610048, -0.138041, 0.0963885, 0.0213755,
0.0742487, -0.0800085, 0.259897, 0.0670196, -0.146027, -0.0324509, -0.0620429, -0.00532985, 0.0440265, 0.29654, -0.0463156,
0.184266, 0.0610048, -0.138041, 0.0963885, 0.0498799, 0.125772, 0.0533032, -0.131413, 0.0988431, -0.018085, -0.159434,
0.0213755, -0.146027, -0.0324509, -0.0620429, 0.030266, -0.0847427, 0.0874114, 0.304256, -0.0585745, -0.0223018, 0.131113,
-0.00532985, 0.0440265, 0.29654, -0.0463156, 0.135643, -0.0566208, 0.142701, 0.0342236, -0.198664, 0.0702607};
0.0498799, 0.125772, 0.0533032, -0.131413,
0.0988431, -0.018085, -0.159434, 0.030266,
-0.0847427, 0.0874114, 0.304256, -0.0585745,
-0.0223018, 0.131113, 0.135643, -0.0566208,
0.142701, 0.0342236, -0.198664, 0.0702607};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
} }
// forward, last_output as program output // forward, last_output as program output
{ {
migraphx::program p; migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data}); auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto hs = p.add_instruction(migraphx::op::lstm{hidden_size, auto hs = p.add_instruction(
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::lstm{
migraphx::op::rnn_direction::forward, hidden_size,
clip, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
0}, migraphx::op::rnn_direction::forward,
seq, clip,
w, 0},
r, seq,
bias, w,
und, r,
ih, bias,
ic, und,
und); ih,
ic,
und);
p.add_instruction(migraphx::op::rnn_last_output{}, hs); p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
...@@ -2262,37 +2279,47 @@ TEST_CASE(lstm_forward) ...@@ -2262,37 +2279,47 @@ TEST_CASE(lstm_forward)
std::vector<float> output_data; std::vector<float> output_data;
last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{ std::vector<float> output_data_gold{-0.0847427,
-0.0847427, 0.0874114, 0.304256, -0.0585745, 0.0874114,
-0.0223018, 0.131113, 0.135643, -0.0566208, 0.304256,
0.142701, 0.0342236, -0.198664, 0.0702607}; -0.0585745,
-0.0223018,
0.131113,
0.135643,
-0.0566208,
0.142701,
0.0342236,
-0.198664,
0.0702607};
EXPECT(migraphx::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify_range(output_data, output_data_gold));
} }
// forward, last_cell_output as program output // forward, last_cell_output as program output
{ {
migraphx::program p; migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data}); auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto hs = p.add_instruction(migraphx::op::lstm{hidden_size, auto hs = p.add_instruction(
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::lstm{
migraphx::op::rnn_direction::forward, hidden_size,
clip, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
0}, migraphx::op::rnn_direction::forward,
seq, clip,
w, 0},
r, seq,
bias, w,
und, r,
ih, bias,
ic, und,
und); ih,
ic,
und);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs); p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
...@@ -2300,10 +2327,18 @@ TEST_CASE(lstm_forward) ...@@ -2300,10 +2327,18 @@ TEST_CASE(lstm_forward)
std::vector<float> output_data; std::vector<float> output_data;
last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{ std::vector<float> output_data_gold{-0.111454,
-0.111454, 0.247794, 0.471087, -0.220574, 0.247794,
-0.048196, 0.263184, 0.283258, -0.14882, 0.471087,
0.605585, 0.078598, -0.64457, 0.119811}; -0.220574,
-0.048196,
0.263184,
0.283258,
-0.14882,
0.605585,
0.078598,
-0.64457,
0.119811};
EXPECT(migraphx::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify_range(output_data, output_data_gold));
} }
...@@ -2311,16 +2346,18 @@ TEST_CASE(lstm_forward) ...@@ -2311,16 +2346,18 @@ TEST_CASE(lstm_forward)
{ {
migraphx::program p; migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data});
p.add_instruction(migraphx::op::lstm{hidden_size, p.add_instruction(
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::lstm{
migraphx::op::rnn_direction::forward, hidden_size,
clip, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
0}, migraphx::op::rnn_direction::forward,
seq, clip,
w, 0},
r); seq,
w,
r);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto last_hs = p.eval({}); auto last_hs = p.eval({});
...@@ -2328,46 +2365,43 @@ TEST_CASE(lstm_forward) ...@@ -2328,46 +2365,43 @@ TEST_CASE(lstm_forward)
last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{ std::vector<float> output_data_gold{
-0.0327039, -0.0543852, 0.114378, -0.0768855, -0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698, -0.0623361,
0.0319021, -0.00298698, -0.0623361, 0.0598866, 0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617, -0.0786602, -0.0613048,
0.101585, 0.0687269, -0.161725, -0.25617, 0.179592, -0.071286, 0.074206, 0.0124086, -0.139544, 0.108016, -0.00973633,
-0.0786602, -0.0613048, 0.179592, -0.071286, -0.0552699, 0.0252681, -0.0562072, -0.102509, -0.0372696, 0.252296, -0.144544,
0.074206, 0.0124086, -0.139544, 0.108016, 0.00496085, 0.0662588, -0.048577, -0.187329, 0.0855831, -0.0171894, -0.140202,
-0.00973633, -0.0552699, 0.0252681, -0.0562072, 0.0828391, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315,
-0.102509, -0.0372696, 0.252296, -0.144544, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774};
0.00496085, 0.0662588, -0.048577, -0.187329,
0.0855831, -0.0171894, -0.140202, 0.0828391,
-0.165194, -0.0372928, 0.273786, -0.100877,
-0.0458544, -0.0401315, 0.0737483, -0.064505,
0.136898, 0.00160891, -0.184812, 0.147774};
EXPECT(migraphx::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify_range(output_data, output_data_gold));
} }
// forward, 8 args // forward, 8 args
{ {
migraphx::program p; migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data}); auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data}); auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::lstm{hidden_size, p.add_instruction(
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::lstm{
migraphx::op::rnn_direction::forward, hidden_size,
clip, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
0}, migraphx::op::rnn_direction::forward,
seq, clip,
w, 0},
r, seq,
bias, w,
und, r,
ih, bias,
ic, und,
pph); ih,
ic,
pph);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
...@@ -2375,52 +2409,47 @@ TEST_CASE(lstm_forward) ...@@ -2375,52 +2409,47 @@ TEST_CASE(lstm_forward)
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{ std::vector<float> hs_data_gold{
0.079753, -0.289854, 0.160043, 0.115056, 0.079753, -0.289854, 0.160043, 0.115056, 0.294074, -0.0319677, -0.0955337,
0.294074, -0.0319677, -0.0955337, 0.104168, 0.104168, 0.022618, -0.121195, -0.4065, -0.252054, 0.186991, -0.0624168,
0.022618, -0.121195, -0.4065, -0.252054, 0.205513, 0.0836373, 0.421857, 0.0459771, -0.144955, 0.0720673, -0.0300906,
0.186991, -0.0624168, 0.205513, 0.0836373, -0.0890598, -0.135266, -0.0413375, 0.0459032, 0.0414126, 0.272303, 0.0393149,
0.421857, 0.0459771, -0.144955, 0.0720673, 0.218258, 0.0944405, 0.0431211, -0.132394, 0.103489, 0.0142918, -0.123408,
-0.0300906, -0.0890598, -0.135266, -0.0413375, 0.0401075, -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544,
0.0459032, 0.0414126, 0.272303, 0.0393149, 0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723};
0.218258, 0.0944405, 0.0431211, -0.132394,
0.103489, 0.0142918, -0.123408, 0.0401075,
-0.058052, 0.0795391, 0.266617, -0.0128746,
0.0309878, 0.0971544, 0.149294, -0.0492549,
0.187761, 0.0501726, -0.121584, 0.0606723};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
} }
// seq_len = 1 // seq_len = 1
{ {
seq_len = 1; seq_len = 1;
migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input_data1{ std::vector<float> input_data1{
-0.5516, 0.2391, -1.6951, -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331};
-0.4313, -0.9730, -0.2005,
2.3930, -0.5221, -0.1331};
migraphx::program p; migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape1, input_data1}); auto seq = p.add_literal(migraphx::literal{in_shape1, input_data1});
auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data}); auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data}); auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data});
auto und = p.add_instruction(migraphx::op::undefined{}); auto und = p.add_instruction(migraphx::op::undefined{});
auto hs = p.add_instruction(migraphx::op::lstm{hidden_size, auto hs = p.add_instruction(
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::lstm{
migraphx::op::rnn_direction::forward, hidden_size,
clip, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
0}, migraphx::op::rnn_direction::forward,
seq, clip,
w, 0},
r, seq,
bias, w,
und, r,
ih, bias,
ic, und,
pph); ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, hs); p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.compile(migraphx::cpu::target{}); p.compile(migraphx::cpu::target{});
...@@ -2428,147 +2457,146 @@ TEST_CASE(lstm_forward) ...@@ -2428,147 +2457,146 @@ TEST_CASE(lstm_forward)
std::vector<float> hs_data; std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{ std::vector<float> hs_data_gold{0.079753,
0.079753, -0.289854, 0.160043, 0.115056, -0.289854,
0.294074, -0.0319677, -0.0955337, 0.104168, 0.160043,
0.022618, -0.121195, -0.4065, -0.252054}; 0.115056,
0.294074,
-0.0319677,
-0.0955337,
0.104168,
0.022618,
-0.121195,
-0.4065,
-0.252054};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
} }
} }
TEST_CASE(lstm_reverse) TEST_CASE(lstm_reverse)
{ {
std::size_t batch_size = 3; std::size_t batch_size = 3;
std::size_t seq_len = 4; std::size_t seq_len = 4;
std::size_t hidden_size = 4; std::size_t hidden_size = 4;
std::size_t input_size = 3; std::size_t input_size = 3;
std::size_t num_dirct = 1; std::size_t num_dirct = 1;
std::vector<float> w_data{ std::vector<float> w_data{
-0.2763, -0.4715, -0.3010, -0.2763, -0.4715, -0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390,
-0.2306, -0.2283, -0.2656, -0.1843, 0.2351, 0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914,
0.2035, 0.3570, -0.1499, -0.1480, 0.3734, -0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401,
0.4390, -0.1843, 0.2351, 0.1486, 0.1346, 0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169,
0.3357, 0.1217, 0.1401, -0.4462, 0.0729, 0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522};
0.3300, -0.0429, 0.3266,
0.4834, -0.3914, -0.1480,
0.3734, -0.0372, -0.1746,
0.0550, 0.4177, -0.1332,
0.4391, -0.3287, -0.4401,
0.1486, 0.1346, 0.1048,
-0.4361, 0.0886, -0.3840,
-0.2730, -0.1710, 0.3274,
0.0169, -0.4462, 0.0729,
0.3983, -0.0669, 0.0756,
0.4150, -0.4684, -0.2522};
std::vector<float> r_data{
-0.4564, -0.4432, 0.1605, 0.4387,
0.0034, 0.4116, 0.2824, 0.4775,
-0.2729, -0.4707, 0.1363, 0.2218,
0.0559, 0.2828, 0.2093, 0.4687,
0.3794, -0.1069, -0.3049, 0.1430,
-0.2506, 0.4644, 0.2755, -0.3645,
-0.3155, 0.1425, 0.2891, 0.1786,
-0.3274, 0.2365, 0.2522, -0.4312,
-0.0562, -0.2748, 0.0776, -0.3154,
0.2851, -0.3930, -0.1174, 0.4360,
0.2436, 0.0164, -0.0680, 0.3403,
-0.2857, -0.0459, -0.2991, -0.2624,
0.4194, -0.3291, -0.4659, 0.3300,
0.0454, 0.4981, -0.4706, -0.4584,
0.2596, 0.2871, -0.3509, -0.1910,
0.3987, -0.1687, -0.0032, -0.1038};
std::vector<float> bias_data{
-0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632, -0.1823,
0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896, 0.4755, -0.4794,
0.2167, -0.4474, -0.3139, 0.1018, 0.4470, -0.4232, 0.3247, -0.1636,
-0.1582, -0.1703, 0.3920, 0.2055, -0.4386, 0.4208, 0.0717, 0.3789};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951,
-0.4313, -0.9730, -0.2005,
2.3930, -0.5221, -0.1331,
-0.0910, 1.2122, -0.1952,
0.4661, 0.6494, 2.1332,
-1.0972, 0.9816, 0.1122,
0.3577, 1.3508, -0.5366,
1.7449, 0.5483, -0.0701,
-0.4100, -2.2344, 0.3685,
0.4583, 2.3794, 1.0372,
-0.8887, 0.7892, -0.4012,
-0.2818, -2.3374, 1.5310};
std::vector<float> r_data{
-0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116, 0.2824, 0.4775, -0.2729, -0.4707,
0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687, 0.3794, -0.1069, -0.3049, 0.1430,
-0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425, 0.2891, 0.1786, -0.3274, 0.2365,
0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154, 0.2851, -0.3930, -0.1174, 0.4360,
0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459, -0.2991, -0.2624, 0.4194, -0.3291,
-0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584, 0.2596, 0.2871, -0.3509, -0.1910,
0.3987, -0.1687, -0.0032, -0.1038};
std::vector<float> bias_data{-0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632,
-0.1823, 0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896,
0.4755, -0.4794, 0.2167, -0.4474, -0.3139, 0.1018, 0.4470,
-0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055,
-0.4386, 0.4208, 0.0717, 0.3789};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331,
-0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122,
0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685,
0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.5289,
1.0986,
0.6091,
1.6462,
0.8720,
0.5349,
-0.1962,
-1.7416,
-0.9912,
1.2831,
1.0896,
-0.6959};
std::vector<float> ic_data{-0.8323,
0.3998,
0.1831,
0.5938,
2.7096,
-0.1790,
0.0022,
-0.8040,
0.1578,
0.0567,
0.8069,
-0.5141};
std::vector<float> pph_data{-0.8271,
-0.5683,
0.4562,
-1.2545,
1.2729,
-0.4082,
-0.4392,
-0.9406,
0.7794,
1.8194,
-0.5811,
0.2166};
std::vector<float> ih_data{ migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
1.5289, 1.0986, 0.6091, 1.6462, migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
0.8720, 0.5349, -0.1962, -1.7416, migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
-0.9912, 1.2831, 1.0896, -0.6959}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
std::vector<float> ic_data{ migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
-0.8323, 0.3998, 0.1831, 0.5938, migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
2.7096, -0.1790, 0.0022, -0.8040, float clip = 0.0f;
0.1578, 0.0567, 0.8069, -0.5141};
std::vector<float> pph_data{
-0.8271, -0.5683, 0.4562, -1.2545,
1.2729, -0.4082, -0.4392, -0.9406,
0.7794, 1.8194, -0.5811, 0.2166};
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4*hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4*hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8*hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3*hidden_size}};
float clip = 0.0f;
// reverse, concatenation of hidden states as program output // reverse, concatenation of hidden states as program output
{ {
migraphx::program p; migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::lstm{hidden_size, auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
migraphx::op::rnn_direction::reverse, auto w = p.add_literal(migraphx::literal{w_shape, w_data});
clip, auto r = p.add_literal(migraphx::literal{r_shape, r_data});
0}, auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
seq, auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data});
w, auto und = p.add_instruction(migraphx::op::undefined{});
r,
bias, p.add_instruction(
und, migraphx::op::lstm{
ih, hidden_size,
ic, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
pph); migraphx::op::rnn_direction::reverse,
p.compile(migraphx::cpu::target{}); clip,
0},
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({}); auto hs_concat = p.eval({});
std::vector<float> output_data; std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{ std::vector<float> output_data_gold{
-0.120174, 0.043157, 0.117138, -0.222188, -0.120174, 0.043157, 0.117138, -0.222188, 0.789732, 0.128538, 0.20909,
0.789732, 0.128538, 0.20909, 0.0553812, 0.0553812, -0.224905, 0.32421, 0.344048, 0.271694, -0.175114, -0.00543549,
-0.224905, 0.32421, 0.344048, 0.271694, 0.178681, -0.266999, 0.928866, 0.113685, 0.220626, -0.0432316, -0.063456,
-0.175114, -0.00543549, 0.178681, -0.266999, 0.148524, 0.05108, -0.0234895, -0.182201, -0.0232277, 0.235501, -0.213485,
0.928866, 0.113685, 0.220626, -0.0432316, 0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, 0.146353,
-0.063456, 0.148524, 0.05108, -0.0234895, 0.0789186, -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676,
-0.182201, -0.0232277, 0.235501, -0.213485, 0.044508, -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537};
0.960938, 0.133565, 0.269741, 0.130438,
-0.0252804, 0.267356, 0.146353, 0.0789186,
-0.185038, -0.026845, 0.177273, -0.0774616,
0.946669, 0.0868676, 0.044508, -0.373961,
-0.0681467, 0.382748, 0.230211, -0.161537};
EXPECT(migraphx::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify_range(output_data, output_data_gold));
} }
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment