#include #include #include #include #include #include #include #include #include #include #include #include #include #include "test.hpp" TEST_CASE(rnn_forward) { std::size_t batch_size = 2; std::size_t seq_len = 2; std::size_t hidden_size = 4; std::size_t input_size = 3; std::size_t num_dirct = 1; std::vector w_data{0.4691, 0.3185, -0.2227, 0.4423, -0.0609, -0.2803, 0.1744, 0.3146, 0.4049, -0.3973, -0.0890, -0.1636}; std::vector r_data{-0.0456, 0.1061, 0.1574, -0.4928, -0.4300, -0.1909, -0.0225, -0.2668, 0.1840, -0.4453, -0.4896, 0.1302, -0.0929, 0.3545, -0.4981, 0.0616}; std::vector bias_data{ -0.4938, 0.4355, -0.3186, 0.2094, 0.1037, -0.1071, 0.4504, -0.3990}; std::vector ih_data(num_dirct * batch_size * hidden_size, 0); std::vector input(seq_len * batch_size * input_size, 0); input[0] = input[1] = 1.0; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; float clip = 0.0f; // concatenation of hidden states as program output { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto und = p.add_instruction(migraphx::op::undefined{}); p.add_instruction(migraphx::op::rnn{hidden_size, {migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip}, seq, w, r, bias, und, ih); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{0.37780784, 0.61055139, 0.55168478, -0.5888475, -0.37144644, 0.31708236, 0.13104209, -0.18736027, 0.03445704, 0.19167931, -0.3946827, -0.30889652, -0.22276389, 0.44193283, -0.16477929, -0.11893477}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } // rnn last output as program output { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto und = p.add_instruction(migraphx::op::undefined{}); auto out_hs = p.add_instruction( migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::forward, clip}, seq, w, r, bias, und, ih); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.compile(migraphx::cpu::target{}); auto last_output = p.eval({}); std::vector last_output_data; last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); }); std::vector last_output_data_gold{0.03445704, 0.19167931, -0.3946827, -0.30889652, -0.22276389, 0.44193283, -0.16477929, -0.11893477}; EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold)); } // multiple rnn_last_output operators { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto und = p.add_instruction(migraphx::op::undefined{}); auto out_hs = p.add_instruction( migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::forward, clip}, seq, w, r, bias, und, ih); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.compile(migraphx::cpu::target{}); auto last_output = p.eval({}); std::vector last_output_data; last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); }); std::vector last_output_data_gold{0.03445704, 0.19167931, -0.3946827, -0.30889652, -0.22276389, 0.44193283, -0.16477929, -0.11893477}; EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold)); } // 3 args { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto out_hs = p.add_instruction( migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::forward, clip}, seq, w, r); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.compile(migraphx::cpu::target{}); auto last_output = p.eval({}); std::vector last_output_data; last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); }); std::vector last_output_data_gold{ 0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0.}; EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold)); } // seq_len = 1 { seq_len = 1; std::vector input_1(seq_len * batch_size * input_size, 0); input_1[0] = input_1[1] = 1.0; migraphx::shape in_shape_1{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape_1, input_1}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto und = p.add_instruction(migraphx::op::undefined{}); p.add_instruction(migraphx::op::rnn{hidden_size, {migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip}, seq, w, r, bias, und, ih); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{0.37780784, 0.61055139, 0.55168478, -0.5888475, -0.37144644, 0.31708236, 0.13104209, -0.18736027}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } } TEST_CASE(rnn_reverse) { std::size_t batch_size = 2; std::size_t seq_len = 2; std::size_t hidden_size = 4; std::size_t input_size = 3; std::size_t num_dirct = 1; std::vector w_data{-0.0296, -0.1341, 0.1761, -0.2325, -0.0717, 0.1852, 0.2720, 0.1471, -0.1097, 0.3363, -0.0587, -0.2302}; std::vector r_data{0.2528, -0.2333, 0.3973, 0.1593, -0.0388, 0.1702, 0.3829, -0.0712, -0.1668, 0.3074, -0.2854, 0.4049, -0.3737, -0.1051, 0.4482, -0.2841}; std::vector bias_data{-0.3188, 0.1341, -0.4446, 0.1389, 0.3117, 0.3664, 0.2352, 0.2552}; std::vector input(seq_len * batch_size * input_size, 0); input[0] = input[1] = 1.0; std::vector ih_data(num_dirct * batch_size * hidden_size, 0); float clip = 0.0f; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; // concatenation of hidden states as program output { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto und = p.add_instruction(migraphx::op::undefined{}); p.add_instruction( migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::reverse, clip}, seq, w, r, bias, und, ih); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{-0.29385301, 0.16796815, 0.51075965, 0.40258689, -0.13818839, 0.44124447, 0.14365635, 0.14803654, -0.0070999, 0.46251031, -0.20639211, 0.37488942, -0.0070999, 0.46251031, -0.20639211, 0.37488942}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } // rnn last output as program output { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto und = p.add_instruction(migraphx::op::undefined{}); auto out_hs = p.add_instruction( migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::reverse, clip}, seq, w, r, bias, und, ih); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.compile(migraphx::cpu::target{}); auto last_output = p.eval({}); std::vector last_output_data; last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); }); std::vector last_output_data_gold{-0.29385301, 0.16796815, 0.51075965, 0.40258689, -0.13818839, 0.44124447, 0.14365635, 0.14803654}; EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold)); } } TEST_CASE(rnn_bidirectional) { std::size_t batch_size = 2; std::size_t seq_len = 2; std::size_t hidden_size = 4; std::size_t input_size = 3; std::size_t num_dirct = 2; std::vector w_data{0.4691, 0.3185, -0.2227, 0.4423, -0.0609, -0.2803, 0.1744, 0.3146, 0.4049, -0.3973, -0.0890, -0.1636, -0.0296, -0.1341, 0.1761, -0.2325, -0.0717, 0.1852, 0.2720, 0.1471, -0.1097, 0.3363, -0.0587, -0.2302}; std::vector r_data{-0.0456, 0.1061, 0.1574, -0.4928, -0.4300, -0.1909, -0.0225, -0.2668, 0.1840, -0.4453, -0.4896, 0.1302, -0.0929, 0.3545, -0.4981, 0.0616, 0.2528, -0.2333, 0.3973, 0.1593, -0.0388, 0.1702, 0.3829, -0.0712, -0.1668, 0.3074, -0.2854, 0.4049, -0.3737, -0.1051, 0.4482, -0.2841}; std::vector bias_data{-0.4938, 0.4355, -0.3186, 0.2094, 0.1037, -0.1071, 0.4504, -0.3990, -0.3188, 0.1341, -0.4446, 0.1389, 0.3117, 0.3664, 0.2352, 0.2552}; std::vector input(seq_len * batch_size * input_size, 0); input[0] = input[1] = 1.0; std::vector ih_data(num_dirct * batch_size * hidden_size, 0); migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, hidden_size, input_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, hidden_size, hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 2 * hidden_size}}; float clip = 0.0f; // concatenation of hidden state for program output { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto und = p.add_instruction(migraphx::op::undefined{}); p.add_instruction( migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip}, seq, w, r, bias, und, ih); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{ 0.37780784, 0.61055139, 0.55168478, -0.5888475, -0.37144644, 0.31708236, 0.13104209, -0.18736027, -0.29385301, 0.16796815, 0.51075965, 0.40258689, -0.13818839, 0.44124447, 0.14365635, 0.14803654, 0.03445704, 0.19167931, -0.3946827, -0.30889652, -0.22276389, 0.44193283, -0.16477929, -0.11893477, -0.0070999, 0.46251031, -0.20639211, 0.37488942, -0.0070999, 0.46251031, -0.20639211, 0.37488942}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } // last rnn output for program output { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto und = p.add_instruction(migraphx::op::undefined{}); auto out_hs = p.add_instruction(migraphx::op::rnn{hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip}, seq, w, r, bias, und, ih); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.compile(migraphx::cpu::target{}); auto last_output = p.eval({}); std::vector last_output_data; last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); }); std::vector last_output_data_gold{0.03445704, 0.19167931, -0.3946827, -0.30889652, -0.22276389, 0.44193283, -0.16477929, -0.11893477, -0.29385301, 0.16796815, 0.51075965, 0.40258689, -0.13818839, 0.44124447, 0.14365635, 0.14803654}; EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold)); } // 4 args { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto out_hs = p.add_instruction(migraphx::op::rnn{hidden_size, {migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip}, seq, w, r, bias); p.add_instruction(migraphx::op::rnn_last_output{}, out_hs); p.compile(migraphx::cpu::target{}); auto last_output = p.eval({}); std::vector last_output_data; last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); }); std::vector last_output_data_gold{0.03445704, 0.19167931, -0.3946827, -0.30889652, -0.22276389, 0.44193283, -0.16477929, -0.11893477, -0.29385301, 0.16796815, 0.51075965, 0.40258689, -0.13818839, 0.44124447, 0.14365635, 0.14803654}; EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold)); } // 3 args { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); p.add_instruction(migraphx::op::rnn{hidden_size, {migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip}, seq, w, r); p.compile(migraphx::cpu::target{}); auto last_output = p.eval({}); std::vector last_output_data; last_output.visit([&](auto out) { last_output_data.assign(out.begin(), out.end()); }); std::vector last_output_data_gold{ 0.6570473, 0.36392266, 0.45342238, -0.45127486, 0., 0., 0., 0., -0.16225325, -0.29515147, 0.39617197, 0.27068236, 0., 0., 0., 0., 0.2935145, -0.23719997, -0.31123261, -0.18357255, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}; EXPECT(migraphx::verify_range(last_output_data, last_output_data_gold)); } // concatenation of hidden state for program output { seq_len = 1; std::vector input_1(seq_len * batch_size * input_size, 0); input_1[0] = input_1[1] = 1.0; migraphx::shape in_shape_1{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape_1, input_1}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto und = p.add_instruction(migraphx::op::undefined{}); p.add_instruction( migraphx::op::rnn{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip}, seq, w, r, bias, und, ih); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{0.37780784, 0.61055139, 0.55168478, -0.5888475, -0.37144644, 0.31708236, 0.13104209, -0.18736027, -0.16915828, 0.1938169, 0.20667936, 0.58609703, -0.0070999, 0.46251031, -0.20639211, 0.37488942}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } } TEST_CASE(gru_forward) { std::size_t batch_size = 2; std::size_t seq_len = 3; std::size_t hidden_size = 5; std::size_t input_size = 3; std::size_t num_dirct = 1; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}}; std::vector w_data{ 0.3485, -0.0378, -0.1782, 0.1416, -0.3096, -0.2212, -0.3883, 0.1983, -0.2418, 0.1480, -0.3255, 0.1359, -0.3551, -0.3605, -0.3482, -0.1424, -0.0495, -0.1640, -0.1979, -0.2577, -0.4097, -0.1211, -0.0412, 0.1801, 0.1721, -0.4327, -0.0498, 0.2628, -0.1573, -0.1577, 0.2759, -0.2023, -0.1185, -0.2136, 0.1294, -0.2331, 0.0701, 0.4316, 0.0480, 0.0247, -0.0166, -0.2729, 0.1712, -0.3984, -0.3905}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}}; std::vector r_data{ 0.2848, -0.2851, -0.3466, -0.1718, -0.1492, -0.0082, 0.2452, -0.0401, 0.3399, 0.2529, -0.0953, -0.0903, -0.1518, -0.1373, 0.3848, -0.0130, -0.4339, 0.0406, -0.1926, -0.1131, 0.4285, -0.0013, 0.2243, 0.2752, 0.1776, -0.1720, 0.0822, -0.0295, 0.1062, -0.2721, -0.2736, -0.1826, 0.3541, -0.4259, 0.2188, 0.0706, 0.3650, 0.3947, 0.2522, 0.2179, -0.0744, 0.2122, -0.4346, 0.2760, 0.4076, 0.1183, -0.1500, -0.1704, 0.3090, -0.0706, -0.2442, 0.3021, 0.1680, 0.0783, -0.3754, -0.3469, -0.2972, -0.0170, 0.4143, 0.3801, 0.3852, -0.1170, -0.2937, 0.2979, -0.1357, 0.4257, 0.3884, -0.2916, 0.1071, 0.0934, 0.3645, -0.4310, -0.3480, 0.0702, -0.1558}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; std::vector bias_data{ 0.0560, 0.0310, -0.1669, -0.0781, 0.1793, -0.1758, 0.3173, -0.1650, -0.3732, 0.2946, -0.0912, 0.3118, 0.1391, 0.2755, 0.2695, -0.1059, -0.2357, 0.3629, -0.2534, -0.0494, 0.0556, 0.0881, -0.2592, -0.2213, 0.2310, -0.4044, 0.1801, 0.1438, 0.3108, -0.3607}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; std::vector input{-0.8432, -0.9887, 1.3041, -2.6430, -0.3306, -0.8504, -0.3933, 0.5151, -0.2951, 0.0093, -1.1948, -0.1239, 0.0373, 1.3211, 0.7854, -0.4838, -1.0536, -0.2529}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; std::vector ih_data{ -0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212}; float clip = 0.0f; // concatenation of hidden states for output { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto und = p.add_instruction(migraphx::op::undefined{}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); p.add_instruction(migraphx::op::gru{hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip, 1}, seq, w, r, bias, und, ih); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{ -0.27298412, 0.42363745, -0.09368783, 0.4823072, -0.02183238, -0.6873896, 0.16144305, 0.31932795, 0.6104771, 0.79759157, -0.31791314, 0.5249062, 0.08800987, 0.46404213, -0.11872687, -0.26210734, 0.34448293, -0.0176422, 0.48523626, 0.60002893, -0.3969709, 0.43360898, 0.35775262, 0.23280787, -0.52179873, -0.21944991, 0.4535257, -0.13735442, 0.51757574, 0.50380427}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } // last output for output { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto und = p.add_instruction(migraphx::op::undefined{}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto concat_hs = p.add_instruction(migraphx::op::gru{hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip, 1}, seq, w, r, bias, und, ih); p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{-0.3969709, 0.43360898, 0.35775262, 0.23280787, -0.52179873, -0.21944991, 0.4535257, -0.13735442, 0.51757574, 0.50380427}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } // two rnn_last_output operators after gru { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto und = p.add_instruction(migraphx::op::undefined{}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto concat_hs = p.add_instruction(migraphx::op::gru{hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip, 1}, seq, w, r, bias, und, ih); p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs); p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{-0.3969709, 0.43360898, 0.35775262, 0.23280787, -0.52179873, -0.21944991, 0.4535257, -0.13735442, 0.51757574, 0.50380427}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } // last output for output, linear_before_reset = 0 { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto und = p.add_instruction(migraphx::op::undefined{}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto concat_hs = p.add_instruction(migraphx::op::gru{hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip, 0}, seq, w, r, bias, und, ih); p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{-0.53291196, 0.50160867, 0.39010462, 0.39292926, -0.5960838, -0.38451535, 0.454239, -0.10620412, 0.6014447, 0.43445644}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } } TEST_CASE(gru_forward_args) { std::size_t batch_size = 2; std::size_t seq_len = 3; std::size_t hidden_size = 5; std::size_t input_size = 3; std::size_t num_dirct = 1; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}}; std::vector w_data{ 0.3485, -0.0378, -0.1782, 0.1416, -0.3096, -0.2212, -0.3883, 0.1983, -0.2418, 0.1480, -0.3255, 0.1359, -0.3551, -0.3605, -0.3482, -0.1424, -0.0495, -0.1640, -0.1979, -0.2577, -0.4097, -0.1211, -0.0412, 0.1801, 0.1721, -0.4327, -0.0498, 0.2628, -0.1573, -0.1577, 0.2759, -0.2023, -0.1185, -0.2136, 0.1294, -0.2331, 0.0701, 0.4316, 0.0480, 0.0247, -0.0166, -0.2729, 0.1712, -0.3984, -0.3905}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}}; std::vector r_data{ 0.2848, -0.2851, -0.3466, -0.1718, -0.1492, -0.0082, 0.2452, -0.0401, 0.3399, 0.2529, -0.0953, -0.0903, -0.1518, -0.1373, 0.3848, -0.0130, -0.4339, 0.0406, -0.1926, -0.1131, 0.4285, -0.0013, 0.2243, 0.2752, 0.1776, -0.1720, 0.0822, -0.0295, 0.1062, -0.2721, -0.2736, -0.1826, 0.3541, -0.4259, 0.2188, 0.0706, 0.3650, 0.3947, 0.2522, 0.2179, -0.0744, 0.2122, -0.4346, 0.2760, 0.4076, 0.1183, -0.1500, -0.1704, 0.3090, -0.0706, -0.2442, 0.3021, 0.1680, 0.0783, -0.3754, -0.3469, -0.2972, -0.0170, 0.4143, 0.3801, 0.3852, -0.1170, -0.2937, 0.2979, -0.1357, 0.4257, 0.3884, -0.2916, 0.1071, 0.0934, 0.3645, -0.4310, -0.3480, 0.0702, -0.1558}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; std::vector bias_data{ 0.0560, 0.0310, -0.1669, -0.0781, 0.1793, -0.1758, 0.3173, -0.1650, -0.3732, 0.2946, -0.0912, 0.3118, 0.1391, 0.2755, 0.2695, -0.1059, -0.2357, 0.3629, -0.2534, -0.0494, 0.0556, 0.0881, -0.2592, -0.2213, 0.2310, -0.4044, 0.1801, 0.1438, 0.3108, -0.3607}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; std::vector input{-0.8432, -0.9887, 1.3041, -2.6430, -0.3306, -0.8504, -0.3933, 0.5151, -0.2951, 0.0093, -1.1948, -0.1239, 0.0373, 1.3211, 0.7854, -0.4838, -1.0536, -0.2529}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; std::vector ih_data{ -0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212}; float clip = 0.0f; // 3 args { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); p.add_instruction(migraphx::op::gru{hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip, 1}, seq, w, r); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{-0.114674, -0.129581, -0.218156, -0.140788, -0.114242, -0.346569, 0.321367, -0.0838253, 0.102097, 0.00232137, -0.149055, 0.0590743, -0.0533094, -0.0446122, -0.112588, 0.0153261, 0.168883, -0.326836, 0.0843562, 0.160872, -0.232523, 0.00214573, 0.231693, -0.160475, -0.518952, 0.0467166, 0.12327, -0.374162, 0.137778, 0.251976}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } // 4 args (bias is used) { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); p.add_instruction(migraphx::op::gru{hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip, 1}, seq, w, r, bias); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{-0.273619, 0.0931375, -0.104717, 0.0203752, -0.0797887, -0.493948, 0.472118, -0.0336318, 0.332706, 0.0182268, -0.341684, 0.38063, 0.0589275, 0.2644, -0.115737, -0.152324, 0.442277, -0.201626, 0.408909, 0.12905, -0.416866, 0.377186, 0.32922, 0.162214, -0.519973, -0.140072, 0.465076, -0.229563, 0.500164, 0.195166}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } // 4 args (ih is used) { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto und = p.add_instruction(migraphx::op::undefined{}); p.add_instruction(migraphx::op::gru{hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip, 1}, seq, w, r, und, und, ih); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{-0.0801064, 0.27025, -0.20704, 0.333579, -0.0452438, -0.56265, 0.061061, 0.262172, 0.405193, 0.775226, -0.100683, 0.258729, -0.0187297, 0.215815, -0.108936, -0.0941018, 0.129665, -0.159421, 0.190636, 0.597412, -0.197, 0.0885705, 0.269396, -0.0414511, -0.515137, -0.03075, 0.158326, -0.296488, 0.177983, 0.519498}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } } TEST_CASE(gru_forward_actv_funcs) { std::size_t batch_size = 2; std::size_t seq_len = 3; std::size_t hidden_size = 5; std::size_t input_size = 3; std::size_t num_dirct = 1; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}}; std::vector w_data{ 0.3485, -0.0378, -0.1782, 0.1416, -0.3096, -0.2212, -0.3883, 0.1983, -0.2418, 0.1480, -0.3255, 0.1359, -0.3551, -0.3605, -0.3482, -0.1424, -0.0495, -0.1640, -0.1979, -0.2577, -0.4097, -0.1211, -0.0412, 0.1801, 0.1721, -0.4327, -0.0498, 0.2628, -0.1573, -0.1577, 0.2759, -0.2023, -0.1185, -0.2136, 0.1294, -0.2331, 0.0701, 0.4316, 0.0480, 0.0247, -0.0166, -0.2729, 0.1712, -0.3984, -0.3905}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}}; std::vector r_data{ 0.2848, -0.2851, -0.3466, -0.1718, -0.1492, -0.0082, 0.2452, -0.0401, 0.3399, 0.2529, -0.0953, -0.0903, -0.1518, -0.1373, 0.3848, -0.0130, -0.4339, 0.0406, -0.1926, -0.1131, 0.4285, -0.0013, 0.2243, 0.2752, 0.1776, -0.1720, 0.0822, -0.0295, 0.1062, -0.2721, -0.2736, -0.1826, 0.3541, -0.4259, 0.2188, 0.0706, 0.3650, 0.3947, 0.2522, 0.2179, -0.0744, 0.2122, -0.4346, 0.2760, 0.4076, 0.1183, -0.1500, -0.1704, 0.3090, -0.0706, -0.2442, 0.3021, 0.1680, 0.0783, -0.3754, -0.3469, -0.2972, -0.0170, 0.4143, 0.3801, 0.3852, -0.1170, -0.2937, 0.2979, -0.1357, 0.4257, 0.3884, -0.2916, 0.1071, 0.0934, 0.3645, -0.4310, -0.3480, 0.0702, -0.1558}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; std::vector bias_data{ 0.0560, 0.0310, -0.1669, -0.0781, 0.1793, -0.1758, 0.3173, -0.1650, -0.3732, 0.2946, -0.0912, 0.3118, 0.1391, 0.2755, 0.2695, -0.1059, -0.2357, 0.3629, -0.2534, -0.0494, 0.0556, 0.0881, -0.2592, -0.2213, 0.2310, -0.4044, 0.1801, 0.1438, 0.3108, -0.3607}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; std::vector input{-0.8432, -0.9887, 1.3041, -2.6430, -0.3306, -0.8504, -0.3933, 0.5151, -0.2951, 0.0093, -1.1948, -0.1239, 0.0373, 1.3211, 0.7854, -0.4838, -1.0536, -0.2529}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; std::vector ih_data{ -0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212}; float clip = 0.0f; // no activation function specified, so default is used. { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto und = p.add_instruction(migraphx::op::undefined{}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto concat_hs = p.add_instruction( migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::forward, clip, 1}, seq, w, r, bias, und, ih); p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{-0.3969709, 0.43360898, 0.35775262, 0.23280787, -0.52179873, -0.21944991, 0.4535257, -0.13735442, 0.51757574, 0.50380427}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } // 1 activation function (sigmoid) specified { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto und = p.add_instruction(migraphx::op::undefined{}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); p.add_instruction(migraphx::op::gru{hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::forward, clip, 1}, seq, w, r, bias, und, ih); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{0.26905832, 0.5669211, 0.20464146, 0.67195725, 0.24752215, 0.11411376, 0.12353572, 0.4245067, 0.73908687, 0.8644615, 0.34754312, 0.61424744, 0.36769435, 0.6499579, 0.3168031, 0.3296533, 0.3055136, 0.42514813, 0.6851256, 0.7967266, 0.35652235, 0.6033026, 0.52634895, 0.5815402, 0.3001663, 0.39814138, 0.4354002, 0.4310627, 0.6708563, 0.7509278}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } // 1 activation function (tanh) specified { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto und = p.add_instruction(migraphx::op::undefined{}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto concat_hs = p.add_instruction( migraphx::op::gru{ hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip, 1}, seq, w, r, bias, und, ih); p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{-0.49333298, -0.06104589, 0.5629142, -0.97955984, -0.9314696, -0.03033514, 0.5280315, -0.27354342, 0.65615714, 0.53612584}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } // seq length of 1 { migraphx::program p; seq_len = 1; migraphx::shape in_shape_one{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; std::vector input_one{-0.8432, -0.9887, 1.3041, -2.6430, -0.3306, -0.8504}; auto seq = p.add_literal(migraphx::literal{in_shape_one, input_one}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto und = p.add_instruction(migraphx::op::undefined{}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); p.add_instruction(migraphx::op::gru{hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip, 1}, seq, w, r, bias, und, ih); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{-0.27298412, 0.42363745, -0.09368783, 0.4823072, -0.02183238, -0.6873896, 0.16144305, 0.31932795, 0.6104771, 0.79759157}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } } TEST_CASE(gru_reverse) { std::size_t batch_size = 2; std::size_t seq_len = 3; std::size_t hidden_size = 5; std::size_t input_size = 3; std::size_t num_dirct = 1; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}}; std::vector w_data{ 0.3485, -0.0378, -0.1782, 0.1416, -0.3096, -0.2212, -0.3883, 0.1983, -0.2418, 0.1480, -0.3255, 0.1359, -0.3551, -0.3605, -0.3482, -0.1424, -0.0495, -0.1640, -0.1979, -0.2577, -0.4097, -0.1211, -0.0412, 0.1801, 0.1721, -0.4327, -0.0498, 0.2628, -0.1573, -0.1577, 0.2759, -0.2023, -0.1185, -0.2136, 0.1294, -0.2331, 0.0701, 0.4316, 0.0480, 0.0247, -0.0166, -0.2729, 0.1712, -0.3984, -0.3905}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}}; std::vector r_data{ 0.2848, -0.2851, -0.3466, -0.1718, -0.1492, -0.0082, 0.2452, -0.0401, 0.3399, 0.2529, -0.0953, -0.0903, -0.1518, -0.1373, 0.3848, -0.0130, -0.4339, 0.0406, -0.1926, -0.1131, 0.4285, -0.0013, 0.2243, 0.2752, 0.1776, -0.1720, 0.0822, -0.0295, 0.1062, -0.2721, -0.2736, -0.1826, 0.3541, -0.4259, 0.2188, 0.0706, 0.3650, 0.3947, 0.2522, 0.2179, -0.0744, 0.2122, -0.4346, 0.2760, 0.4076, 0.1183, -0.1500, -0.1704, 0.3090, -0.0706, -0.2442, 0.3021, 0.1680, 0.0783, -0.3754, -0.3469, -0.2972, -0.0170, 0.4143, 0.3801, 0.3852, -0.1170, -0.2937, 0.2979, -0.1357, 0.4257, 0.3884, -0.2916, 0.1071, 0.0934, 0.3645, -0.4310, -0.3480, 0.0702, -0.1558}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; std::vector bias_data{ 0.0560, 0.0310, -0.1669, -0.0781, 0.1793, -0.1758, 0.3173, -0.1650, -0.3732, 0.2946, -0.0912, 0.3118, 0.1391, 0.2755, 0.2695, -0.1059, -0.2357, 0.3629, -0.2534, -0.0494, 0.0556, 0.0881, -0.2592, -0.2213, 0.2310, -0.4044, 0.1801, 0.1438, 0.3108, -0.3607}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; std::vector input{-0.8432, -0.9887, 1.3041, -2.6430, -0.3306, -0.8504, -0.3933, 0.5151, -0.2951, 0.0093, -1.1948, -0.1239, 0.0373, 1.3211, 0.7854, -0.4838, -1.0536, -0.2529}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; std::vector ih_data{ -0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212}; float clip = 0.0f; // concatenation of hidden states for output { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto und = p.add_instruction(migraphx::op::undefined{}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); p.add_instruction(migraphx::op::gru{hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip, 1}, seq, w, r, bias, und, ih); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{-0.263403, 0.317655, -0.00634162, 0.200443, -0.349125, -0.600874, 0.542386, -0.0856531, 0.55703, 0.54711, -0.276245, 0.521348, 0.302874, 0.394353, -0.334369, -0.187861, 0.213553, -0.0708377, 0.545435, 0.654301, -0.329512, 0.476095, 0.284044, 0.392077, -0.369226, -0.3275, -0.027301, 0.143774, 0.655686, 0.782831}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } // last output for output { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto und = p.add_instruction(migraphx::op::undefined{}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto concat_hs = p.add_instruction(migraphx::op::gru{hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip, 1}, seq, w, r, bias, und, ih); p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{-0.263403, 0.317655, -0.00634162, 0.200443, -0.349125, -0.600874, 0.542386, -0.0856531, 0.55703, 0.54711}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } // last output for output, linear_before_reset = 0 { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto und = p.add_instruction(migraphx::op::undefined{}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto concat_hs = p.add_instruction(migraphx::op::gru{hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip, 0}, seq, w, r, bias, und, ih); p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{-0.388654, 0.384975, 0.0179455, 0.350101, -0.456872, -0.690085, 0.534512, -0.0558191, 0.646604, 0.463943}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } // no activation function specified, so default is used. { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto und = p.add_instruction(migraphx::op::undefined{}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); p.add_instruction( migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::reverse, clip, 1}, seq, w, r, bias, und, ih); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{-0.263403, 0.317655, -0.00634162, 0.200443, -0.349125, -0.600874, 0.542386, -0.0856531, 0.55703, 0.54711, -0.276245, 0.521348, 0.302874, 0.394353, -0.334369, -0.187861, 0.213553, -0.0708377, 0.545435, 0.654301, -0.329512, 0.476095, 0.284044, 0.392077, -0.369226, -0.3275, -0.027301, 0.143774, 0.655686, 0.782831}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } // seq length of 1 { migraphx::program p; seq_len = 1; migraphx::shape in_shape_one{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; std::vector input_one{-0.8432, -0.9887, 1.3041, -2.6430, -0.3306, -0.8504}; auto seq = p.add_literal(migraphx::literal{in_shape_one, input_one}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto und = p.add_instruction(migraphx::op::undefined{}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); p.add_instruction(migraphx::op::gru{hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip, 1}, seq, w, r, bias, und, ih); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{-0.272984, 0.423637, -0.0936878, 0.482307, -0.0218324, -0.68739, 0.161443, 0.319328, 0.610477, 0.797592}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } } TEST_CASE(gru_bidirectional) { std::size_t batch_size = 2; std::size_t seq_len = 3; std::size_t hidden_size = 5; std::size_t input_size = 3; std::size_t num_dirct = 2; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}}; std::vector w_data{ 0.3809, 0.4283, 0.2294, -0.1018, -0.1226, -0.0037, 0.2449, -0.2712, -0.1418, 0.1363, -0.3453, -0.0693, -0.2281, 0.2699, -0.2024, -0.3085, -0.3338, 0.4109, 0.2605, -0.1019, -0.2813, 0.3323, -0.1590, 0.0788, -0.3535, 0.0397, 0.2732, 0.2906, 0.0519, 0.3617, -0.2664, 0.1441, 0.0464, -0.1057, 0.2204, -0.3294, 0.3670, 0.1411, 0.3852, 0.3572, 0.3918, 0.0483, -0.3906, -0.2841, -0.2778, -0.4272, 0.2335, -0.1811, -0.3885, -0.1279, 0.1000, 0.0206, -0.3284, -0.0353, 0.1197, 0.1190, 0.3862, 0.0965, -0.0492, 0.2657, -0.1430, 0.0597, 0.1408, -0.0315, 0.1248, 0.0751, 0.3838, 0.3020, 0.0515, 0.2375, -0.4255, 0.1714, -0.0432, 0.3447, -0.2441, -0.3989, -0.3428, -0.4204, -0.4080, -0.2683, -0.0996, -0.1685, -0.0532, -0.1258, 0.1663, -0.3526, -0.3915, -0.1721, 0.1292, -0.2279}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}}; std::vector r_data{ -0.2683, 0.0699, -0.4021, -0.1379, 0.0042, -0.2447, 0.4006, 0.0270, -0.0446, 0.1063, 0.1381, 0.1310, -0.3596, 0.3869, 0.3929, 0.2750, 0.0890, 0.3069, -0.1691, -0.2194, -0.1066, 0.3187, -0.4369, -0.0603, -0.0834, -0.1182, -0.2047, 0.3253, -0.2931, 0.2082, 0.0424, 0.1111, -0.2773, -0.0279, -0.0869, 0.1413, -0.4227, -0.3672, 0.4137, 0.0609, 0.4223, -0.4032, 0.2945, 0.3600, 0.3345, -0.3880, -0.0192, -0.0090, -0.2648, 0.4339, -0.0155, 0.4437, -0.1766, 0.1957, 0.2475, 0.3773, -0.2710, 0.3289, -0.2077, -0.2534, -0.0832, -0.1632, 0.0728, 0.2520, 0.4153, 0.1659, -0.4342, 0.0541, 0.1812, -0.2305, 0.4440, 0.0946, 0.0410, -0.4381, -0.3161, 0.3906, -0.3958, -0.4238, 0.1975, 0.3440, 0.1437, -0.0568, 0.1492, -0.4248, -0.3304, 0.2786, -0.1328, -0.3740, -0.3566, 0.3074, 0.0924, 0.2684, -0.1527, 0.1826, 0.2424, 0.2002, 0.3479, -0.1089, 0.3472, -0.3677, -0.4231, -0.0798, -0.3709, 0.3924, 0.2774, -0.3690, -0.0233, 0.2845, 0.1969, 0.1618, -0.3742, -0.3619, 0.2925, -0.1838, -0.1495, -0.3747, 0.0341, -0.4243, -0.0732, -0.3997, 0.2139, 0.2425, 0.4171, -0.3358, 0.3534, 0.0938, -0.0582, -0.2681, -0.4293, 0.1027, 0.4101, 0.2641, -0.4110, -0.1681, 0.3582, -0.2089, 0.0852, 0.0963, 0.3866, 0.1955, -0.2174, 0.1996, -0.2252, 0.1748, 0.1833, -0.3155, 0.2567, -0.4387, 0.3402, 0.0599}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; std::vector bias_data{ -0.1582, -0.0826, 0.4008, 0.0118, 0.2511, 0.1900, -0.2838, 0.2549, -0.2484, 0.2363, -0.4083, -0.0295, -0.1161, 0.1211, 0.2509, -0.1414, -0.2628, -0.2992, 0.1517, 0.1817, -0.2783, 0.3183, -0.1629, -0.3108, -0.3418, 0.0411, 0.2203, 0.2187, -0.2990, -0.0416, 0.0209, -0.1024, 0.4443, -0.4420, -0.0330, -0.3591, -0.2990, 0.2167, 0.1395, 0.2317, 0.1318, 0.1909, -0.3615, 0.1953, -0.2582, -0.2217, 0.3723, 0.1458, 0.2630, -0.0377, 0.1754, 0.0800, -0.3964, -0.3247, 0.4219, -0.0900, 0.3553, 0.2614, -0.1298, -0.1124}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; std::vector input{-0.8432, -0.9887, 1.3041, -2.6430, -0.3306, -0.8504, -0.3933, 0.5151, -0.2951, 0.0093, -1.1948, -0.1239, 0.0373, 1.3211, 0.7854, -0.4838, -1.0536, -0.2529}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; std::vector ih_data{-0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212, -0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212}; float clip = 0.0f; // concatenation of hidden states for output { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto und = p.add_instruction(migraphx::op::undefined{}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); p.add_instruction(migraphx::op::gru{hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip, 1}, seq, w, r, bias, und, ih); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{ 0.0352243, 0.0146756, 0.00570925, 0.152446, 0.208683, 0.214342, -0.0454273, -0.135177, -0.0800739, 0.903659, 0.0248217, 0.435231, -0.144448, 0.101531, -0.111305, 0.381317, 0.468983, 0.230557, 0.348021, 0.180229, -0.0930435, 0.174108, -0.063834, 0.0909285, 0.22759, -0.221983, -0.139656, -0.0938906, -0.247681, 0.69647, -0.159396, 0.299061, -0.116652, 0.238649, 0.109945, 0.192866, 0.307073, 0.191113, 0.658287, -0.0340374, -0.0959787, 0.0794681, 0.241526, 0.321104, 0.00693533, -0.311839, -0.12802, -0.16643, -0.393849, 0.648851, -0.395918, 0.231694, -0.160503, 0.383289, 0.0879262, -0.0254665, 0.079043, 0.322652, 0.752701, 0.243775}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } // last output for output { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto und = p.add_instruction(migraphx::op::undefined{}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto concat_hs = p.add_instruction(migraphx::op::gru{hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip, 1}, seq, w, r, bias, und, ih); p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{-0.0959787, 0.0794681, 0.241526, 0.321104, 0.00693533, -0.311839, -0.12802, -0.16643, -0.393849, 0.648851, 0.0248217, 0.435231, -0.144448, 0.101531, -0.111305, 0.381317, 0.468983, 0.230557, 0.348021, 0.180229}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } // last output for output, linear_before_reset = 0 { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto und = p.add_instruction(migraphx::op::undefined{}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto concat_hs = p.add_instruction(migraphx::op::gru{hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip, 0}, seq, w, r, bias, und, ih); p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{ -0.09280921, 0.18506107, 0.32247013, 0.17034212, -0.00115255, -0.29865006, -0.04513004, -0.10688055, -0.4767866, 0.6317833, 0.00286336, 0.53692746, -0.00617076, 0.04564289, -0.18030001, 0.39584228, 0.53879917, 0.384983, 0.2759448, 0.11611474}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } } TEST_CASE(gru_bidirectional_args) { std::size_t batch_size = 2; std::size_t seq_len = 3; std::size_t hidden_size = 5; std::size_t input_size = 3; std::size_t num_dirct = 2; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}}; std::vector w_data{ 0.3809, 0.4283, 0.2294, -0.1018, -0.1226, -0.0037, 0.2449, -0.2712, -0.1418, 0.1363, -0.3453, -0.0693, -0.2281, 0.2699, -0.2024, -0.3085, -0.3338, 0.4109, 0.2605, -0.1019, -0.2813, 0.3323, -0.1590, 0.0788, -0.3535, 0.0397, 0.2732, 0.2906, 0.0519, 0.3617, -0.2664, 0.1441, 0.0464, -0.1057, 0.2204, -0.3294, 0.3670, 0.1411, 0.3852, 0.3572, 0.3918, 0.0483, -0.3906, -0.2841, -0.2778, -0.4272, 0.2335, -0.1811, -0.3885, -0.1279, 0.1000, 0.0206, -0.3284, -0.0353, 0.1197, 0.1190, 0.3862, 0.0965, -0.0492, 0.2657, -0.1430, 0.0597, 0.1408, -0.0315, 0.1248, 0.0751, 0.3838, 0.3020, 0.0515, 0.2375, -0.4255, 0.1714, -0.0432, 0.3447, -0.2441, -0.3989, -0.3428, -0.4204, -0.4080, -0.2683, -0.0996, -0.1685, -0.0532, -0.1258, 0.1663, -0.3526, -0.3915, -0.1721, 0.1292, -0.2279}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}}; std::vector r_data{ -0.2683, 0.0699, -0.4021, -0.1379, 0.0042, -0.2447, 0.4006, 0.0270, -0.0446, 0.1063, 0.1381, 0.1310, -0.3596, 0.3869, 0.3929, 0.2750, 0.0890, 0.3069, -0.1691, -0.2194, -0.1066, 0.3187, -0.4369, -0.0603, -0.0834, -0.1182, -0.2047, 0.3253, -0.2931, 0.2082, 0.0424, 0.1111, -0.2773, -0.0279, -0.0869, 0.1413, -0.4227, -0.3672, 0.4137, 0.0609, 0.4223, -0.4032, 0.2945, 0.3600, 0.3345, -0.3880, -0.0192, -0.0090, -0.2648, 0.4339, -0.0155, 0.4437, -0.1766, 0.1957, 0.2475, 0.3773, -0.2710, 0.3289, -0.2077, -0.2534, -0.0832, -0.1632, 0.0728, 0.2520, 0.4153, 0.1659, -0.4342, 0.0541, 0.1812, -0.2305, 0.4440, 0.0946, 0.0410, -0.4381, -0.3161, 0.3906, -0.3958, -0.4238, 0.1975, 0.3440, 0.1437, -0.0568, 0.1492, -0.4248, -0.3304, 0.2786, -0.1328, -0.3740, -0.3566, 0.3074, 0.0924, 0.2684, -0.1527, 0.1826, 0.2424, 0.2002, 0.3479, -0.1089, 0.3472, -0.3677, -0.4231, -0.0798, -0.3709, 0.3924, 0.2774, -0.3690, -0.0233, 0.2845, 0.1969, 0.1618, -0.3742, -0.3619, 0.2925, -0.1838, -0.1495, -0.3747, 0.0341, -0.4243, -0.0732, -0.3997, 0.2139, 0.2425, 0.4171, -0.3358, 0.3534, 0.0938, -0.0582, -0.2681, -0.4293, 0.1027, 0.4101, 0.2641, -0.4110, -0.1681, 0.3582, -0.2089, 0.0852, 0.0963, 0.3866, 0.1955, -0.2174, 0.1996, -0.2252, 0.1748, 0.1833, -0.3155, 0.2567, -0.4387, 0.3402, 0.0599}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; std::vector bias_data{ -0.1582, -0.0826, 0.4008, 0.0118, 0.2511, 0.1900, -0.2838, 0.2549, -0.2484, 0.2363, -0.4083, -0.0295, -0.1161, 0.1211, 0.2509, -0.1414, -0.2628, -0.2992, 0.1517, 0.1817, -0.2783, 0.3183, -0.1629, -0.3108, -0.3418, 0.0411, 0.2203, 0.2187, -0.2990, -0.0416, 0.0209, -0.1024, 0.4443, -0.4420, -0.0330, -0.3591, -0.2990, 0.2167, 0.1395, 0.2317, 0.1318, 0.1909, -0.3615, 0.1953, -0.2582, -0.2217, 0.3723, 0.1458, 0.2630, -0.0377, 0.1754, 0.0800, -0.3964, -0.3247, 0.4219, -0.0900, 0.3553, 0.2614, -0.1298, -0.1124}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; std::vector input{-0.8432, -0.9887, 1.3041, -2.6430, -0.3306, -0.8504, -0.3933, 0.5151, -0.2951, 0.0093, -1.1948, -0.1239, 0.0373, 1.3211, 0.7854, -0.4838, -1.0536, -0.2529}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; std::vector ih_data{-0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212, -0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212}; float clip = 0.0f; // 3 args { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); p.add_instruction(migraphx::op::gru{hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip, 0}, seq, w, r); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{ 0.0863793, -0.227845, 0.0283059, -0.258645, 0.14187, 0.43541, 0.190748, -0.530196, -0.440444, 0.293767, 0.0402142, 0.0788687, -0.013, -0.233298, -0.0739615, 0.467104, 0.446285, 0.306097, 0.125636, 0.272524, 0.0949838, 0.0522264, -0.0872712, -0.084203, 0.140013, 0.12739, -0.0111171, -0.431119, -0.468382, 0.388067, -0.109174, -0.119064, -0.0242958, -0.180555, 0.118983, 0.341578, 0.275472, 0.0853083, 0.332205, -0.0498387, 0.140338, 0.0319435, 0.247019, 0.275848, -0.158223, 0.0495464, -0.0681034, -0.418158, -0.523234, 0.469122, -0.306578, -0.221095, -0.106449, -0.248934, -0.00682121, 0.288407, 0.198708, 0.0695644, 0.211621, 0.00246037}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } // 4 args (bias is used) { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); p.add_instruction(migraphx::op::gru{hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip, 1}, seq, w, r, bias); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{ -0.156667, -0.248473, 0.0255282, -0.24566, 0.211589, 0.192707, 0.253025, -0.515283, -0.414174, 0.227127, 0.124773, 0.284532, -0.203929, -0.120517, -0.2794, 0.547635, 0.518549, 0.0447674, 0.258461, 0.0502881, -0.219516, 0.0927382, -0.0760062, -0.0906231, 0.237615, -0.215638, 0.0128074, -0.425813, -0.433378, 0.375383, -0.0381738, 0.117793, -0.180851, -0.0841245, -0.116649, 0.419469, 0.393515, -0.076395, 0.427436, -0.264071, -0.185829, 0.0483585, 0.242955, 0.25233, 0.0148512, -0.304127, -0.0616653, -0.411568, -0.491748, 0.476508, -0.313413, -0.0361821, -0.173037, -0.235731, -0.163113, 0.349008, 0.248674, -0.0295413, 0.291437, -0.165005}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } // 4 args (ih is used) { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto und = p.add_instruction(migraphx::op::undefined{}); p.add_instruction(migraphx::op::gru{hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip, 1}, seq, w, r, und, und, ih); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{ 0.248571, 0.0982155, 0.00808877, 0.0986508, 0.0969705, 0.434692, -0.141696, -0.164271, -0.121157, 0.863222, -0.0718357, 0.137711, 0.109221, -0.00207995, 0.0331223, 0.262705, 0.346587, 0.457158, 0.240744, 0.404261, 0.222779, 0.179757, -0.0845316, 0.0690347, 0.10204, 0.100155, -0.190286, -0.122062, -0.274379, 0.547281, -0.226753, -0.0397069, 0.120404, 0.171299, 0.259989, 0.0864604, 0.111322, 0.331784, 0.604653, 0.181017, 0.237426, 0.0911999, 0.233106, 0.32996, -0.17175, 0.0190231, -0.154805, -0.205631, -0.405354, 0.519054, -0.380409, -0.0350301, -0.00633752, 0.403791, 0.181883, -0.0977917, -0.0339407, 0.413089, 0.721238, 0.431879}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } } TEST_CASE(gru_bidirectional_actv_funcs) { std::size_t batch_size = 2; std::size_t seq_len = 3; std::size_t hidden_size = 5; std::size_t input_size = 3; std::size_t num_dirct = 2; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}}; std::vector w_data{ 0.3809, 0.4283, 0.2294, -0.1018, -0.1226, -0.0037, 0.2449, -0.2712, -0.1418, 0.1363, -0.3453, -0.0693, -0.2281, 0.2699, -0.2024, -0.3085, -0.3338, 0.4109, 0.2605, -0.1019, -0.2813, 0.3323, -0.1590, 0.0788, -0.3535, 0.0397, 0.2732, 0.2906, 0.0519, 0.3617, -0.2664, 0.1441, 0.0464, -0.1057, 0.2204, -0.3294, 0.3670, 0.1411, 0.3852, 0.3572, 0.3918, 0.0483, -0.3906, -0.2841, -0.2778, -0.4272, 0.2335, -0.1811, -0.3885, -0.1279, 0.1000, 0.0206, -0.3284, -0.0353, 0.1197, 0.1190, 0.3862, 0.0965, -0.0492, 0.2657, -0.1430, 0.0597, 0.1408, -0.0315, 0.1248, 0.0751, 0.3838, 0.3020, 0.0515, 0.2375, -0.4255, 0.1714, -0.0432, 0.3447, -0.2441, -0.3989, -0.3428, -0.4204, -0.4080, -0.2683, -0.0996, -0.1685, -0.0532, -0.1258, 0.1663, -0.3526, -0.3915, -0.1721, 0.1292, -0.2279}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}}; std::vector r_data{ -0.2683, 0.0699, -0.4021, -0.1379, 0.0042, -0.2447, 0.4006, 0.0270, -0.0446, 0.1063, 0.1381, 0.1310, -0.3596, 0.3869, 0.3929, 0.2750, 0.0890, 0.3069, -0.1691, -0.2194, -0.1066, 0.3187, -0.4369, -0.0603, -0.0834, -0.1182, -0.2047, 0.3253, -0.2931, 0.2082, 0.0424, 0.1111, -0.2773, -0.0279, -0.0869, 0.1413, -0.4227, -0.3672, 0.4137, 0.0609, 0.4223, -0.4032, 0.2945, 0.3600, 0.3345, -0.3880, -0.0192, -0.0090, -0.2648, 0.4339, -0.0155, 0.4437, -0.1766, 0.1957, 0.2475, 0.3773, -0.2710, 0.3289, -0.2077, -0.2534, -0.0832, -0.1632, 0.0728, 0.2520, 0.4153, 0.1659, -0.4342, 0.0541, 0.1812, -0.2305, 0.4440, 0.0946, 0.0410, -0.4381, -0.3161, 0.3906, -0.3958, -0.4238, 0.1975, 0.3440, 0.1437, -0.0568, 0.1492, -0.4248, -0.3304, 0.2786, -0.1328, -0.3740, -0.3566, 0.3074, 0.0924, 0.2684, -0.1527, 0.1826, 0.2424, 0.2002, 0.3479, -0.1089, 0.3472, -0.3677, -0.4231, -0.0798, -0.3709, 0.3924, 0.2774, -0.3690, -0.0233, 0.2845, 0.1969, 0.1618, -0.3742, -0.3619, 0.2925, -0.1838, -0.1495, -0.3747, 0.0341, -0.4243, -0.0732, -0.3997, 0.2139, 0.2425, 0.4171, -0.3358, 0.3534, 0.0938, -0.0582, -0.2681, -0.4293, 0.1027, 0.4101, 0.2641, -0.4110, -0.1681, 0.3582, -0.2089, 0.0852, 0.0963, 0.3866, 0.1955, -0.2174, 0.1996, -0.2252, 0.1748, 0.1833, -0.3155, 0.2567, -0.4387, 0.3402, 0.0599}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}}; std::vector bias_data{ -0.1582, -0.0826, 0.4008, 0.0118, 0.2511, 0.1900, -0.2838, 0.2549, -0.2484, 0.2363, -0.4083, -0.0295, -0.1161, 0.1211, 0.2509, -0.1414, -0.2628, -0.2992, 0.1517, 0.1817, -0.2783, 0.3183, -0.1629, -0.3108, -0.3418, 0.0411, 0.2203, 0.2187, -0.2990, -0.0416, 0.0209, -0.1024, 0.4443, -0.4420, -0.0330, -0.3591, -0.2990, 0.2167, 0.1395, 0.2317, 0.1318, 0.1909, -0.3615, 0.1953, -0.2582, -0.2217, 0.3723, 0.1458, 0.2630, -0.0377, 0.1754, 0.0800, -0.3964, -0.3247, 0.4219, -0.0900, 0.3553, 0.2614, -0.1298, -0.1124}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; std::vector input{-0.8432, -0.9887, 1.3041, -2.6430, -0.3306, -0.8504, -0.3933, 0.5151, -0.2951, 0.0093, -1.1948, -0.1239, 0.0373, 1.3211, 0.7854, -0.4838, -1.0536, -0.2529}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; std::vector ih_data{-0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212, -0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212}; float clip = 0.0f; // no activation function specified, so default is used. { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto und = p.add_instruction(migraphx::op::undefined{}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto concat_hs = p.add_instruction( migraphx::op::gru{hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip, 1}, seq, w, r, bias, und, ih); p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{-0.0959787, 0.0794681, 0.241526, 0.321104, 0.00693533, -0.311839, -0.12802, -0.16643, -0.393849, 0.648851, 0.0248217, 0.435231, -0.144448, 0.101531, -0.111305, 0.381317, 0.468983, 0.230557, 0.348021, 0.180229}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } // 1 activation function (sigmoid) specified { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto und = p.add_instruction(migraphx::op::undefined{}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); p.add_instruction(migraphx::op::gru{hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::bidirectional, clip, 0}, seq, w, r, bias, und, ih); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{ 0.325495, 0.469214, 0.164517, 0.585327, 0.328398, 0.457928, 0.065011, 0.35986, 0.545029, 0.859425, 0.427923, 0.667133, 0.41591, 0.540971, 0.365475, 0.482058, 0.565495, 0.556993, 0.607649, 0.543627, 0.428915, 0.537405, 0.306046, 0.518399, 0.403561, 0.410694, 0.301163, 0.407397, 0.471334, 0.726446, 0.309389, 0.612072, 0.360619, 0.590861, 0.366545, 0.367001, 0.433829, 0.501275, 0.72481, 0.512745, 0.463795, 0.539649, 0.487682, 0.554471, 0.395916, 0.430744, 0.415923, 0.424275, 0.409655, 0.698256, 0.126883, 0.554374, 0.216137, 0.671491, 0.263833, 0.0678646, 0.132732, 0.477083, 0.802206, 0.626802}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } // 1 activation function (tanh) specified { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto und = p.add_instruction(migraphx::op::undefined{}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); p.add_instruction(migraphx::op::gru{hidden_size, {migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip, 1}, seq, w, r, bias, und, ih); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{ 0.0919632, -0.398302, -0.0267752, -0.326771, 0.401983, 0.949841, 0.557779, -0.745259, -1.52726, 0.946066, 0.330446, 0.301982, -0.443763, -0.0655817, -0.326473, 0.861394, 0.560799, -0.101768, 0.145142, 0.128956, -0.329758, 0.458253, -0.339208, 0.289109, 0.36728, -1.09574, -0.181394, -0.575781, -0.823083, 0.804262, -0.0965933, 0.20405, -0.430215, 0.00884668, 0.0716857, 0.844222, 0.516472, -0.191571, 0.596968, -0.545405, -0.336693, -0.0280516, 0.339058, 1.00367, 0.12655, -0.0984504, -0.174945, -0.5365, 0.183188, 0.66716, -0.704461, -0.393346, -0.627123, 0.210395, 0.0563026, 0.31419, 0.759629, 0.000258222, 0.350835, -0.682684}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } // 3 activation functions specified { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto und = p.add_instruction(migraphx::op::undefined{}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto concat_hs = p.add_instruction( migraphx::op::gru{hidden_size, {migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip, 1}, seq, w, r, bias, und, ih); p.add_instruction(migraphx::op::rnn_last_output{}, concat_hs); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{0.351019, 0.474363, 0.570719, 0.717703, 0.468843, 1.15142, 0.457633, 0.300962, 0.361245, 0.666199, 0.330446, 0.301982, -0.443763, -0.0655817, -0.326473, 0.861394, 0.560799, -0.101768, 0.145142, 0.128956}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } // 4 activation functions all specified { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto und = p.add_instruction(migraphx::op::undefined{}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); p.add_instruction(migraphx::op::gru{hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip, 1}, seq, w, r, bias, und, ih); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{ 0.0352243, 0.0146756, 0.00570925, 0.152446, 0.208683, 0.214342, -0.0454273, -0.135177, -0.0800739, 0.903659, 0.0248217, 0.435231, -0.144448, 0.101531, -0.111305, 0.381317, 0.468983, 0.230557, 0.348021, 0.180229, -0.0930435, 0.174108, -0.063834, 0.0909285, 0.22759, -0.221983, -0.139656, -0.0938906, -0.247681, 0.69647, -0.159396, 0.299061, -0.116652, 0.238649, 0.109945, 0.192866, 0.307073, 0.191113, 0.658287, -0.0340374, -0.0959787, 0.0794681, 0.241526, 0.321104, 0.00693533, -0.311839, -0.12802, -0.16643, -0.393849, 0.648851, -0.395918, 0.231694, -0.160503, 0.383289, 0.0879262, -0.0254665, 0.079043, 0.322652, 0.752701, 0.243775}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } // seq length of 1 { migraphx::program p; seq_len = 1; migraphx::shape in_shape_one{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; std::vector input_one{-0.8432, -0.9887, 1.3041, -2.6430, -0.3306, -0.8504}; auto seq = p.add_literal(migraphx::literal{in_shape_one, input_one}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto und = p.add_instruction(migraphx::op::undefined{}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); p.add_instruction(migraphx::op::gru{hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip, 1}, seq, w, r, bias, und, ih); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{0.0352243, 0.0146756, 0.00570925, 0.152446, 0.208683, 0.214342, -0.0454273, -0.135177, -0.0800739, 0.903659, -0.0271321, 0.624762, -0.117084, 0.509115, -0.0175078, -0.144492, -0.0115366, 0.409153, 0.487015, 0.550755}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } } TEST_CASE(lstm_forward) { std::size_t batch_size = 3; std::size_t seq_len = 4; std::size_t hidden_size = 4; std::size_t input_size = 3; std::size_t num_dirct = 1; std::vector w_data{ 0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344, 0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382, -0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279, -0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976, -0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285}; std::vector r_data{ 0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685, -0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858, -0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265, -0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562, 0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100, 0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161, -0.2169, -0.1344, 0.3468, -0.2260}; std::vector bias_data{0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881, -0.4407, -0.2760, 0.1274, -0.0083, -0.2885, 0.3949, -0.0182, 0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067, 0.0807, 0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316, -0.3025, 0.3637, -0.3181, -0.4655}; std::vector input_data{ -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331, -0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122, 0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685, 0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310}; std::vector ih_data{1.9104, -1.9004, 0.3337, 0.5741, 0.5671, 0.0458, 0.4514, -0.8968, -0.9201, 0.1962, 0.5771, -0.5332}; std::vector ic_data{0.9569, -0.5981, 1.1312, 1.0945, 1.1055, -0.1212, -0.9097, 0.7831, -1.6991, -1.9498, -1.2567, -0.4114}; std::vector pph_data{1.84369764, 0.68413646, -0.44892886, -1.50904413, 0.3860796, -0.52186625, 1.08474445, -1.80867321, 1.32594529, 0.4336262, -0.83699064, 0.49162736}; float clip = 0.0f; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; // forward, hidden state concatenation as output { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data}); auto und = p.add_instruction(migraphx::op::undefined{}); p.add_instruction( migraphx::op::lstm{ hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip, 0}, seq, w, r, bias, und, ih, ic, und); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{ 0.0417273, -0.272355, 0.206765, 0.223879, 0.138193, -0.0322939, -0.0891815, 0.15773, 0.19139, -0.127708, -0.409371, -0.136186, 0.0742487, -0.0800085, 0.259897, 0.0670196, 0.184266, 0.0610048, -0.138041, 0.0963885, 0.0213755, -0.146027, -0.0324509, -0.0620429, -0.00532985, 0.0440265, 0.29654, -0.0463156, 0.0498799, 0.125772, 0.0533032, -0.131413, 0.0988431, -0.018085, -0.159434, 0.030266, -0.0847427, 0.0874114, 0.304256, -0.0585745, -0.0223018, 0.131113, 0.135643, -0.0566208, 0.142701, 0.0342236, -0.198664, 0.0702607}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } // forward, last_output as program output { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data}); auto und = p.add_instruction(migraphx::op::undefined{}); auto hs = p.add_instruction( migraphx::op::lstm{ hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip, 0}, seq, w, r, bias, und, ih, ic, und); p.add_instruction(migraphx::op::rnn_last_output{}, hs); p.compile(migraphx::cpu::target{}); auto last_hs = p.eval({}); std::vector output_data; last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); std::vector output_data_gold{-0.0847427, 0.0874114, 0.304256, -0.0585745, -0.0223018, 0.131113, 0.135643, -0.0566208, 0.142701, 0.0342236, -0.198664, 0.0702607}; EXPECT(migraphx::verify_range(output_data, output_data_gold)); } // forward, last_cell_output as program output { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data}); auto und = p.add_instruction(migraphx::op::undefined{}); auto hs = p.add_instruction( migraphx::op::lstm{ hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip, 0}, seq, w, r, bias, und, ih, ic, und); p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs); p.compile(migraphx::cpu::target{}); auto last_hs = p.eval({}); std::vector output_data; last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); std::vector output_data_gold{-0.111454, 0.247794, 0.471087, -0.220574, -0.048196, 0.263184, 0.283258, -0.14882, 0.605585, 0.078598, -0.64457, 0.119811}; EXPECT(migraphx::verify_range(output_data, output_data_gold)); } } TEST_CASE(lstm_forward_more) { std::size_t batch_size = 3; std::size_t seq_len = 4; std::size_t hidden_size = 4; std::size_t input_size = 3; std::size_t num_dirct = 1; std::vector w_data{ 0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344, 0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382, -0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279, -0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976, -0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285}; std::vector r_data{ 0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685, -0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858, -0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265, -0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562, 0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100, 0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161, -0.2169, -0.1344, 0.3468, -0.2260}; std::vector bias_data{0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881, -0.4407, -0.2760, 0.1274, -0.0083, -0.2885, 0.3949, -0.0182, 0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067, 0.0807, 0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316, -0.3025, 0.3637, -0.3181, -0.4655}; std::vector input_data{ -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331, -0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122, 0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685, 0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310}; std::vector ih_data{1.9104, -1.9004, 0.3337, 0.5741, 0.5671, 0.0458, 0.4514, -0.8968, -0.9201, 0.1962, 0.5771, -0.5332}; std::vector ic_data{0.9569, -0.5981, 1.1312, 1.0945, 1.1055, -0.1212, -0.9097, 0.7831, -1.6991, -1.9498, -1.2567, -0.4114}; std::vector pph_data{1.84369764, 0.68413646, -0.44892886, -1.50904413, 0.3860796, -0.52186625, 1.08474445, -1.80867321, 1.32594529, 0.4336262, -0.83699064, 0.49162736}; float clip = 0.0f; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; // forward, 3 args { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); p.add_instruction( migraphx::op::lstm{ hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip, 0}, seq, w, r); p.compile(migraphx::cpu::target{}); auto last_hs = p.eval({}); std::vector output_data; last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); std::vector output_data_gold{ -0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698, -0.0623361, 0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617, -0.0786602, -0.0613048, 0.179592, -0.071286, 0.074206, 0.0124086, -0.139544, 0.108016, -0.00973633, -0.0552699, 0.0252681, -0.0562072, -0.102509, -0.0372696, 0.252296, -0.144544, 0.00496085, 0.0662588, -0.048577, -0.187329, 0.0855831, -0.0171894, -0.140202, 0.0828391, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774}; EXPECT(migraphx::verify_range(output_data, output_data_gold)); } // forward, 8 args { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data}); auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data}); auto und = p.add_instruction(migraphx::op::undefined{}); p.add_instruction( migraphx::op::lstm{ hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip, 0}, seq, w, r, bias, und, ih, ic, pph); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{ 0.079753, -0.289854, 0.160043, 0.115056, 0.294074, -0.0319677, -0.0955337, 0.104168, 0.022618, -0.121195, -0.4065, -0.252054, 0.186991, -0.0624168, 0.205513, 0.0836373, 0.421857, 0.0459771, -0.144955, 0.0720673, -0.0300906, -0.0890598, -0.135266, -0.0413375, 0.0459032, 0.0414126, 0.272303, 0.0393149, 0.218258, 0.0944405, 0.0431211, -0.132394, 0.103489, 0.0142918, -0.123408, 0.0401075, -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544, 0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } // seq_len = 1 { seq_len = 1; migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; std::vector input_data1{ -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331}; migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape1, input_data1}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data}); auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data}); auto und = p.add_instruction(migraphx::op::undefined{}); auto hs = p.add_instruction( migraphx::op::lstm{ hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::forward, clip, 0}, seq, w, r, bias, und, ih, ic, pph); p.add_instruction(migraphx::op::rnn_last_output{}, hs); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector hs_data; hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); }); std::vector hs_data_gold{0.079753, -0.289854, 0.160043, 0.115056, 0.294074, -0.0319677, -0.0955337, 0.104168, 0.022618, -0.121195, -0.4065, -0.252054}; EXPECT(migraphx::verify_range(hs_data, hs_data_gold)); } } TEST_CASE(lstm_reverse) { std::size_t batch_size = 3; std::size_t seq_len = 4; std::size_t hidden_size = 4; std::size_t input_size = 3; std::size_t num_dirct = 1; std::vector w_data{ -0.2763, -0.4715, -0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390, -0.1843, 0.2351, 0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914, -0.1480, 0.3734, -0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401, 0.1486, 0.1346, 0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169, -0.4462, 0.0729, 0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522}; std::vector r_data{ -0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116, 0.2824, 0.4775, -0.2729, -0.4707, 0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687, 0.3794, -0.1069, -0.3049, 0.1430, -0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425, 0.2891, 0.1786, -0.3274, 0.2365, 0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154, 0.2851, -0.3930, -0.1174, 0.4360, 0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459, -0.2991, -0.2624, 0.4194, -0.3291, -0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584, 0.2596, 0.2871, -0.3509, -0.1910, 0.3987, -0.1687, -0.0032, -0.1038}; std::vector bias_data{-0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632, -0.1823, 0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896, 0.4755, -0.4794, 0.2167, -0.4474, -0.3139, 0.1018, 0.4470, -0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055, -0.4386, 0.4208, 0.0717, 0.3789}; std::vector input_data{ -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331, -0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122, 0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685, 0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310}; std::vector ih_data{1.5289, 1.0986, 0.6091, 1.6462, 0.8720, 0.5349, -0.1962, -1.7416, -0.9912, 1.2831, 1.0896, -0.6959}; std::vector ic_data{-0.8323, 0.3998, 0.1831, 0.5938, 2.7096, -0.1790, 0.0022, -0.8040, 0.1578, 0.0567, 0.8069, -0.5141}; std::vector pph_data{-0.8271, -0.5683, 0.4562, -1.2545, 1.2729, -0.4082, -0.4392, -0.9406, 0.7794, 1.8194, -0.5811, 0.2166}; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; float clip = 0.0f; // reverse, concatenation of hidden states as program output { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data}); auto und = p.add_instruction(migraphx::op::undefined{}); p.add_instruction( migraphx::op::lstm{ hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip, 0}, seq, w, r, bias, und, ih, ic, pph); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector output_data; hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); std::vector output_data_gold{ -0.120174, 0.043157, 0.117138, -0.222188, 0.789732, 0.128538, 0.20909, 0.0553812, -0.224905, 0.32421, 0.344048, 0.271694, -0.175114, -0.00543549, 0.178681, -0.266999, 0.928866, 0.113685, 0.220626, -0.0432316, -0.063456, 0.148524, 0.05108, -0.0234895, -0.182201, -0.0232277, 0.235501, -0.213485, 0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, 0.146353, 0.0789186, -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, 0.044508, -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537}; EXPECT(migraphx::verify_range(output_data, output_data_gold)); } // reverse, 3 args, last cell output as program output { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto hs = p.add_instruction( migraphx::op::lstm{ hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip, 0}, seq, w, r); p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector output_data; hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); std::vector output_data_gold{-0.443077, -0.325425, -0.249367, -0.270812, 0.122913, 0.118537, 0.0370199, -0.0164687, -0.00754759, 0.141613, 0.348002, 0.667298}; EXPECT(migraphx::verify_range(output_data, output_data_gold)); } // reverse, 3 args, 0 actv function { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto hs = p.add_instruction( migraphx::op::lstm{hidden_size, {}, migraphx::op::rnn_direction::reverse, clip, 0}, seq, w, r); p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector output_data; hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); std::vector output_data_gold{-0.443077, -0.325425, -0.249367, -0.270812, 0.122913, 0.118537, 0.0370199, -0.0164687, -0.00754759, 0.141613, 0.348002, 0.667298}; EXPECT(migraphx::verify_range(output_data, output_data_gold)); } // reverse, 3 args, 1 actv function { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); p.add_instruction(migraphx::op::lstm{hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::reverse, clip, 0}, seq, w, r); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector output_data; hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); std::vector output_data_gold{ 0.246078, 0.199709, 0.303753, 0.301178, 0.264634, 0.304661, 0.349371, 0.288934, 0.405483, 0.445586, 0.515814, 0.473186, 0.301937, 0.264893, 0.254353, 0.269231, 0.359258, 0.400097, 0.288884, 0.247329, 0.276519, 0.264249, 0.1769, 0.23213, 0.310306, 0.262902, 0.276964, 0.295002, 0.373802, 0.366785, 0.419791, 0.393216, 0.262827, 0.371441, 0.369022, 0.298262, 0.334143, 0.309444, 0.174822, 0.251634, 0.244564, 0.214386, 0.185994, 0.226699, 0.28445, 0.376092, 0.338326, 0.259502}; EXPECT(migraphx::verify_range(output_data, output_data_gold)); } // reverse, 3 args, 2 actv functions { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto hs = p.add_instruction(migraphx::op::lstm{hidden_size, {migraphx::op::tanh{}, migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::reverse, clip, 0}, seq, w, r); p.add_instruction(migraphx::op::rnn_last_output{}, hs); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector output_data; hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); std::vector output_data_gold{-0.132123, -0.37531, -0.12943, -0.00798307, -0.133882, -0.0251383, 0.0486486, -0.0220606, 0.292495, 0.233866, 0.48646, 0.481844}; EXPECT(migraphx::verify_range(output_data, output_data_gold)); } // reverse, 3 args, seq_len = 1, concatenation of hidden states as program output { seq_len = 1; std::vector input_data1{ -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331}; migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape1, input_data1}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); p.add_instruction( migraphx::op::lstm{ hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::reverse, clip, 0}, seq, w, r); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector output_data; hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); std::vector output_data_gold{-0.104351, -0.0471426, -0.0905753, 0.01506, 0.059797, 0.104239, -0.0266768, 0.0727547, -0.146298, 0.070535, 0.327809, 0.407388}; EXPECT(migraphx::verify_range(output_data, output_data_gold)); } } TEST_CASE(lstm_bidirectional) { std::size_t batch_size = 3; std::size_t seq_len = 4; std::size_t hidden_size = 4; std::size_t input_size = 3; std::size_t num_dirct = 2; std::vector w_data{ 0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344, 0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382, -0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279, -0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976, -0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285, -0.2763, -0.4715, -0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390, -0.1843, 0.2351, 0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914, -0.1480, 0.3734, -0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401, 0.1486, 0.1346, 0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169, -0.4462, 0.0729, 0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522}; std::vector r_data{ 0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685, -0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858, -0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265, -0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562, 0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100, 0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161, -0.2169, -0.1344, 0.3468, -0.2260, -0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116, 0.2824, 0.4775, -0.2729, -0.4707, 0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687, 0.3794, -0.1069, -0.3049, 0.1430, -0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425, 0.2891, 0.1786, -0.3274, 0.2365, 0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154, 0.2851, -0.3930, -0.1174, 0.4360, 0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459, -0.2991, -0.2624, 0.4194, -0.3291, -0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584, 0.2596, 0.2871, -0.3509, -0.1910, 0.3987, -0.1687, -0.0032, -0.1038}; std::vector bias_data{ 0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881, -0.4407, -0.2760, 0.1274, -0.0083, -0.2885, 0.3949, -0.0182, 0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067, 0.0807, 0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316, -0.3025, 0.3637, -0.3181, -0.4655, -0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632, -0.1823, 0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896, 0.4755, -0.4794, 0.2167, -0.4474, -0.3139, 0.1018, 0.4470, -0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055, -0.4386, 0.4208, 0.0717, 0.3789}; std::vector input_data{ -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331, -0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122, 0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685, 0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310}; std::vector ih_data{1.9104, -1.9004, 0.3337, 0.5741, 0.5671, 0.0458, 0.4514, -0.8968, -0.9201, 0.1962, 0.5771, -0.5332, 1.5289, 1.0986, 0.6091, 1.6462, 0.8720, 0.5349, -0.1962, -1.7416, -0.9912, 1.2831, 1.0896, -0.6959}; std::vector ic_data{0.9569, -0.5981, 1.1312, 1.0945, 1.1055, -0.1212, -0.9097, 0.7831, -1.6991, -1.9498, -1.2567, -0.4114, -0.8323, 0.3998, 0.1831, 0.5938, 2.7096, -0.1790, 0.0022, -0.8040, 0.1578, 0.0567, 0.8069, -0.5141}; std::vector pph_data{1.84369764, 0.68413646, -0.44892886, -1.50904413, 0.3860796, -0.52186625, 1.08474445, -1.80867321, 1.32594529, 0.4336262, -0.83699064, 0.49162736, -0.8271, -0.5683, 0.4562, -1.2545, 1.2729, -0.4082, -0.4392, -0.9406, 0.7794, 1.8194, -0.5811, 0.2166}; float clip = 0.0f; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}}; migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}}; migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}}; migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}}; // concatenation of hidden states as program output { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data}); auto und = p.add_instruction(migraphx::op::undefined{}); p.add_instruction( migraphx::op::lstm{ hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip, 0}, seq, w, r, bias, und, ih, ic, pph); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector output_data; hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); std::vector output_data_gold{ 0.079753, -0.289854, 0.160043, 0.115056, 0.294074, -0.0319677, -0.0955337, 0.104168, 0.022618, -0.121195, -0.4065, -0.252054, -0.120174, 0.043157, 0.117138, -0.222188, 0.789732, 0.128538, 0.20909, 0.0553812, -0.224905, 0.32421, 0.344048, 0.271694, 0.186991, -0.0624168, 0.205513, 0.0836373, 0.421857, 0.0459771, -0.144955, 0.0720673, -0.0300906, -0.0890598, -0.135266, -0.0413375, -0.175114, -0.00543549, 0.178681, -0.266999, 0.928866, 0.113685, 0.220626, -0.0432316, -0.063456, 0.148524, 0.05108, -0.0234895, 0.0459032, 0.0414126, 0.272303, 0.0393149, 0.218258, 0.0944405, 0.0431211, -0.132394, 0.103489, 0.0142918, -0.123408, 0.0401075, -0.182201, -0.0232277, 0.235501, -0.213485, 0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, 0.146353, 0.0789186, -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544, 0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723, -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, 0.044508, -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537}; EXPECT(migraphx::verify_range(output_data, output_data_gold)); } // last hidden state as program output { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data}); auto und = p.add_instruction(migraphx::op::undefined{}); auto hs = p.add_instruction( migraphx::op::lstm{ hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip, 0}, seq, w, r, bias, und, ih, ic, pph); p.add_instruction(migraphx::op::rnn_last_output{}, hs); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector output_data; hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); std::vector output_data_gold{ -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544, 0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723, -0.120174, 0.043157, 0.117138, -0.222188, 0.789732, 0.128538, 0.20909, 0.0553812, -0.224905, 0.32421, 0.344048, 0.271694}; EXPECT(migraphx::verify_range(output_data, output_data_gold)); } // last cell output as program output { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data}); auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto bias = p.add_literal(migraphx::literal{b_shape, bias_data}); auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data}); auto und = p.add_instruction(migraphx::op::undefined{}); auto hs = p.add_instruction( migraphx::op::lstm{ hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip, 0}, seq, w, r, bias, und, ih, ic, pph); p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector output_data; hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); std::vector output_data_gold{ -0.077353, 0.245616, 0.361023, -0.0443759, 0.0685243, 0.20465, 0.277867, -0.112934, 0.67312, 0.120508, -0.726968, 0.113845, -0.889294, 0.182463, 0.186512, -0.402334, 1.48161, 0.524116, 0.347113, 0.181813, -0.434265, 0.747833, 0.416053, 0.558713}; EXPECT(migraphx::verify_range(output_data, output_data_gold)); } // 3 args, concatenation of hidden states as program output { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); p.add_instruction( migraphx::op::lstm{ hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip, 0}, seq, w, r); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector output_data; hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); std::vector output_data_gold{ -0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698, -0.0623361, 0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617, -0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876, 0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805, -0.0786602, -0.0613048, 0.179592, -0.071286, 0.074206, 0.0124086, -0.139544, 0.108016, -0.00973633, -0.0552699, 0.0252681, -0.0562072, -0.123496, -0.153616, -0.032874, -0.195349, 0.0192675, -0.108636, 0.098927, -0.140733, 0.162602, 0.0143099, -0.0455534, 0.0151574, -0.102509, -0.0372696, 0.252296, -0.144544, 0.00496085, 0.0662588, -0.048577, -0.187329, 0.0855831, -0.0171894, -0.140202, 0.0828391, -0.1073, -0.150145, 0.015065, -0.192699, -0.112764, -0.120496, 0.155754, 0.148256, 0.208491, 0.348432, 0.0291103, 0.230275, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, -0.021205, -0.125423, 0.0206439, -0.187097, -0.0051453, -0.0767618, -0.0735348, -0.0826436, 0.214159, 0.262295, 0.0247127, 0.14472}; EXPECT(migraphx::verify_range(output_data, output_data_gold)); } // sequence length is 1, contenation of hidden state as program output { migraphx::program p; seq_len = 1; migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; std::vector input_data1{ -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331}; auto seq = p.add_literal(migraphx::literal{in_shape1, input_data1}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); p.add_instruction( migraphx::op::lstm{ hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip, 0}, seq, w, r); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector output_data; hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); std::vector output_data_gold{ -0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698, -0.0623361, 0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617, -0.104351, -0.0471426, -0.0905753, 0.01506, 0.059797, 0.104239, -0.0266768, 0.0727547, -0.146298, 0.070535, 0.327809, 0.407388}; EXPECT(migraphx::verify_range(output_data, output_data_gold)); } } TEST_CASE(lstm_bidirectional_actv_func) { std::size_t batch_size = 3; std::size_t seq_len = 4; std::size_t hidden_size = 4; std::size_t input_size = 3; std::size_t num_dirct = 2; std::vector w_data{ 0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344, 0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382, -0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279, -0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976, -0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285, -0.2763, -0.4715, -0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390, -0.1843, 0.2351, 0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914, -0.1480, 0.3734, -0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401, 0.1486, 0.1346, 0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169, -0.4462, 0.0729, 0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522}; std::vector r_data{ 0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685, -0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858, -0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265, -0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562, 0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100, 0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161, -0.2169, -0.1344, 0.3468, -0.2260, -0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116, 0.2824, 0.4775, -0.2729, -0.4707, 0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687, 0.3794, -0.1069, -0.3049, 0.1430, -0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425, 0.2891, 0.1786, -0.3274, 0.2365, 0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154, 0.2851, -0.3930, -0.1174, 0.4360, 0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459, -0.2991, -0.2624, 0.4194, -0.3291, -0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584, 0.2596, 0.2871, -0.3509, -0.1910, 0.3987, -0.1687, -0.0032, -0.1038}; std::vector input_data{ -0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331, -0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122, 0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685, 0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310}; float clip = 0.0f; migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}}; migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}}; migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}}; // 3 args, 0 actv func { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); p.add_instruction( migraphx::op::lstm{ hidden_size, {}, migraphx::op::rnn_direction::bidirectional, clip, 0}, seq, w, r); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector output_data; hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); std::vector output_data_gold{ -0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698, -0.0623361, 0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617, -0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876, 0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805, -0.0786602, -0.0613048, 0.179592, -0.071286, 0.074206, 0.0124086, -0.139544, 0.108016, -0.00973633, -0.0552699, 0.0252681, -0.0562072, -0.123496, -0.153616, -0.032874, -0.195349, 0.0192675, -0.108636, 0.098927, -0.140733, 0.162602, 0.0143099, -0.0455534, 0.0151574, -0.102509, -0.0372696, 0.252296, -0.144544, 0.00496085, 0.0662588, -0.048577, -0.187329, 0.0855831, -0.0171894, -0.140202, 0.0828391, -0.1073, -0.150145, 0.015065, -0.192699, -0.112764, -0.120496, 0.155754, 0.148256, 0.208491, 0.348432, 0.0291103, 0.230275, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, -0.021205, -0.125423, 0.0206439, -0.187097, -0.0051453, -0.0767618, -0.0735348, -0.0826436, 0.214159, 0.262295, 0.0247127, 0.14472}; EXPECT(migraphx::verify_range(output_data, output_data_gold)); } // 3 args, 1 actv func { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); p.add_instruction(migraphx::op::lstm{hidden_size, {migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::bidirectional, clip, 0}, seq, w, r); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector output_data; hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); std::vector output_data_gold{ 0.227861, 0.328562, 0.277867, 0.272945, 0.204389, 0.296123, 0.223834, 0.311113, 0.424666, 0.173974, 0.40628, 0.286631, 0.246078, 0.199709, 0.303753, 0.301178, 0.264634, 0.304661, 0.349371, 0.288934, 0.405483, 0.445586, 0.515814, 0.473186, 0.339438, 0.29655, 0.331832, 0.242338, 0.409384, 0.236272, 0.306045, 0.26269, 0.261246, 0.334357, 0.23622, 0.245288, 0.301937, 0.264893, 0.254353, 0.269231, 0.359258, 0.400097, 0.288884, 0.247329, 0.276519, 0.264249, 0.1769, 0.23213, 0.374123, 0.283167, 0.377129, 0.245726, 0.444712, 0.203168, 0.411446, 0.269965, 0.172792, 0.296224, 0.17319, 0.352547, 0.310306, 0.262902, 0.276964, 0.295002, 0.373802, 0.366785, 0.419791, 0.393216, 0.262827, 0.371441, 0.369022, 0.298262, 0.450186, 0.263538, 0.402895, 0.216177, 0.267257, 0.342535, 0.257797, 0.268563, 0.193043, 0.275645, 0.167678, 0.350889, 0.334143, 0.309444, 0.174822, 0.251634, 0.244564, 0.214386, 0.185994, 0.226699, 0.28445, 0.376092, 0.338326, 0.259502}; EXPECT(migraphx::verify_range(output_data, output_data_gold)); } // 3 args, 2 actv func { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto hs = p.add_instruction(migraphx::op::lstm{hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip, 0}, seq, w, r); p.add_instruction(migraphx::op::rnn_last_output{}, hs); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector output_data; hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); std::vector output_data_gold{ -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, -0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876, 0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805}; EXPECT(migraphx::verify_range(output_data, output_data_gold)); } // 3 args, 4 actv func { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto hs = p.add_instruction(migraphx::op::lstm{hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}, migraphx::op::sigmoid{}}, migraphx::op::rnn_direction::bidirectional, clip, 0}, seq, w, r); p.add_instruction(migraphx::op::rnn_last_output{}, hs); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector output_data; hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); std::vector output_data_gold{ -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, 0.246078, 0.199709, 0.303753, 0.301178, 0.264634, 0.304661, 0.349371, 0.288934, 0.405483, 0.445586, 0.515814, 0.473186}; EXPECT(migraphx::verify_range(output_data, output_data_gold)); } // 3 args, 5 actv func { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto hs = p.add_instruction(migraphx::op::lstm{hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip, 0}, seq, w, r); p.add_instruction(migraphx::op::rnn_last_output{}, hs); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector output_data; hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); std::vector output_data_gold{ -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, -0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876, 0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805}; EXPECT(migraphx::verify_range(output_data, output_data_gold)); } // 3 args, 6 actv func { migraphx::program p; auto seq = p.add_literal(migraphx::literal{in_shape, input_data}); auto w = p.add_literal(migraphx::literal{w_shape, w_data}); auto r = p.add_literal(migraphx::literal{r_shape, r_data}); auto hs = p.add_instruction(migraphx::op::lstm{hidden_size, {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}}, migraphx::op::rnn_direction::bidirectional, clip, 0}, seq, w, r); p.add_instruction(migraphx::op::rnn_last_output{}, hs); p.compile(migraphx::cpu::target{}); auto hs_concat = p.eval({}); std::vector output_data; hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); }); std::vector output_data_gold{ -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544, -0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774, -0.162851, -0.102647, -0.113827, -0.142818, 0.0513685, 0.0547876, 0.0201981, -0.00808453, -0.00520328, 0.0945081, 0.264123, 0.410805}; EXPECT(migraphx::verify_range(output_data, output_data_gold)); } } int main(int argc, const char* argv[]) { test::run(argc, argv); }