Commit dcaf8fd3 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add bidirectional test case for the lstm operator

parent 731f2dbb
...@@ -2730,7 +2730,7 @@ TEST_CASE(lstm_reverse) ...@@ -2730,7 +2730,7 @@ TEST_CASE(lstm_reverse)
EXPECT(migraphx::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify_range(output_data, output_data_gold));
} }
// reverse, 3 args, seq_len = 1, last output as program output // reverse, 3 args, seq_len = 1, concatenation of hidden states as program output
{ {
seq_len = 1; seq_len = 1;
std::vector<float> input_data1{ std::vector<float> input_data1{
...@@ -2772,4 +2772,354 @@ TEST_CASE(lstm_reverse) ...@@ -2772,4 +2772,354 @@ TEST_CASE(lstm_reverse)
} }
} }
TEST_CASE(lstm_bidirectional)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
std::vector<float> w_data{
0.1236, -0.3942, 0.4149,
0.0795, 0.4934, -0.2858,
0.2602, -0.3098, 0.0567,
0.3344, 0.3607, -0.0551,
0.4952, 0.3799, 0.0630,
-0.3532, 0.0023, -0.0592,
0.4267, 0.2382, -0.0784,
-0.0032, -0.2476, -0.0206,
-0.4963, 0.4837, 0.0827,
0.0123, -0.1203, -0.0279,
-0.0049, 0.4721, -0.3564,
-0.1286, 0.4090, -0.0504,
0.0575, -0.2138, 0.1071,
0.1976, -0.0758, 0.0139,
-0.0761, 0.3991, -0.2965,
-0.4845, -0.1496, 0.3285,
-0.2763, -0.4715, -0.3010,
-0.2306, -0.2283, -0.2656,
0.2035, 0.3570, -0.1499,
0.4390, -0.1843, 0.2351,
0.3357, 0.1217, 0.1401,
0.3300, -0.0429, 0.3266,
0.4834, -0.3914, -0.1480,
0.3734, -0.0372, -0.1746,
0.0550, 0.4177, -0.1332,
0.4391, -0.3287, -0.4401,
0.1486, 0.1346, 0.1048,
-0.4361, 0.0886, -0.3840,
-0.2730, -0.1710, 0.3274,
0.0169, -0.4462, 0.0729,
0.3983, -0.0669, 0.0756,
0.4150, -0.4684, -0.2522};
std::vector<float> r_data{
0.1237, 0.1229, -0.0766, -0.1144,
-0.1186, 0.2922, 0.2478, 0.3159,
-0.0522, 0.1685, -0.4621, 0.1728,
0.0670, -0.2458, -0.3835, -0.4589,
-0.3109, 0.4908, -0.0133, -0.1858,
-0.0590, -0.0347, -0.2353, -0.0671,
-0.3812, -0.0004, -0.1432, 0.2406,
0.1033, -0.0265, -0.3902, 0.0755,
0.3733, 0.4383, -0.3140, 0.2537,
-0.1818, -0.4127, 0.3506, 0.2562,
0.2926, 0.1620, -0.4849, -0.4861,
0.4426, 0.2106, -0.0005, 0.4418,
-0.2926, -0.3100, 0.1500, -0.0362,
-0.3801, -0.0065, -0.0631, 0.1277,
0.2315, 0.4087, -0.3963, -0.4161,
-0.2169, -0.1344, 0.3468, -0.2260,
-0.4564, -0.4432, 0.1605, 0.4387,
0.0034, 0.4116, 0.2824, 0.4775,
-0.2729, -0.4707, 0.1363, 0.2218,
0.0559, 0.2828, 0.2093, 0.4687,
0.3794, -0.1069, -0.3049, 0.1430,
-0.2506, 0.4644, 0.2755, -0.3645,
-0.3155, 0.1425, 0.2891, 0.1786,
-0.3274, 0.2365, 0.2522, -0.4312,
-0.0562, -0.2748, 0.0776, -0.3154,
0.2851, -0.3930, -0.1174, 0.4360,
0.2436, 0.0164, -0.0680, 0.3403,
-0.2857, -0.0459, -0.2991, -0.2624,
0.4194, -0.3291, -0.4659, 0.3300,
0.0454, 0.4981, -0.4706, -0.4584,
0.2596, 0.2871, -0.3509, -0.1910,
0.3987, -0.1687, -0.0032, -0.1038};
std::vector<float> bias_data{
0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881, -0.4407,
-0.2760, 0.1274, -0.0083, -0.2885, 0.3949, -0.0182, 0.4445, 0.3477,
0.2266, 0.3423, -0.0674, -0.4067, 0.0807, 0.1109, -0.2036, 0.1782,
-0.2467, -0.0730, -0.4216, 0.0316, -0.3025, 0.3637, -0.3181, -0.4655,
-0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632, -0.1823,
0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896, 0.4755, -0.4794,
0.2167, -0.4474, -0.3139, 0.1018, 0.4470, -0.4232, 0.3247, -0.1636,
-0.1582, -0.1703, 0.3920, 0.2055, -0.4386, 0.4208, 0.0717, 0.3789};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951,
-0.4313, -0.9730, -0.2005,
2.3930, -0.5221, -0.1331,
-0.0910, 1.2122, -0.1952,
0.4661, 0.6494, 2.1332,
-1.0972, 0.9816, 0.1122,
0.3577, 1.3508, -0.5366,
1.7449, 0.5483, -0.0701,
-0.4100, -2.2344, 0.3685,
0.4583, 2.3794, 1.0372,
-0.8887, 0.7892, -0.4012,
-0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{
1.9104, -1.9004, 0.3337, 0.5741,
0.5671, 0.0458, 0.4514, -0.8968,
-0.9201, 0.1962, 0.5771, -0.5332,
1.5289, 1.0986, 0.6091, 1.6462,
0.8720, 0.5349, -0.1962, -1.7416,
-0.9912, 1.2831, 1.0896, -0.6959};
std::vector<float> ic_data{
0.9569, -0.5981, 1.1312, 1.0945,
1.1055, -0.1212, -0.9097, 0.7831,
-1.6991, -1.9498, -1.2567, -0.4114,
-0.8323, 0.3998, 0.1831, 0.5938,
2.7096, -0.1790, 0.0022, -0.8040,
0.1578, 0.0567, 0.8069, -0.5141};
std::vector<float> pph_data{
1.84369764, 0.68413646, -0.44892886, -1.50904413,
0.3860796, -0.52186625, 1.08474445, -1.80867321,
1.32594529, 0.4336262, -0.83699064, 0.49162736,
-0.8271, -0.5683, 0.4562, -1.2545,
1.2729, -0.4082, -0.4392, -0.9406,
0.7794, 1.8194, -0.5811, 0.2166};
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4*hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4*hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8*hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3*hidden_size}};
// concatenation of hidden states as program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data});
auto und = p.add_instruction(migraphx::op::undefined{});
p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
0.079753, -0.289854, 0.160043, 0.115056,
0.294074, -0.0319677, -0.0955337, 0.104168,
0.022618, -0.121195, -0.4065, -0.252054,
-0.120174, 0.043157, 0.117138, -0.222188,
0.789732, 0.128538, 0.20909, 0.0553812,
-0.224905, 0.32421, 0.344048, 0.271694,
0.186991, -0.0624168, 0.205513, 0.0836373,
0.421857, 0.0459771, -0.144955, 0.0720673,
-0.0300906, -0.0890598, -0.135266, -0.0413375,
-0.175114, -0.00543549, 0.178681, -0.266999,
0.928866, 0.113685, 0.220626, -0.0432316,
-0.063456, 0.148524, 0.05108, -0.0234895,
0.0459032, 0.0414126, 0.272303, 0.0393149,
0.218258, 0.0944405, 0.0431211, -0.132394,
0.103489, 0.0142918, -0.123408, 0.0401075,
-0.182201, -0.0232277, 0.235501, -0.213485,
0.960938, 0.133565, 0.269741, 0.130438,
-0.0252804, 0.267356, 0.146353, 0.0789186,
-0.058052, 0.0795391, 0.266617, -0.0128746,
0.0309878, 0.0971544, 0.149294, -0.0492549,
0.187761, 0.0501726, -0.121584, 0.0606723,
-0.185038, -0.026845, 0.177273, -0.0774616,
0.946669, 0.0868676, 0.044508, -0.373961,
-0.0681467, 0.382748, 0.230211, -0.161537};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// last hidden state as program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto hs = p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
p.add_instruction(migraphx::op::rnn_last_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.058052, 0.0795391, 0.266617, -0.0128746,
0.0309878, 0.0971544, 0.149294, -0.0492549,
0.187761, 0.0501726, -0.121584, 0.0606723,
-0.120174, 0.043157, 0.117138, -0.222188,
0.789732, 0.128538, 0.20909, 0.0553812,
-0.224905, 0.32421, 0.344048, 0.271694};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// last cell output as program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto ih = p.add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = p.add_literal(migraphx::literal{ic_shape, ic_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
auto bias = p.add_literal(migraphx::literal{b_shape, bias_data});
auto pph = p.add_literal(migraphx::literal{pph_shape, pph_data});
auto und = p.add_instruction(migraphx::op::undefined{});
auto hs = p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
p.add_instruction(migraphx::op::lstm_last_cell_output{}, hs);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.077353, 0.245616, 0.361023, -0.0443759,
0.0685243, 0.20465, 0.277867, -0.112934,
0.67312, 0.120508, -0.726968, 0.113845,
-0.889294, 0.182463, 0.186512, -0.402334,
1.48161, 0.524116, 0.347113, 0.181813,
-0.434265, 0.747833, 0.416053, 0.558713};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// 3 args, concatenation of hidden states as program output
{
migraphx::program p;
auto seq = p.add_literal(migraphx::literal{in_shape, input_data});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq,
w,
r);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.0327039, -0.0543852, 0.114378, -0.0768855,
0.0319021, -0.00298698, -0.0623361, 0.0598866,
0.101585, 0.0687269, -0.161725, -0.25617,
-0.162851, -0.102647, -0.113827, -0.142818,
0.0513685, 0.0547876, 0.0201981, -0.00808453,
-0.00520328, 0.0945081, 0.264123, 0.410805,
-0.0786602, -0.0613048, 0.179592, -0.071286,
0.074206, 0.0124086, -0.139544, 0.108016,
-0.00973633, -0.0552699, 0.0252681, -0.0562072,
-0.123496, -0.153616, -0.032874, -0.195349,
0.0192675, -0.108636, 0.098927, -0.140733,
0.162602, 0.0143099, -0.0455534, 0.0151574,
-0.102509, -0.0372696, 0.252296, -0.144544,
0.00496085, 0.0662588, -0.048577, -0.187329,
0.0855831, -0.0171894, -0.140202, 0.0828391,
-0.1073, -0.150145, 0.015065, -0.192699,
-0.112764, -0.120496, 0.155754, 0.148256,
0.208491, 0.348432, 0.0291103, 0.230275,
-0.165194, -0.0372928, 0.273786, -0.100877,
-0.0458544, -0.0401315, 0.0737483, -0.064505,
0.136898, 0.00160891, -0.184812, 0.147774,
-0.021205, -0.125423, 0.0206439, -0.187097,
-0.0051453, -0.0767618, -0.0735348, -0.0826436,
0.214159, 0.262295, 0.0247127, 0.14472};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
// sequence length is 1, contenation of hidden state as program output
{
migraphx::program p;
seq_len = 1;
migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input_data1{
-0.5516, 0.2391, -1.6951,
-0.4313, -0.9730, -0.2005,
2.3930, -0.5221, -0.1331};
auto seq = p.add_literal(migraphx::literal{in_shape1, input_data1});
auto w = p.add_literal(migraphx::literal{w_shape, w_data});
auto r = p.add_literal(migraphx::literal{r_shape, r_data});
p.add_instruction(migraphx::op::lstm{hidden_size,
{migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
migraphx::op::rnn_direction::bidirectional,
clip,
0},
seq,
w,
r);
p.compile(migraphx::cpu::target{});
auto hs_concat = p.eval({});
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.0327039, -0.0543852, 0.114378, -0.0768855,
0.0319021, -0.00298698, -0.0623361, 0.0598866,
0.101585, 0.0687269, -0.161725, -0.25617,
-0.104351, -0.0471426, -0.0905753, 0.01506,
0.059797, 0.104239, -0.0266768, 0.0727547,
-0.146298, 0.070535, 0.327809, 0.407388};
EXPECT(migraphx::verify_range(output_data, output_data_gold));
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment