Commit fbc53b14 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix clang format issues.

parent a40e58d3
......@@ -736,6 +736,64 @@ TEST_CASE(gru_forward)
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
}
TEST_CASE(gru_forward_args)
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}};
std::vector<float> w_data{
0.3485, -0.0378, -0.1782, 0.1416, -0.3096, -0.2212, -0.3883, 0.1983, -0.2418,
0.1480, -0.3255, 0.1359, -0.3551, -0.3605, -0.3482, -0.1424, -0.0495, -0.1640,
-0.1979, -0.2577, -0.4097, -0.1211, -0.0412, 0.1801, 0.1721, -0.4327, -0.0498,
0.2628, -0.1573, -0.1577, 0.2759, -0.2023, -0.1185, -0.2136, 0.1294, -0.2331,
0.0701, 0.4316, 0.0480, 0.0247, -0.0166, -0.2729, 0.1712, -0.3984, -0.3905};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}};
std::vector<float> r_data{
0.2848, -0.2851, -0.3466, -0.1718, -0.1492, -0.0082, 0.2452, -0.0401, 0.3399, 0.2529,
-0.0953, -0.0903, -0.1518, -0.1373, 0.3848, -0.0130, -0.4339, 0.0406, -0.1926, -0.1131,
0.4285, -0.0013, 0.2243, 0.2752, 0.1776, -0.1720, 0.0822, -0.0295, 0.1062, -0.2721,
-0.2736, -0.1826, 0.3541, -0.4259, 0.2188, 0.0706, 0.3650, 0.3947, 0.2522, 0.2179,
-0.0744, 0.2122, -0.4346, 0.2760, 0.4076, 0.1183, -0.1500, -0.1704, 0.3090, -0.0706,
-0.2442, 0.3021, 0.1680, 0.0783, -0.3754, -0.3469, -0.2972, -0.0170, 0.4143, 0.3801,
0.3852, -0.1170, -0.2937, 0.2979, -0.1357, 0.4257, 0.3884, -0.2916, 0.1071, 0.0934,
0.3645, -0.4310, -0.3480, 0.0702, -0.1558};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
std::vector<float> bias_data{
0.0560, 0.0310, -0.1669, -0.0781, 0.1793, -0.1758, 0.3173, -0.1650, -0.3732, 0.2946,
-0.0912, 0.3118, 0.1391, 0.2755, 0.2695, -0.1059, -0.2357, 0.3629, -0.2534, -0.0494,
0.0556, 0.0881, -0.2592, -0.2213, 0.2310, -0.4044, 0.1801, 0.1438, 0.3108, -0.3607};
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input{-0.8432,
-0.9887,
1.3041,
-2.6430,
-0.3306,
-0.8504,
-0.3933,
0.5151,
-0.2951,
0.0093,
-1.1948,
-0.1239,
0.0373,
1.3211,
0.7854,
-0.4838,
-1.0536,
-0.2529};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
std::vector<float> ih_data{
-0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212};
float clip = 0.0f;
// 3 args
{
......@@ -833,6 +891,64 @@ TEST_CASE(gru_forward)
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
}
TEST_CASE(gru_forward_actv_funcs)
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}};
std::vector<float> w_data{
0.3485, -0.0378, -0.1782, 0.1416, -0.3096, -0.2212, -0.3883, 0.1983, -0.2418,
0.1480, -0.3255, 0.1359, -0.3551, -0.3605, -0.3482, -0.1424, -0.0495, -0.1640,
-0.1979, -0.2577, -0.4097, -0.1211, -0.0412, 0.1801, 0.1721, -0.4327, -0.0498,
0.2628, -0.1573, -0.1577, 0.2759, -0.2023, -0.1185, -0.2136, 0.1294, -0.2331,
0.0701, 0.4316, 0.0480, 0.0247, -0.0166, -0.2729, 0.1712, -0.3984, -0.3905};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}};
std::vector<float> r_data{
0.2848, -0.2851, -0.3466, -0.1718, -0.1492, -0.0082, 0.2452, -0.0401, 0.3399, 0.2529,
-0.0953, -0.0903, -0.1518, -0.1373, 0.3848, -0.0130, -0.4339, 0.0406, -0.1926, -0.1131,
0.4285, -0.0013, 0.2243, 0.2752, 0.1776, -0.1720, 0.0822, -0.0295, 0.1062, -0.2721,
-0.2736, -0.1826, 0.3541, -0.4259, 0.2188, 0.0706, 0.3650, 0.3947, 0.2522, 0.2179,
-0.0744, 0.2122, -0.4346, 0.2760, 0.4076, 0.1183, -0.1500, -0.1704, 0.3090, -0.0706,
-0.2442, 0.3021, 0.1680, 0.0783, -0.3754, -0.3469, -0.2972, -0.0170, 0.4143, 0.3801,
0.3852, -0.1170, -0.2937, 0.2979, -0.1357, 0.4257, 0.3884, -0.2916, 0.1071, 0.0934,
0.3645, -0.4310, -0.3480, 0.0702, -0.1558};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
std::vector<float> bias_data{
0.0560, 0.0310, -0.1669, -0.0781, 0.1793, -0.1758, 0.3173, -0.1650, -0.3732, 0.2946,
-0.0912, 0.3118, 0.1391, 0.2755, 0.2695, -0.1059, -0.2357, 0.3629, -0.2534, -0.0494,
0.0556, 0.0881, -0.2592, -0.2213, 0.2310, -0.4044, 0.1801, 0.1438, 0.3108, -0.3607};
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input{-0.8432,
-0.9887,
1.3041,
-2.6430,
-0.3306,
-0.8504,
-0.3933,
0.5151,
-0.2951,
0.0093,
-1.1948,
-0.1239,
0.0373,
1.3211,
0.7854,
-0.4838,
-1.0536,
-0.2529};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
std::vector<float> ih_data{
-0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212};
float clip = 0.0f;
// no activation function specified, so default is used.
{
......@@ -1422,6 +1538,82 @@ TEST_CASE(gru_bidirectional)
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
}
TEST_CASE(gru_bidirectional_args)
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}};
std::vector<float> w_data{
0.3809, 0.4283, 0.2294, -0.1018, -0.1226, -0.0037, 0.2449, -0.2712, -0.1418,
0.1363, -0.3453, -0.0693, -0.2281, 0.2699, -0.2024, -0.3085, -0.3338, 0.4109,
0.2605, -0.1019, -0.2813, 0.3323, -0.1590, 0.0788, -0.3535, 0.0397, 0.2732,
0.2906, 0.0519, 0.3617, -0.2664, 0.1441, 0.0464, -0.1057, 0.2204, -0.3294,
0.3670, 0.1411, 0.3852, 0.3572, 0.3918, 0.0483, -0.3906, -0.2841, -0.2778,
-0.4272, 0.2335, -0.1811, -0.3885, -0.1279, 0.1000, 0.0206, -0.3284, -0.0353,
0.1197, 0.1190, 0.3862, 0.0965, -0.0492, 0.2657, -0.1430, 0.0597, 0.1408,
-0.0315, 0.1248, 0.0751, 0.3838, 0.3020, 0.0515, 0.2375, -0.4255, 0.1714,
-0.0432, 0.3447, -0.2441, -0.3989, -0.3428, -0.4204, -0.4080, -0.2683, -0.0996,
-0.1685, -0.0532, -0.1258, 0.1663, -0.3526, -0.3915, -0.1721, 0.1292, -0.2279};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}};
std::vector<float> r_data{
-0.2683, 0.0699, -0.4021, -0.1379, 0.0042, -0.2447, 0.4006, 0.0270, -0.0446, 0.1063,
0.1381, 0.1310, -0.3596, 0.3869, 0.3929, 0.2750, 0.0890, 0.3069, -0.1691, -0.2194,
-0.1066, 0.3187, -0.4369, -0.0603, -0.0834, -0.1182, -0.2047, 0.3253, -0.2931, 0.2082,
0.0424, 0.1111, -0.2773, -0.0279, -0.0869, 0.1413, -0.4227, -0.3672, 0.4137, 0.0609,
0.4223, -0.4032, 0.2945, 0.3600, 0.3345, -0.3880, -0.0192, -0.0090, -0.2648, 0.4339,
-0.0155, 0.4437, -0.1766, 0.1957, 0.2475, 0.3773, -0.2710, 0.3289, -0.2077, -0.2534,
-0.0832, -0.1632, 0.0728, 0.2520, 0.4153, 0.1659, -0.4342, 0.0541, 0.1812, -0.2305,
0.4440, 0.0946, 0.0410, -0.4381, -0.3161, 0.3906, -0.3958, -0.4238, 0.1975, 0.3440,
0.1437, -0.0568, 0.1492, -0.4248, -0.3304, 0.2786, -0.1328, -0.3740, -0.3566, 0.3074,
0.0924, 0.2684, -0.1527, 0.1826, 0.2424, 0.2002, 0.3479, -0.1089, 0.3472, -0.3677,
-0.4231, -0.0798, -0.3709, 0.3924, 0.2774, -0.3690, -0.0233, 0.2845, 0.1969, 0.1618,
-0.3742, -0.3619, 0.2925, -0.1838, -0.1495, -0.3747, 0.0341, -0.4243, -0.0732, -0.3997,
0.2139, 0.2425, 0.4171, -0.3358, 0.3534, 0.0938, -0.0582, -0.2681, -0.4293, 0.1027,
0.4101, 0.2641, -0.4110, -0.1681, 0.3582, -0.2089, 0.0852, 0.0963, 0.3866, 0.1955,
-0.2174, 0.1996, -0.2252, 0.1748, 0.1833, -0.3155, 0.2567, -0.4387, 0.3402, 0.0599};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
std::vector<float> bias_data{
-0.1582, -0.0826, 0.4008, 0.0118, 0.2511, 0.1900, -0.2838, 0.2549, -0.2484, 0.2363,
-0.4083, -0.0295, -0.1161, 0.1211, 0.2509, -0.1414, -0.2628, -0.2992, 0.1517, 0.1817,
-0.2783, 0.3183, -0.1629, -0.3108, -0.3418, 0.0411, 0.2203, 0.2187, -0.2990, -0.0416,
0.0209, -0.1024, 0.4443, -0.4420, -0.0330, -0.3591, -0.2990, 0.2167, 0.1395, 0.2317,
0.1318, 0.1909, -0.3615, 0.1953, -0.2582, -0.2217, 0.3723, 0.1458, 0.2630, -0.0377,
0.1754, 0.0800, -0.3964, -0.3247, 0.4219, -0.0900, 0.3553, 0.2614, -0.1298, -0.1124};
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input{-0.8432,
-0.9887,
1.3041,
-2.6430,
-0.3306,
-0.8504,
-0.3933,
0.5151,
-0.2951,
0.0093,
-1.1948,
-0.1239,
0.0373,
1.3211,
0.7854,
-0.4838,
-1.0536,
-0.2529};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
std::vector<float> ih_data{-0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348,
0.4178, 1.0175, 0.9212, -0.0468, 0.5691, -0.0882, 0.8340,
0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212};
float clip = 0.0f;
// 3 args
{
......@@ -1530,6 +1722,82 @@ TEST_CASE(gru_bidirectional)
-0.0339407, 0.413089, 0.721238, 0.431879};
EXPECT(migraphx::verify_range(hs_data, hs_data_gold));
}
}
TEST_CASE(gru_bidirectional_actv_funcs)
{
std::size_t batch_size = 2;
std::size_t seq_len = 3;
std::size_t hidden_size = 5;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, input_size}};
std::vector<float> w_data{
0.3809, 0.4283, 0.2294, -0.1018, -0.1226, -0.0037, 0.2449, -0.2712, -0.1418,
0.1363, -0.3453, -0.0693, -0.2281, 0.2699, -0.2024, -0.3085, -0.3338, 0.4109,
0.2605, -0.1019, -0.2813, 0.3323, -0.1590, 0.0788, -0.3535, 0.0397, 0.2732,
0.2906, 0.0519, 0.3617, -0.2664, 0.1441, 0.0464, -0.1057, 0.2204, -0.3294,
0.3670, 0.1411, 0.3852, 0.3572, 0.3918, 0.0483, -0.3906, -0.2841, -0.2778,
-0.4272, 0.2335, -0.1811, -0.3885, -0.1279, 0.1000, 0.0206, -0.3284, -0.0353,
0.1197, 0.1190, 0.3862, 0.0965, -0.0492, 0.2657, -0.1430, 0.0597, 0.1408,
-0.0315, 0.1248, 0.0751, 0.3838, 0.3020, 0.0515, 0.2375, -0.4255, 0.1714,
-0.0432, 0.3447, -0.2441, -0.3989, -0.3428, -0.4204, -0.4080, -0.2683, -0.0996,
-0.1685, -0.0532, -0.1258, 0.1663, -0.3526, -0.3915, -0.1721, 0.1292, -0.2279};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size, hidden_size}};
std::vector<float> r_data{
-0.2683, 0.0699, -0.4021, -0.1379, 0.0042, -0.2447, 0.4006, 0.0270, -0.0446, 0.1063,
0.1381, 0.1310, -0.3596, 0.3869, 0.3929, 0.2750, 0.0890, 0.3069, -0.1691, -0.2194,
-0.1066, 0.3187, -0.4369, -0.0603, -0.0834, -0.1182, -0.2047, 0.3253, -0.2931, 0.2082,
0.0424, 0.1111, -0.2773, -0.0279, -0.0869, 0.1413, -0.4227, -0.3672, 0.4137, 0.0609,
0.4223, -0.4032, 0.2945, 0.3600, 0.3345, -0.3880, -0.0192, -0.0090, -0.2648, 0.4339,
-0.0155, 0.4437, -0.1766, 0.1957, 0.2475, 0.3773, -0.2710, 0.3289, -0.2077, -0.2534,
-0.0832, -0.1632, 0.0728, 0.2520, 0.4153, 0.1659, -0.4342, 0.0541, 0.1812, -0.2305,
0.4440, 0.0946, 0.0410, -0.4381, -0.3161, 0.3906, -0.3958, -0.4238, 0.1975, 0.3440,
0.1437, -0.0568, 0.1492, -0.4248, -0.3304, 0.2786, -0.1328, -0.3740, -0.3566, 0.3074,
0.0924, 0.2684, -0.1527, 0.1826, 0.2424, 0.2002, 0.3479, -0.1089, 0.3472, -0.3677,
-0.4231, -0.0798, -0.3709, 0.3924, 0.2774, -0.3690, -0.0233, 0.2845, 0.1969, 0.1618,
-0.3742, -0.3619, 0.2925, -0.1838, -0.1495, -0.3747, 0.0341, -0.4243, -0.0732, -0.3997,
0.2139, 0.2425, 0.4171, -0.3358, 0.3534, 0.0938, -0.0582, -0.2681, -0.4293, 0.1027,
0.4101, 0.2641, -0.4110, -0.1681, 0.3582, -0.2089, 0.0852, 0.0963, 0.3866, 0.1955,
-0.2174, 0.1996, -0.2252, 0.1748, 0.1833, -0.3155, 0.2567, -0.4387, 0.3402, 0.0599};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 6 * hidden_size}};
std::vector<float> bias_data{
-0.1582, -0.0826, 0.4008, 0.0118, 0.2511, 0.1900, -0.2838, 0.2549, -0.2484, 0.2363,
-0.4083, -0.0295, -0.1161, 0.1211, 0.2509, -0.1414, -0.2628, -0.2992, 0.1517, 0.1817,
-0.2783, 0.3183, -0.1629, -0.3108, -0.3418, 0.0411, 0.2203, 0.2187, -0.2990, -0.0416,
0.0209, -0.1024, 0.4443, -0.4420, -0.0330, -0.3591, -0.2990, 0.2167, 0.1395, 0.2317,
0.1318, 0.1909, -0.3615, 0.1953, -0.2582, -0.2217, 0.3723, 0.1458, 0.2630, -0.0377,
0.1754, 0.0800, -0.3964, -0.3247, 0.4219, -0.0900, 0.3553, 0.2614, -0.1298, -0.1124};
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input{-0.8432,
-0.9887,
1.3041,
-2.6430,
-0.3306,
-0.8504,
-0.3933,
0.5151,
-0.2951,
0.0093,
-1.1948,
-0.1239,
0.0373,
1.3211,
0.7854,
-0.4838,
-1.0536,
-0.2529};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
std::vector<float> ih_data{-0.0468, 0.5691, -0.0882, 0.8340, 0.1483, -0.3902, -0.5348,
0.4178, 1.0175, 0.9212, -0.0468, 0.5691, -0.0882, 0.8340,
0.1483, -0.3902, -0.5348, 0.4178, 1.0175, 0.9212};
float clip = 0.0f;
// no activation function specified, so default is used.
{
......
......@@ -741,6 +741,16 @@ TEST_CASE(gru_test)
EXPECT(p == prog);
}
}
TEST_CASE(gru_test_args)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
// 3 arguments
{
......@@ -836,7 +846,16 @@ TEST_CASE(gru_test)
EXPECT(p == prog);
}
}
TEST_CASE(gru_test_actv_funcs)
{
std::size_t sl = 5; // sequence len
std::size_t bs = 3; // batch size
std::size_t hs = 20; // hidden size
std::size_t is = 10; // input size
std::size_t nd = 2; // num directions
float clip = 0.0f;
// bidirection, 0 actv function
{
nd = 2;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment