Commit 717e6bc2 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add gpu test for the lstm operator.

parent 48c72797
...@@ -52,7 +52,6 @@ struct rewrite_rnn ...@@ -52,7 +52,6 @@ struct rewrite_rnn
program& prog, program& prog,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref> inputs, std::vector<instruction_ref> inputs,
int input_forget,
const operation& actv_func1, const operation& actv_func1,
const operation& actv_func2, const operation& actv_func2,
const operation& actv_func3) const; const operation& actv_func3) const;
......
...@@ -751,7 +751,6 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -751,7 +751,6 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
prog, prog,
ins, ins,
{args[0], w_forward, r_forward, bias_forward, ih_forward, ic_forward, pph_forward}, {args[0], w_forward, r_forward, bias_forward, ih_forward, ic_forward, pph_forward},
lstm_op.input_forget,
actv_funcs.at(0), actv_funcs.at(0),
actv_funcs.at(1), actv_funcs.at(1),
actv_funcs.at(2)); actv_funcs.at(2));
...@@ -761,7 +760,6 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -761,7 +760,6 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
prog, prog,
ins, ins,
{args[0], w_reverse, r_reverse, bias_reverse, ih_reverse, ic_reverse, pph_reverse}, {args[0], w_reverse, r_reverse, bias_reverse, ih_reverse, ic_reverse, pph_reverse},
lstm_op.input_forget,
actv_funcs.at(3), actv_funcs.at(3),
actv_funcs.at(4), actv_funcs.at(4),
actv_funcs.at(5)); actv_funcs.at(5));
...@@ -835,7 +833,6 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const ...@@ -835,7 +833,6 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
prog, prog,
ins, ins,
{args[0], w, r, bias, ih, ic, pph}, {args[0], w, r, bias, ih, ic, pph},
lstm_op.input_forget,
actv_funcs.at(0), actv_funcs.at(0),
actv_funcs.at(1), actv_funcs.at(1),
actv_funcs.at(2)); actv_funcs.at(2));
...@@ -892,7 +889,6 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward, ...@@ -892,7 +889,6 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
program& prog, program& prog,
instruction_ref ins, instruction_ref ins,
std::vector<instruction_ref> inputs, std::vector<instruction_ref> inputs,
int input_forget,
const operation& actv_func1, const operation& actv_func1,
const operation& actv_func2, const operation& actv_func2,
const operation& actv_func3) const const operation& actv_func3) const
......
...@@ -2341,6 +2341,90 @@ TEST_CASE(lstm_forward) ...@@ -2341,6 +2341,90 @@ TEST_CASE(lstm_forward)
0.119811}; 0.119811};
EXPECT(migraphx::verify_range(output_data, output_data_gold)); EXPECT(migraphx::verify_range(output_data, output_data_gold));
} }
}
TEST_CASE(lstm_forward_more)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
std::vector<float> w_data{
0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344,
0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382,
-0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279,
-0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976,
-0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285};
std::vector<float> r_data{
0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685,
-0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858,
-0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265,
-0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562,
0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100,
0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161,
-0.2169, -0.1344, 0.3468, -0.2260};
std::vector<float> bias_data{0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881,
-0.4407, -0.2760, 0.1274, -0.0083, -0.2885, 0.3949, -0.0182,
0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067, 0.0807,
0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316,
-0.3025, 0.3637, -0.3181, -0.4655};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331,
-0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122,
0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685,
0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.9104,
-1.9004,
0.3337,
0.5741,
0.5671,
0.0458,
0.4514,
-0.8968,
-0.9201,
0.1962,
0.5771,
-0.5332};
std::vector<float> ic_data{0.9569,
-0.5981,
1.1312,
1.0945,
1.1055,
-0.1212,
-0.9097,
0.7831,
-1.6991,
-1.9498,
-1.2567,
-0.4114};
std::vector<float> pph_data{1.84369764,
0.68413646,
-0.44892886,
-1.50904413,
0.3860796,
-0.52186625,
1.08474445,
-1.80867321,
1.32594529,
0.4336262,
-0.83699064,
0.49162736};
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
// forward, 3 args // forward, 3 args
{ {
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment