onnx_rnn_test.cpp 65.7 KB
Newer Older
1
2
3
4
5
6
#include <iostream>
#include <vector>
#include <migraphx/literal.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
7
8
9
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
10
11
12
#include <migraphx/onnx.hpp>
#include "test.hpp"

Shucai Xiao's avatar
Shucai Xiao committed
13
14
15
16
17
18
19
20
migraphx::program optimize_onnx(const std::string& name, bool eliminate_deadcode = true)
{
    auto prog = migraphx::parse_onnx(name);
    if(eliminate_deadcode)
        migraphx::run_passes(prog, {migraphx::dead_code_elimination{}});

    // remove the last identity instruction
    auto last_ins = std::prev(prog.end());
21
22
23
24
    if(last_ins->name() == "@return")
    {
        prog.remove_instruction(last_ins);
    }
Shucai Xiao's avatar
Shucai Xiao committed
25
26
27
28

    return prog;
}

29
TEST_CASE(rnn_test_bidirectional)
30
31
32
33
34
35
36
{
    std::size_t sl = 5;  // sequence len
    std::size_t bs = 3;  // batch size
    std::size_t hs = 20; // hidden size
    std::size_t is = 10; // input size
    std::size_t nd = 2;  // num directions
    float clip     = 0.0f;
37
38
39
40
41
42
43
44
    migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}};
    migraphx::shape w_shape{migraphx::shape::float_type, {nd, hs, is}};
    migraphx::shape r_shape{migraphx::shape::float_type, {nd, hs, hs}};
    migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 2 * hs}};
    migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
    migraphx::shape ih_shape{migraphx::shape::float_type, {nd, bs, hs}};

    migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
45
    auto* mm = p.get_main_module();
46

Shucai Xiao's avatar
Shucai Xiao committed
47
48
49
50
51
52
    auto seq     = mm->add_parameter("seq", seq_shape);
    auto w       = mm->add_parameter("w", w_shape);
    auto r       = mm->add_parameter("r", r_shape);
    auto bias    = mm->add_parameter("bias", bias_shape);
    auto seq_len = mm->add_parameter("seq_len", sl_shape);
    auto ih      = mm->add_parameter("h0", ih_shape);
53
54

    auto out_hs =
Shucai Xiao's avatar
Shucai Xiao committed
55
56
57
58
59
60
61
62
63
64
65
        mm->add_instruction(migraphx::op::rnn{hs,
                                              {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
                                              migraphx::op::rnn_direction::bidirectional,
                                              clip},
                            seq,
                            w,
                            r,
                            bias,
                            seq_len,
                            ih);
    mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
66
    auto prog = optimize_onnx("onnx_rnn_bi.onnx");
67
68
69

    EXPECT(p == prog);
}
70

71
TEST_CASE(rnn_test_one_direction)
72
73
74
75
76
77
78
79
80
81
82
83
84
{
    std::size_t sl = 5;  // sequence len
    std::size_t bs = 3;  // batch size
    std::size_t hs = 20; // hidden size
    std::size_t is = 10; // input size
    std::size_t nd = 1;  // num directions
    float clip     = 0.0f;
    migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}};
    migraphx::shape w_shape{migraphx::shape::float_type, {nd, hs, is}};
    migraphx::shape r_shape{migraphx::shape::float_type, {nd, hs, hs}};
    migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 2 * hs}};
    migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
    migraphx::shape ih_shape{migraphx::shape::float_type, {nd, bs, hs}};
85
86
87
88

    // forward
    {
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
89
90
91
92
93
94
95
        auto* mm     = p.get_main_module();
        auto seq     = mm->add_parameter("seq", seq_shape);
        auto w       = mm->add_parameter("w", w_shape);
        auto r       = mm->add_parameter("r", r_shape);
        auto bias    = mm->add_parameter("bias", bias_shape);
        auto seq_len = mm->add_parameter("seq_len", sl_shape);
        auto ih      = mm->add_parameter("h0", ih_shape);
96
97

        auto out_hs =
Shucai Xiao's avatar
Shucai Xiao committed
98
99
100
101
102
103
104
105
106
107
108
            mm->add_instruction(migraphx::op::rnn{hs,
                                                  {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
                                                  migraphx::op::rnn_direction::forward,
                                                  clip},
                                seq,
                                w,
                                r,
                                bias,
                                seq_len,
                                ih);
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
109
        auto prog = optimize_onnx("onnx_rnn_forward.onnx");
110
111
112
113
114
115
116

        EXPECT(p == prog);
    }

    // reverse
    {
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
117
118
119
120
121
122
123
        auto* mm     = p.get_main_module();
        auto seq     = mm->add_parameter("seq", seq_shape);
        auto w       = mm->add_parameter("w", w_shape);
        auto r       = mm->add_parameter("r", r_shape);
        auto bias    = mm->add_parameter("bias", bias_shape);
        auto seq_len = mm->add_parameter("seq_len", sl_shape);
        auto ih      = mm->add_parameter("h0", ih_shape);
124
        auto out_hs =
Shucai Xiao's avatar
Shucai Xiao committed
125
126
127
128
129
130
131
132
133
134
135
            mm->add_instruction(migraphx::op::rnn{hs,
                                                  {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
                                                  migraphx::op::rnn_direction::reverse,
                                                  clip},
                                seq,
                                w,
                                r,
                                bias,
                                seq_len,
                                ih);
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
136
        auto prog = optimize_onnx("onnx_rnn_reverse.onnx");
137
138
139
140
141
142
143

        EXPECT(p == prog);
    }

    // 3 argumments
    {
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
144
145
146
147
148
        auto* mm = p.get_main_module();
        auto seq = mm->add_parameter("seq", seq_shape);
        auto w   = mm->add_parameter("w", w_shape);
        auto r   = mm->add_parameter("r", r_shape);
        auto und = mm->add_instruction(migraphx::op::undefined{});
149
        auto out_hs =
Shucai Xiao's avatar
Shucai Xiao committed
150
151
152
153
154
155
156
157
158
159
160
            mm->add_instruction(migraphx::op::rnn{hs,
                                                  {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
                                                  migraphx::op::rnn_direction::reverse,
                                                  clip},
                                seq,
                                w,
                                r,
                                und,
                                und,
                                und);
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
161
        auto prog = optimize_onnx("onnx_rnn_3args.onnx");
162
163
164
165
166
167
168

        EXPECT(p == prog);
    }

    // 5 argumments
    {
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
169
        auto* mm = p.get_main_module();
170

Shucai Xiao's avatar
Shucai Xiao committed
171
172
173
174
175
176
        auto seq     = mm->add_parameter("seq", seq_shape);
        auto w       = mm->add_parameter("w", w_shape);
        auto r       = mm->add_parameter("r", r_shape);
        auto bias    = mm->add_parameter("bias", bias_shape);
        auto seq_len = mm->add_parameter("seq_len", sl_shape);
        auto und     = mm->add_instruction(migraphx::op::undefined{});
177
178

        auto out_hs =
Shucai Xiao's avatar
Shucai Xiao committed
179
180
181
182
183
184
185
186
187
188
189
            mm->add_instruction(migraphx::op::rnn{hs,
                                                  {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
                                                  migraphx::op::rnn_direction::forward,
                                                  clip},
                                seq,
                                w,
                                r,
                                bias,
                                seq_len,
                                und);
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
190
        auto prog = optimize_onnx("onnx_rnn_5args.onnx");
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207

        EXPECT(p == prog);
    }
}

TEST_CASE(gru_test)
{
    std::size_t sl = 5;  // sequence len
    std::size_t bs = 3;  // batch size
    std::size_t hs = 20; // hidden size
    std::size_t is = 10; // input size
    std::size_t nd = 2;  // num directions
    float clip     = 0.0f;
    // forward
    {
        nd = 1;
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
208
        auto* mm = p.get_main_module();
209
210

        auto seq =
Shucai Xiao's avatar
Shucai Xiao committed
211
            mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
212
        auto w =
Shucai Xiao's avatar
Shucai Xiao committed
213
            mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
214
        auto r =
Shucai Xiao's avatar
Shucai Xiao committed
215
            mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
216
        auto bias =
Shucai Xiao's avatar
Shucai Xiao committed
217
            mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
218
        auto seq_len =
Shucai Xiao's avatar
Shucai Xiao committed
219
220
221
            mm->add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
        auto ih =
            mm->add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
222
223

        auto out_hs =
Shucai Xiao's avatar
Shucai Xiao committed
224
225
226
227
228
229
230
231
232
233
234
235
            mm->add_instruction(migraphx::op::gru{hs,
                                                  {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
                                                  migraphx::op::rnn_direction::forward,
                                                  clip,
                                                  1},
                                seq,
                                w,
                                r,
                                bias,
                                seq_len,
                                ih);
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
236
        auto prog = optimize_onnx("onnx_gru_forward.onnx");
237
238
239
240
241
242
243
244

        EXPECT(p == prog);
    }

    // reverse
    {
        nd = 1;
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
245
        auto* mm = p.get_main_module();
246
247

        auto seq =
Shucai Xiao's avatar
Shucai Xiao committed
248
            mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
249
        auto w =
Shucai Xiao's avatar
Shucai Xiao committed
250
            mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
251
        auto r =
Shucai Xiao's avatar
Shucai Xiao committed
252
            mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
253
        auto bias =
Shucai Xiao's avatar
Shucai Xiao committed
254
            mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
255
        auto seq_len =
Shucai Xiao's avatar
Shucai Xiao committed
256
257
258
            mm->add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
        auto ih =
            mm->add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
259
260

        auto out_hs =
Shucai Xiao's avatar
Shucai Xiao committed
261
262
263
264
265
266
267
268
269
270
271
            mm->add_instruction(migraphx::op::gru{hs,
                                                  {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
                                                  migraphx::op::rnn_direction::reverse,
                                                  clip},
                                seq,
                                w,
                                r,
                                bias,
                                seq_len,
                                ih);
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
272
        auto prog = optimize_onnx("onnx_gru_reverse.onnx");
273
274
275
276
277
278
279
280

        EXPECT(p == prog);
    }

    // bidirectional
    {
        nd = 2;
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
281
        auto* mm = p.get_main_module();
282
283

        auto seq =
Shucai Xiao's avatar
Shucai Xiao committed
284
            mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
285
        auto w =
Shucai Xiao's avatar
Shucai Xiao committed
286
            mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
287
        auto r =
Shucai Xiao's avatar
Shucai Xiao committed
288
            mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
289
        auto bias =
Shucai Xiao's avatar
Shucai Xiao committed
290
            mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
291
        auto seq_len =
Shucai Xiao's avatar
Shucai Xiao committed
292
293
294
            mm->add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
        auto ih =
            mm->add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
295
296

        auto out_hs =
Shucai Xiao's avatar
Shucai Xiao committed
297
298
299
300
301
302
303
304
305
306
307
308
309
310
            mm->add_instruction(migraphx::op::gru{hs,
                                                  {migraphx::op::tanh{},
                                                   migraphx::op::sigmoid{},
                                                   migraphx::op::relu{},
                                                   migraphx::op::tanh{}},
                                                  migraphx::op::rnn_direction::bidirectional,
                                                  clip},
                                seq,
                                w,
                                r,
                                bias,
                                seq_len,
                                ih);
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
311
        auto prog = optimize_onnx("onnx_gru_bi.onnx");
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329

        EXPECT(p == prog);
    }
}

TEST_CASE(gru_test_args)
{
    std::size_t sl = 5;  // sequence len
    std::size_t bs = 3;  // batch size
    std::size_t hs = 20; // hidden size
    std::size_t is = 10; // input size
    std::size_t nd = 2;  // num directions
    float clip     = 0.0f;

    // 3 arguments
    {
        nd = 1;
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
330
        auto* mm = p.get_main_module();
331
332

        auto seq =
Shucai Xiao's avatar
Shucai Xiao committed
333
            mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
334
        auto w =
Shucai Xiao's avatar
Shucai Xiao committed
335
            mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
336
        auto r =
Shucai Xiao's avatar
Shucai Xiao committed
337
338
            mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
        auto und = mm->add_instruction(migraphx::op::undefined{});
339
        auto out_hs =
Shucai Xiao's avatar
Shucai Xiao committed
340
341
342
343
344
345
346
347
348
349
350
            mm->add_instruction(migraphx::op::gru{hs,
                                                  {migraphx::op::tanh{}, migraphx::op::sigmoid{}},
                                                  migraphx::op::rnn_direction::forward,
                                                  clip},
                                seq,
                                w,
                                r,
                                und,
                                und,
                                und);
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
351
        auto prog = optimize_onnx("onnx_gru_3arg.onnx");
352
353
354
355
356
357
358
359

        EXPECT(p == prog);
    }

    // 4 arguments
    {
        nd = 1;
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
360
        auto* mm = p.get_main_module();
361
362

        auto seq =
Shucai Xiao's avatar
Shucai Xiao committed
363
            mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
364
        auto w =
Shucai Xiao's avatar
Shucai Xiao committed
365
            mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
366
        auto r =
Shucai Xiao's avatar
Shucai Xiao committed
367
            mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
368
        auto bias =
Shucai Xiao's avatar
Shucai Xiao committed
369
370
            mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
        auto und = mm->add_instruction(migraphx::op::undefined{});
371
372

        auto out_hs =
Shucai Xiao's avatar
Shucai Xiao committed
373
374
375
376
377
378
379
380
381
382
383
            mm->add_instruction(migraphx::op::gru{hs,
                                                  {migraphx::op::relu{}, migraphx::op::tanh{}},
                                                  migraphx::op::rnn_direction::reverse,
                                                  clip},
                                seq,
                                w,
                                r,
                                bias,
                                und,
                                und);
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
384
        auto prog = optimize_onnx("onnx_gru_4arg.onnx");
385
386
387
388
389
390
391
392

        EXPECT(p == prog);
    }

    // 5 arguments
    {
        nd = 2;
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
393
        auto* mm = p.get_main_module();
394
395

        auto seq =
Shucai Xiao's avatar
Shucai Xiao committed
396
            mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
397
        auto w =
Shucai Xiao's avatar
Shucai Xiao committed
398
            mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
399
        auto r =
Shucai Xiao's avatar
Shucai Xiao committed
400
            mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
401
        auto bias =
Shucai Xiao's avatar
Shucai Xiao committed
402
            mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
403
        auto seq_len =
Shucai Xiao's avatar
Shucai Xiao committed
404
405
            mm->add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
        auto und = mm->add_instruction(migraphx::op::undefined{});
406
407

        auto out_hs =
Shucai Xiao's avatar
Shucai Xiao committed
408
409
410
411
412
413
414
415
416
417
418
419
420
421
            mm->add_instruction(migraphx::op::gru{hs,
                                                  {migraphx::op::tanh{},
                                                   migraphx::op::sigmoid{},
                                                   migraphx::op::relu{},
                                                   migraphx::op::tanh{}},
                                                  migraphx::op::rnn_direction::bidirectional,
                                                  clip},
                                seq,
                                w,
                                r,
                                bias,
                                seq_len,
                                und);
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
422
        auto prog = optimize_onnx("onnx_gru_5arg.onnx");
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439

        EXPECT(p == prog);
    }
}

TEST_CASE(gru_test_actv_funcs)
{
    std::size_t sl = 5;  // sequence len
    std::size_t bs = 3;  // batch size
    std::size_t hs = 20; // hidden size
    std::size_t is = 10; // input size
    std::size_t nd = 2;  // num directions
    float clip     = 0.0f;
    // bidirection, 0 actv function
    {
        nd = 2;
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
440
        auto* mm = p.get_main_module();
441
442

        auto seq =
Shucai Xiao's avatar
Shucai Xiao committed
443
            mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
444
        auto w =
Shucai Xiao's avatar
Shucai Xiao committed
445
            mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
446
        auto r =
Shucai Xiao's avatar
Shucai Xiao committed
447
            mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
448
        auto bias =
Shucai Xiao's avatar
Shucai Xiao committed
449
            mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
450
        auto seq_len =
Shucai Xiao's avatar
Shucai Xiao committed
451
452
453
            mm->add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
        auto ih =
            mm->add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
454

Shucai Xiao's avatar
Shucai Xiao committed
455
        auto out_hs =
Shucai Xiao's avatar
Shucai Xiao committed
456
457
458
459
460
461
462
463
464
465
466
467
468
469
            mm->add_instruction(migraphx::op::gru{hs,
                                                  {migraphx::op::sigmoid{},
                                                   migraphx::op::tanh{},
                                                   migraphx::op::sigmoid{},
                                                   migraphx::op::tanh{}},
                                                  migraphx::op::rnn_direction::bidirectional,
                                                  clip},
                                seq,
                                w,
                                r,
                                bias,
                                seq_len,
                                ih);
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
470
        auto prog = optimize_onnx("onnx_gru_bi_0.onnx");
471
472
473
474
475
476
477
478

        EXPECT(p == prog);
    }

    // bidirection, 1 actv function
    {
        nd = 2;
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
479
        auto* mm = p.get_main_module();
480
481

        auto seq =
Shucai Xiao's avatar
Shucai Xiao committed
482
            mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
483
        auto w =
Shucai Xiao's avatar
Shucai Xiao committed
484
            mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
485
        auto r =
Shucai Xiao's avatar
Shucai Xiao committed
486
            mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
487
        auto bias =
Shucai Xiao's avatar
Shucai Xiao committed
488
            mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
489
        auto seq_len =
Shucai Xiao's avatar
Shucai Xiao committed
490
491
492
            mm->add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
        auto ih =
            mm->add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
493

Shucai Xiao's avatar
Shucai Xiao committed
494
        auto out_hs =
Shucai Xiao's avatar
Shucai Xiao committed
495
496
497
498
499
500
501
502
503
504
505
506
507
508
            mm->add_instruction(migraphx::op::gru{hs,
                                                  {migraphx::op::sigmoid{},
                                                   migraphx::op::sigmoid{},
                                                   migraphx::op::sigmoid{},
                                                   migraphx::op::sigmoid{}},
                                                  migraphx::op::rnn_direction::bidirectional,
                                                  clip},
                                seq,
                                w,
                                r,
                                bias,
                                seq_len,
                                ih);
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
509
        auto prog = optimize_onnx("onnx_gru_bi_1.onnx");
510
511
512
513
514
515
516
517

        EXPECT(p == prog);
    }

    // bidirection, 2 actv functions
    {
        nd = 2;
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
518
        auto* mm = p.get_main_module();
519
520

        auto seq =
Shucai Xiao's avatar
Shucai Xiao committed
521
            mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
522
        auto w =
Shucai Xiao's avatar
Shucai Xiao committed
523
            mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
524
        auto r =
Shucai Xiao's avatar
Shucai Xiao committed
525
            mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
526
        auto bias =
Shucai Xiao's avatar
Shucai Xiao committed
527
            mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
528
        auto seq_len =
Shucai Xiao's avatar
Shucai Xiao committed
529
530
531
            mm->add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
        auto ih =
            mm->add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
532
533

        auto out_hs =
Shucai Xiao's avatar
Shucai Xiao committed
534
535
536
537
538
539
540
541
542
543
544
545
546
547
            mm->add_instruction(migraphx::op::gru{hs,
                                                  {migraphx::op::tanh{},
                                                   migraphx::op::sigmoid{},
                                                   migraphx::op::tanh{},
                                                   migraphx::op::sigmoid{}},
                                                  migraphx::op::rnn_direction::bidirectional,
                                                  clip},
                                seq,
                                w,
                                r,
                                bias,
                                seq_len,
                                ih);
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
548
        auto prog = optimize_onnx("onnx_gru_bi_2.onnx");
549
550
551
552
553
554
555
556

        EXPECT(p == prog);
    }

    // bidirection, 3 actv functions
    {
        nd = 2;
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
557
        auto* mm = p.get_main_module();
558
559

        auto seq =
Shucai Xiao's avatar
Shucai Xiao committed
560
            mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
561
        auto w =
Shucai Xiao's avatar
Shucai Xiao committed
562
            mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
563
        auto r =
Shucai Xiao's avatar
Shucai Xiao committed
564
            mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
565
        auto bias =
Shucai Xiao's avatar
Shucai Xiao committed
566
            mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
567
        auto seq_len =
Shucai Xiao's avatar
Shucai Xiao committed
568
569
570
            mm->add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
        auto ih =
            mm->add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
571

Shucai Xiao's avatar
Shucai Xiao committed
572
        auto out_hs =
Shucai Xiao's avatar
Shucai Xiao committed
573
574
575
576
577
578
579
580
581
582
583
584
585
586
            mm->add_instruction(migraphx::op::gru{hs,
                                                  {migraphx::op::tanh{},
                                                   migraphx::op::sigmoid{},
                                                   migraphx::op::tanh{},
                                                   migraphx::op::tanh{}},
                                                  migraphx::op::rnn_direction::bidirectional,
                                                  clip},
                                seq,
                                w,
                                r,
                                bias,
                                seq_len,
                                ih);
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
587
        auto prog = optimize_onnx("onnx_gru_bi_3.onnx");
588
589
590
591
592
593
594
595

        EXPECT(p == prog);
    }

    // forward, 0 actv function
    {
        nd = 1;
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
596
        auto* mm = p.get_main_module();
597
598

        auto seq =
Shucai Xiao's avatar
Shucai Xiao committed
599
            mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
600
        auto w =
Shucai Xiao's avatar
Shucai Xiao committed
601
            mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
602
        auto r =
Shucai Xiao's avatar
Shucai Xiao committed
603
            mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
604
        auto bias =
Shucai Xiao's avatar
Shucai Xiao committed
605
            mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
606
        auto seq_len =
Shucai Xiao's avatar
Shucai Xiao committed
607
608
609
            mm->add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
        auto ih =
            mm->add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
610
611

        auto out_hs =
Shucai Xiao's avatar
Shucai Xiao committed
612
613
614
615
616
617
618
619
620
621
622
            mm->add_instruction(migraphx::op::gru{hs,
                                                  {migraphx::op::sigmoid{}, migraphx::op::tanh{}},
                                                  migraphx::op::rnn_direction::forward,
                                                  clip},
                                seq,
                                w,
                                r,
                                bias,
                                seq_len,
                                ih);
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
623
        auto prog = optimize_onnx("onnx_gru_forward_0.onnx");
624
625
626
627
628
629
630
631

        EXPECT(p == prog);
    }

    // reverse, 1 actv function
    {
        nd = 1;
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
632
        auto* mm = p.get_main_module();
633
634

        auto seq =
Shucai Xiao's avatar
Shucai Xiao committed
635
            mm->add_parameter("seq", migraphx::shape{migraphx::shape::float_type, {sl, bs, is}});
636
        auto w =
Shucai Xiao's avatar
Shucai Xiao committed
637
            mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, is}});
638
        auto r =
Shucai Xiao's avatar
Shucai Xiao committed
639
            mm->add_parameter("r", migraphx::shape{migraphx::shape::float_type, {nd, 3 * hs, hs}});
640
        auto bias =
Shucai Xiao's avatar
Shucai Xiao committed
641
            mm->add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {nd, 6 * hs}});
642
        auto seq_len =
Shucai Xiao's avatar
Shucai Xiao committed
643
644
645
            mm->add_parameter("seq_len", migraphx::shape{migraphx::shape::int32_type, {bs}});
        auto ih =
            mm->add_parameter("h0", migraphx::shape{migraphx::shape::float_type, {nd, bs, hs}});
646

Shucai Xiao's avatar
Shucai Xiao committed
647
        auto out_hs =
Shucai Xiao's avatar
Shucai Xiao committed
648
649
650
651
652
653
654
655
656
657
658
            mm->add_instruction(migraphx::op::gru{hs,
                                                  {migraphx::op::relu{}, migraphx::op::relu{}},
                                                  migraphx::op::rnn_direction::reverse,
                                                  clip},
                                seq,
                                w,
                                r,
                                bias,
                                seq_len,
                                ih);
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
659
        auto prog = optimize_onnx("onnx_gru_reverse_1.onnx");
660
661
662
663
664

        EXPECT(p == prog);
    }
}

665
666
TEST_CASE(lstm_forward)
{
Shucai Xiao's avatar
Shucai Xiao committed
667
668
669
670
671
672
    std::size_t sl   = 5;  // sequence len
    std::size_t bs   = 3;  // batch size
    std::size_t hs   = 20; // hidden size
    std::size_t is   = 10; // input size
    std::size_t nd   = 1;  // num directions
    float clip       = 0.0f;
673
674
    int input_forget = 1;
    migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}};
Shucai Xiao's avatar
Shucai Xiao committed
675
676
    migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}};
    migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}};
677
678
679
    migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}};
    migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
    migraphx::shape ih_shape{migraphx::shape::float_type, {nd, bs, hs}};
Shucai Xiao's avatar
Shucai Xiao committed
680
    migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}};
681
682
    {
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
683
684
685
686
687
688
689
690
691
692
693
        auto* mm     = p.get_main_module();
        auto seq     = mm->add_parameter("seq", seq_shape);
        auto w       = mm->add_parameter("w", w_shape);
        auto r       = mm->add_parameter("r", r_shape);
        auto bias    = mm->add_parameter("bias", bias_shape);
        auto seq_len = mm->add_parameter("seq_len", sl_shape);
        auto ih      = mm->add_parameter("h0", ih_shape);
        auto ic      = mm->add_parameter("c0", ih_shape);
        auto pph     = mm->add_parameter("pph", pph_shape);

        auto out_hs = mm->add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
694
695
696
697
698
699
700
701
702
703
704
705
706
707
            migraphx::op::lstm{
                hs,
                {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
                migraphx::op::rnn_direction::forward,
                clip,
                input_forget},
            seq,
            w,
            r,
            bias,
            seq_len,
            ih,
            ic,
            pph);
Shucai Xiao's avatar
Shucai Xiao committed
708
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
709
        auto prog = optimize_onnx("onnx_lstm_forward.onnx");
710
711
712

        EXPECT(p == prog);
    }
713
714
715
716

    // 3 args
    {
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
717
718
719
720
721
        auto* mm = p.get_main_module();
        auto seq = mm->add_parameter("seq", seq_shape);
        auto w   = mm->add_parameter("w", w_shape);
        auto r   = mm->add_parameter("r", r_shape);
        auto und = mm->add_instruction(migraphx::op::undefined{});
722

Shucai Xiao's avatar
Shucai Xiao committed
723
        auto out_hs = mm->add_instruction(
724
725
726
727
728
729
730
731
732
733
734
735
736
737
            migraphx::op::lstm{
                hs,
                {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
                migraphx::op::rnn_direction::forward,
                clip,
                input_forget},
            seq,
            w,
            r,
            und,
            und,
            und,
            und,
            und);
Shucai Xiao's avatar
Shucai Xiao committed
738
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
739
740
741
742
743
744
745
746
        auto prog = optimize_onnx("onnx_lstm_f3args.onnx");

        EXPECT(p == prog);
    }

    // 3 args, hs output
    {
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
747
748
749
750
751
        auto* mm = p.get_main_module();
        auto seq = mm->add_parameter("seq", seq_shape);
        auto w   = mm->add_parameter("w", w_shape);
        auto r   = mm->add_parameter("r", r_shape);
        auto und = mm->add_instruction(migraphx::op::undefined{});
Shucai Xiao's avatar
Shucai Xiao committed
752

Shucai Xiao's avatar
Shucai Xiao committed
753
        mm->add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
            migraphx::op::lstm{
                hs,
                {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
                migraphx::op::rnn_direction::forward,
                clip,
                input_forget},
            seq,
            w,
            r,
            und,
            und,
            und,
            und,
            und);
        auto prog = optimize_onnx("onnx_lstm_hs.onnx");

        EXPECT(p == prog);
    }

    // 3 args, last output
    {
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
776
777
778
779
780
        auto* mm = p.get_main_module();
        auto seq = mm->add_parameter("seq", seq_shape);
        auto w   = mm->add_parameter("w", w_shape);
        auto r   = mm->add_parameter("r", r_shape);
        auto und = mm->add_instruction(migraphx::op::undefined{});
Shucai Xiao's avatar
Shucai Xiao committed
781

Shucai Xiao's avatar
Shucai Xiao committed
782
        auto out_hs = mm->add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
783
784
785
786
787
788
789
790
791
792
793
794
795
796
            migraphx::op::lstm{
                hs,
                {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
                migraphx::op::rnn_direction::forward,
                clip,
                input_forget},
            seq,
            w,
            r,
            und,
            und,
            und,
            und,
            und);
Shucai Xiao's avatar
Shucai Xiao committed
797
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
798
799
800
801
802
803
804
805
        auto prog = optimize_onnx("onnx_lstm_last.onnx");

        EXPECT(p == prog);
    }

    // 3 args, cell output
    {
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
806
807
808
809
810
        auto* mm = p.get_main_module();
        auto seq = mm->add_parameter("seq", seq_shape);
        auto w   = mm->add_parameter("w", w_shape);
        auto r   = mm->add_parameter("r", r_shape);
        auto und = mm->add_instruction(migraphx::op::undefined{});
Shucai Xiao's avatar
Shucai Xiao committed
811

Shucai Xiao's avatar
Shucai Xiao committed
812
        auto out_hs = mm->add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
813
814
815
816
817
818
819
820
821
822
823
824
825
826
            migraphx::op::lstm{
                hs,
                {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
                migraphx::op::rnn_direction::forward,
                clip,
                input_forget},
            seq,
            w,
            r,
            und,
            und,
            und,
            und,
            und);
Shucai Xiao's avatar
Shucai Xiao committed
827
        mm->add_instruction(migraphx::op::rnn_last_cell_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
828
        auto prog = optimize_onnx("onnx_lstm_cell.onnx");
829
830
831
832
833
834
835

        EXPECT(p == prog);
    }

    // 4 args
    {
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
836
837
838
839
840
841
842
843
        auto* mm  = p.get_main_module();
        auto seq  = mm->add_parameter("seq", seq_shape);
        auto w    = mm->add_parameter("w", w_shape);
        auto r    = mm->add_parameter("r", r_shape);
        auto bias = mm->add_parameter("bias", bias_shape);
        auto und  = mm->add_instruction(migraphx::op::undefined{});

        auto out_hs = mm->add_instruction(
844
845
846
847
848
849
850
851
852
853
854
855
856
857
            migraphx::op::lstm{
                hs,
                {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
                migraphx::op::rnn_direction::forward,
                clip,
                input_forget},
            seq,
            w,
            r,
            bias,
            und,
            und,
            und,
            und);
Shucai Xiao's avatar
Shucai Xiao committed
858
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
859
        auto prog = optimize_onnx("onnx_lstm_f4args.onnx");
860
861
862
863
864
865
866

        EXPECT(p == prog);
    }

    // 5 args
    {
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
867
868
869
870
871
872
873
874
875
        auto* mm     = p.get_main_module();
        auto seq     = mm->add_parameter("seq", seq_shape);
        auto w       = mm->add_parameter("w", w_shape);
        auto r       = mm->add_parameter("r", r_shape);
        auto bias    = mm->add_parameter("bias", bias_shape);
        auto seq_len = mm->add_parameter("seq_len", sl_shape);
        auto und     = mm->add_instruction(migraphx::op::undefined{});

        auto out_hs = mm->add_instruction(
876
877
878
879
880
881
882
883
884
885
886
887
888
889
            migraphx::op::lstm{
                hs,
                {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
                migraphx::op::rnn_direction::forward,
                clip,
                input_forget},
            seq,
            w,
            r,
            bias,
            seq_len,
            und,
            und,
            und);
Shucai Xiao's avatar
Shucai Xiao committed
890
891
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
        mm->add_instruction(migraphx::op::rnn_last_cell_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
892
        auto prog = optimize_onnx("onnx_lstm_f5args.onnx");
893
894
895
896
897
898
899

        EXPECT(p == prog);
    }

    // 6 args
    {
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
900
901
902
903
904
905
906
907
908
909
        auto* mm     = p.get_main_module();
        auto seq     = mm->add_parameter("seq", seq_shape);
        auto w       = mm->add_parameter("w", w_shape);
        auto r       = mm->add_parameter("r", r_shape);
        auto bias    = mm->add_parameter("bias", bias_shape);
        auto seq_len = mm->add_parameter("seq_len", sl_shape);
        auto ih      = mm->add_parameter("h0", ih_shape);
        auto und     = mm->add_instruction(migraphx::op::undefined{});

        auto out_hs = mm->add_instruction(
910
911
912
913
914
915
916
917
918
919
920
921
922
923
            migraphx::op::lstm{
                hs,
                {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
                migraphx::op::rnn_direction::forward,
                clip,
                input_forget},
            seq,
            w,
            r,
            bias,
            seq_len,
            ih,
            und,
            und);
Shucai Xiao's avatar
Shucai Xiao committed
924
925
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
        mm->add_instruction(migraphx::op::rnn_last_cell_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
926
        auto prog = optimize_onnx("onnx_lstm_f6args.onnx");
927
928
929
930
931
932
933

        EXPECT(p == prog);
    }

    // 7 args
    {
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
934
935
936
937
938
939
940
941
942
943
944
        auto* mm     = p.get_main_module();
        auto seq     = mm->add_parameter("seq", seq_shape);
        auto w       = mm->add_parameter("w", w_shape);
        auto r       = mm->add_parameter("r", r_shape);
        auto bias    = mm->add_parameter("bias", bias_shape);
        auto seq_len = mm->add_parameter("seq_len", sl_shape);
        auto ih      = mm->add_parameter("h0", ih_shape);
        auto ic      = mm->add_parameter("c0", ih_shape);
        auto und     = mm->add_instruction(migraphx::op::undefined{});

        auto out_hs = mm->add_instruction(
945
946
947
948
949
950
951
952
953
954
955
956
957
958
            migraphx::op::lstm{
                hs,
                {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
                migraphx::op::rnn_direction::forward,
                clip,
                input_forget},
            seq,
            w,
            r,
            bias,
            seq_len,
            ih,
            ic,
            und);
Shucai Xiao's avatar
Shucai Xiao committed
959
960
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
        mm->add_instruction(migraphx::op::rnn_last_cell_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
961
        auto prog = optimize_onnx("onnx_lstm_f7args.onnx");
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984

        EXPECT(p == prog);
    }
}

// activation functions
TEST_CASE(lstm_forward_actv_func)
{
    std::size_t sl   = 5;  // sequence len
    std::size_t bs   = 3;  // batch size
    std::size_t hs   = 20; // hidden size
    std::size_t is   = 10; // input size
    std::size_t nd   = 1;  // num directions
    float clip       = 0.0f;
    int input_forget = 1;
    migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}};
    migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}};
    migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}};
    migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}};
    migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
    // no activation function specified
    {
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
985
986
987
988
989
990
991
992
        auto* mm = p.get_main_module();
        auto seq = mm->add_parameter("seq", seq_shape);
        auto w   = mm->add_parameter("w", w_shape);
        auto r   = mm->add_parameter("r", r_shape);
        // auto seq_len = mm->add_parameter("seq_len", sl_shape);
        auto und = mm->add_instruction(migraphx::op::undefined{});

        auto out_hs = mm->add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
993
994
995
996
997
998
            migraphx::op::lstm{
                hs,
                {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
                migraphx::op::rnn_direction::forward,
                clip,
                input_forget},
999
1000
1001
1002
1003
1004
1005
1006
            seq,
            w,
            r,
            und,
            und,
            und,
            und,
            und);
Shucai Xiao's avatar
Shucai Xiao committed
1007
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
1008
        auto prog = optimize_onnx("onnx_lstm_f0af.onnx");
1009
1010
1011
1012
1013
1014
1015

        EXPECT(p == prog);
    }

    // 1 activation function specified
    {
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
1016
1017
1018
1019
1020
1021
1022
1023
        auto* mm  = p.get_main_module();
        auto seq  = mm->add_parameter("seq", seq_shape);
        auto w    = mm->add_parameter("w", w_shape);
        auto r    = mm->add_parameter("r", r_shape);
        auto bias = mm->add_parameter("bias", bias_shape);
        auto und  = mm->add_instruction(migraphx::op::undefined{});

        auto out_hs = mm->add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
            migraphx::op::lstm{
                hs,
                {migraphx::op::sigmoid{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}},
                migraphx::op::rnn_direction::forward,
                clip,
                input_forget},
            seq,
            w,
            r,
            bias,
            und,
            und,
            und,
            und);
Shucai Xiao's avatar
Shucai Xiao committed
1038
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
1039
        auto prog = optimize_onnx("onnx_lstm_f1af.onnx");
1040
1041
1042
1043
1044
1045
1046

        EXPECT(p == prog);
    }

    // 2 activation function specified
    {
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
1047
1048
1049
1050
1051
1052
1053
1054
1055
        auto* mm     = p.get_main_module();
        auto seq     = mm->add_parameter("seq", seq_shape);
        auto w       = mm->add_parameter("w", w_shape);
        auto r       = mm->add_parameter("r", r_shape);
        auto bias    = mm->add_parameter("bias", bias_shape);
        auto seq_len = mm->add_parameter("seq_len", sl_shape);
        auto und     = mm->add_instruction(migraphx::op::undefined{});

        auto out_hs = mm->add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
            migraphx::op::lstm{
                hs,
                {migraphx::op::tanh{}, migraphx::op::sigmoid{}, migraphx::op::sigmoid{}},
                migraphx::op::rnn_direction::forward,
                clip,
                input_forget},
            seq,
            w,
            r,
            bias,
            seq_len,
            und,
            und,
            und);
Shucai Xiao's avatar
Shucai Xiao committed
1070
1071
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
        mm->add_instruction(migraphx::op::rnn_last_cell_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
1072
        auto prog = optimize_onnx("onnx_lstm_f2af.onnx");
1073
1074
1075

        EXPECT(p == prog);
    }
1076
1077
1078
1079
}

TEST_CASE(lstm_reverse)
{
Shucai Xiao's avatar
Shucai Xiao committed
1080
1081
1082
1083
1084
1085
    std::size_t sl   = 5;  // sequence len
    std::size_t bs   = 3;  // batch size
    std::size_t hs   = 20; // hidden size
    std::size_t is   = 10; // input size
    std::size_t nd   = 1;  // num directions
    float clip       = 0.0f;
1086
1087
    int input_forget = 1;
    migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}};
Shucai Xiao's avatar
Shucai Xiao committed
1088
1089
    migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}};
    migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}};
1090
1091
1092
    migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}};
    migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
    migraphx::shape ih_shape{migraphx::shape::float_type, {nd, bs, hs}};
Shucai Xiao's avatar
Shucai Xiao committed
1093
    migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}};
1094
1095
    {
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
        auto* mm     = p.get_main_module();
        auto seq     = mm->add_parameter("seq", seq_shape);
        auto w       = mm->add_parameter("w", w_shape);
        auto r       = mm->add_parameter("r", r_shape);
        auto bias    = mm->add_parameter("bias", bias_shape);
        auto seq_len = mm->add_parameter("seq_len", sl_shape);
        auto ih      = mm->add_parameter("h0", ih_shape);
        auto ic      = mm->add_parameter("c0", ih_shape);
        auto pph     = mm->add_parameter("pph", pph_shape);

        auto out_hs = mm->add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
            migraphx::op::lstm{
                hs,
                {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
                migraphx::op::rnn_direction::reverse,
                clip,
                input_forget},
            seq,
            w,
            r,
            bias,
            seq_len,
            ih,
            ic,
            pph);
Shucai Xiao's avatar
Shucai Xiao committed
1121
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
1122
        auto prog = optimize_onnx("onnx_lstm_reverse.onnx");
1123
1124
1125

        EXPECT(p == prog);
    }
1126
1127
1128
1129

    // 5 args
    {
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
1130
1131
1132
1133
1134
1135
1136
1137
1138
        auto* mm     = p.get_main_module();
        auto seq     = mm->add_parameter("seq", seq_shape);
        auto w       = mm->add_parameter("w", w_shape);
        auto r       = mm->add_parameter("r", r_shape);
        auto bias    = mm->add_parameter("bias", bias_shape);
        auto seq_len = mm->add_parameter("seq_len", sl_shape);
        auto und     = mm->add_instruction(migraphx::op::undefined{});

        auto out_hs = mm->add_instruction(
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
            migraphx::op::lstm{
                hs,
                {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
                migraphx::op::rnn_direction::reverse,
                clip,
                input_forget},
            seq,
            w,
            r,
            bias,
            seq_len,
            und,
            und,
            und);
Shucai Xiao's avatar
Shucai Xiao committed
1153
1154
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
        mm->add_instruction(migraphx::op::rnn_last_cell_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
1155
        auto prog = optimize_onnx("onnx_lstm_r5args.onnx");
1156
1157
1158
1159
1160
1161
1162

        EXPECT(p == prog);
    }

    // no activation function specified
    {
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
1163
1164
1165
1166
1167
        auto* mm = p.get_main_module();
        auto seq = mm->add_parameter("seq", seq_shape);
        auto w   = mm->add_parameter("w", w_shape);
        auto r   = mm->add_parameter("r", r_shape);
        auto und = mm->add_instruction(migraphx::op::undefined{});
1168

Shucai Xiao's avatar
Shucai Xiao committed
1169
        auto out_hs = mm->add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1170
1171
1172
1173
1174
1175
            migraphx::op::lstm{
                hs,
                {migraphx::op::sigmoid{}, migraphx::op::tanh{}, migraphx::op::tanh{}},
                migraphx::op::rnn_direction::reverse,
                clip,
                input_forget},
1176
1177
1178
1179
1180
1181
1182
1183
            seq,
            w,
            r,
            und,
            und,
            und,
            und,
            und);
Shucai Xiao's avatar
Shucai Xiao committed
1184
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
1185
        auto prog = optimize_onnx("onnx_lstm_r0af.onnx");
1186
1187
1188

        EXPECT(p == prog);
    }
1189
1190
1191
1192
}

TEST_CASE(lstm_bidirectional)
{
Shucai Xiao's avatar
Shucai Xiao committed
1193
1194
1195
1196
1197
1198
    std::size_t sl   = 5;  // sequence len
    std::size_t bs   = 3;  // batch size
    std::size_t hs   = 20; // hidden size
    std::size_t is   = 10; // input size
    std::size_t nd   = 2;  // num directions
    float clip       = 0.0f;
1199
1200
    int input_forget = 1;
    migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}};
Shucai Xiao's avatar
Shucai Xiao committed
1201
1202
    migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}};
    migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}};
1203
1204
1205
    migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}};
    migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
    migraphx::shape ih_shape{migraphx::shape::float_type, {nd, bs, hs}};
Shucai Xiao's avatar
Shucai Xiao committed
1206
    migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}};
1207
1208
    {
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
1209
1210
1211
1212
1213
1214
1215
1216
1217
        auto* mm     = p.get_main_module();
        auto seq     = mm->add_parameter("seq", seq_shape);
        auto w       = mm->add_parameter("w", w_shape);
        auto r       = mm->add_parameter("r", r_shape);
        auto bias    = mm->add_parameter("bias", bias_shape);
        auto seq_len = mm->add_parameter("seq_len", sl_shape);
        auto ih      = mm->add_parameter("h0", ih_shape);
        auto ic      = mm->add_parameter("c0", ih_shape);
        auto pph     = mm->add_parameter("pph", pph_shape);
1218

Shucai Xiao's avatar
Shucai Xiao committed
1219
        auto out_hs =
Shucai Xiao's avatar
Shucai Xiao committed
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
            mm->add_instruction(migraphx::op::lstm{hs,
                                                   {migraphx::op::sigmoid{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::sigmoid{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::tanh{}},
                                                   migraphx::op::rnn_direction::bidirectional,
                                                   clip,
                                                   input_forget},
                                seq,
                                w,
                                r,
                                bias,
                                seq_len,
                                ih,
                                ic,
                                pph);
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
1239
        auto prog = optimize_onnx("onnx_lstm_bi.onnx");
1240
1241
1242

        EXPECT(p == prog);
    }
1243
1244
1245
1246

    // 3 args
    {
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
1247
1248
1249
1250
1251
        auto* mm = p.get_main_module();
        auto seq = mm->add_parameter("seq", seq_shape);
        auto w   = mm->add_parameter("w", w_shape);
        auto r   = mm->add_parameter("r", r_shape);
        auto und = mm->add_instruction(migraphx::op::undefined{});
1252

Shucai Xiao's avatar
Shucai Xiao committed
1253
        auto out_hs =
Shucai Xiao's avatar
Shucai Xiao committed
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
            mm->add_instruction(migraphx::op::lstm{hs,
                                                   {migraphx::op::sigmoid{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::sigmoid{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::tanh{}},
                                                   migraphx::op::rnn_direction::bidirectional,
                                                   clip,
                                                   input_forget},
                                seq,
                                w,
                                r,
                                und,
                                und,
                                und,
                                und,
                                und);
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
1273
        auto prog = optimize_onnx("onnx_lstm_bi3args.onnx");
1274
1275
1276
1277
1278
1279
1280

        EXPECT(p == prog);
    }

    // 4 args
    {
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
1281
1282
1283
1284
1285
1286
        auto* mm  = p.get_main_module();
        auto seq  = mm->add_parameter("seq", seq_shape);
        auto w    = mm->add_parameter("w", w_shape);
        auto r    = mm->add_parameter("r", r_shape);
        auto bias = mm->add_parameter("bias", bias_shape);
        auto und  = mm->add_instruction(migraphx::op::undefined{});
1287

Shucai Xiao's avatar
Shucai Xiao committed
1288
        auto out_hs =
Shucai Xiao's avatar
Shucai Xiao committed
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
            mm->add_instruction(migraphx::op::lstm{hs,
                                                   {migraphx::op::sigmoid{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::sigmoid{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::tanh{}},
                                                   migraphx::op::rnn_direction::bidirectional,
                                                   clip,
                                                   input_forget},
                                seq,
                                w,
                                r,
                                bias,
                                und,
                                und,
                                und,
                                und);
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
1308
        auto prog = optimize_onnx("onnx_lstm_bi4args.onnx");
1309
1310
1311
1312
1313
1314
1315

        EXPECT(p == prog);
    }

    // 5 args
    {
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
1316
1317
1318
1319
1320
1321
1322
        auto* mm     = p.get_main_module();
        auto seq     = mm->add_parameter("seq", seq_shape);
        auto w       = mm->add_parameter("w", w_shape);
        auto r       = mm->add_parameter("r", r_shape);
        auto bias    = mm->add_parameter("bias", bias_shape);
        auto seq_len = mm->add_parameter("seq_len", sl_shape);
        auto und     = mm->add_instruction(migraphx::op::undefined{});
1323

Shucai Xiao's avatar
Shucai Xiao committed
1324
        auto out_hs =
Shucai Xiao's avatar
Shucai Xiao committed
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
            mm->add_instruction(migraphx::op::lstm{hs,
                                                   {migraphx::op::sigmoid{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::sigmoid{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::tanh{}},
                                                   migraphx::op::rnn_direction::bidirectional,
                                                   clip,
                                                   input_forget},
                                seq,
                                w,
                                r,
                                bias,
                                seq_len,
                                und,
                                und,
                                und);
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
1344
        auto prog = optimize_onnx("onnx_lstm_bi5args.onnx");
1345
1346
1347
1348
1349
1350
1351

        EXPECT(p == prog);
    }

    // 6 args
    {
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
1352
1353
1354
1355
1356
1357
1358
1359
        auto* mm     = p.get_main_module();
        auto seq     = mm->add_parameter("seq", seq_shape);
        auto w       = mm->add_parameter("w", w_shape);
        auto r       = mm->add_parameter("r", r_shape);
        auto bias    = mm->add_parameter("bias", bias_shape);
        auto seq_len = mm->add_parameter("seq_len", sl_shape);
        auto ih      = mm->add_parameter("h0", ih_shape);
        auto und     = mm->add_instruction(migraphx::op::undefined{});
1360

Shucai Xiao's avatar
Shucai Xiao committed
1361
        auto out_hs =
Shucai Xiao's avatar
Shucai Xiao committed
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
            mm->add_instruction(migraphx::op::lstm{hs,
                                                   {migraphx::op::sigmoid{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::sigmoid{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::tanh{}},
                                                   migraphx::op::rnn_direction::bidirectional,
                                                   clip,
                                                   input_forget},
                                seq,
                                w,
                                r,
                                bias,
                                seq_len,
                                ih,
                                und,
                                und);
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
1381
        auto prog = optimize_onnx("onnx_lstm_bi6args.onnx");
1382
1383
1384
1385
1386
1387
1388

        EXPECT(p == prog);
    }

    // 7 args
    {
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
1389
1390
1391
1392
1393
1394
1395
1396
1397
        auto* mm     = p.get_main_module();
        auto seq     = mm->add_parameter("seq", seq_shape);
        auto w       = mm->add_parameter("w", w_shape);
        auto r       = mm->add_parameter("r", r_shape);
        auto bias    = mm->add_parameter("bias", bias_shape);
        auto seq_len = mm->add_parameter("seq_len", sl_shape);
        auto ih      = mm->add_parameter("h0", ih_shape);
        auto ic      = mm->add_parameter("c0", ih_shape);
        auto und     = mm->add_instruction(migraphx::op::undefined{});
1398

Shucai Xiao's avatar
Shucai Xiao committed
1399
        auto out_hs =
Shucai Xiao's avatar
Shucai Xiao committed
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
            mm->add_instruction(migraphx::op::lstm{hs,
                                                   {migraphx::op::sigmoid{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::sigmoid{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::tanh{}},
                                                   migraphx::op::rnn_direction::bidirectional,
                                                   clip,
                                                   input_forget},
                                seq,
                                w,
                                r,
                                bias,
                                seq_len,
                                ih,
                                ic,
                                und);
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
1419
        auto prog = optimize_onnx("onnx_lstm_bi7args.onnx");
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443

        EXPECT(p == prog);
    }
}

TEST_CASE(lstm_bi_actv_funcs)
{
    std::size_t sl   = 5;  // sequence len
    std::size_t bs   = 3;  // batch size
    std::size_t hs   = 20; // hidden size
    std::size_t is   = 10; // input size
    std::size_t nd   = 2;  // num directions
    float clip       = 0.0f;
    int input_forget = 1;
    migraphx::shape seq_shape{migraphx::shape::float_type, {sl, bs, is}};
    migraphx::shape w_shape{migraphx::shape::float_type, {nd, 4 * hs, is}};
    migraphx::shape r_shape{migraphx::shape::float_type, {nd, 4 * hs, hs}};
    migraphx::shape bias_shape{migraphx::shape::float_type, {nd, 8 * hs}};
    migraphx::shape sl_shape{migraphx::shape::int32_type, {bs}};
    migraphx::shape ih_shape{migraphx::shape::float_type, {nd, bs, hs}};
    migraphx::shape pph_shape{migraphx::shape::float_type, {nd, 3 * hs}};
    // 0 activation function
    {
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
1444
1445
1446
1447
1448
        auto* mm = p.get_main_module();
        auto seq = mm->add_parameter("seq", seq_shape);
        auto w   = mm->add_parameter("w", w_shape);
        auto r   = mm->add_parameter("r", r_shape);
        auto und = mm->add_instruction(migraphx::op::undefined{});
1449

Shucai Xiao's avatar
Shucai Xiao committed
1450
        auto out_hs =
Shucai Xiao's avatar
Shucai Xiao committed
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
            mm->add_instruction(migraphx::op::lstm{hs,
                                                   {migraphx::op::sigmoid{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::sigmoid{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::tanh{}},
                                                   migraphx::op::rnn_direction::bidirectional,
                                                   clip,
                                                   input_forget},
                                seq,
                                w,
                                r,
                                und,
                                und,
                                und,
                                und,
                                und);
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
1470
        auto prog = optimize_onnx("onnx_lstm_bi0af.onnx");
1471
1472
1473
1474
1475
1476
1477

        EXPECT(p == prog);
    }

    // 1 activation function
    {
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
1478
1479
1480
1481
1482
1483
        auto* mm  = p.get_main_module();
        auto seq  = mm->add_parameter("seq", seq_shape);
        auto w    = mm->add_parameter("w", w_shape);
        auto r    = mm->add_parameter("r", r_shape);
        auto bias = mm->add_parameter("bias", bias_shape);
        auto und  = mm->add_instruction(migraphx::op::undefined{});
1484

Shucai Xiao's avatar
Shucai Xiao committed
1485
        auto out_hs =
Shucai Xiao's avatar
Shucai Xiao committed
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
            mm->add_instruction(migraphx::op::lstm{hs,
                                                   {migraphx::op::sigmoid{},
                                                    migraphx::op::sigmoid{},
                                                    migraphx::op::sigmoid{},
                                                    migraphx::op::sigmoid{},
                                                    migraphx::op::sigmoid{},
                                                    migraphx::op::sigmoid{}},
                                                   migraphx::op::rnn_direction::bidirectional,
                                                   clip,
                                                   input_forget},
                                seq,
                                w,
                                r,
                                bias,
                                und,
                                und,
                                und,
                                und);
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
1505
        auto prog = optimize_onnx("onnx_lstm_bi1af.onnx");
1506
1507
1508
1509
1510
1511
1512

        EXPECT(p == prog);
    }

    // 2 activation functions
    {
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
1513
1514
1515
1516
1517
1518
1519
        auto* mm     = p.get_main_module();
        auto seq     = mm->add_parameter("seq", seq_shape);
        auto w       = mm->add_parameter("w", w_shape);
        auto r       = mm->add_parameter("r", r_shape);
        auto bias    = mm->add_parameter("bias", bias_shape);
        auto seq_len = mm->add_parameter("seq_len", sl_shape);
        auto und     = mm->add_instruction(migraphx::op::undefined{});
1520

Shucai Xiao's avatar
Shucai Xiao committed
1521
        auto out_hs =
Shucai Xiao's avatar
Shucai Xiao committed
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
            mm->add_instruction(migraphx::op::lstm{hs,
                                                   {migraphx::op::sigmoid{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::sigmoid{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::tanh{}},
                                                   migraphx::op::rnn_direction::bidirectional,
                                                   clip,
                                                   input_forget},
                                seq,
                                w,
                                r,
                                bias,
                                seq_len,
                                und,
                                und,
                                und);
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
1541
        auto prog = optimize_onnx("onnx_lstm_bi2af.onnx");
1542
1543
1544
1545
1546
1547
1548

        EXPECT(p == prog);
    }

    // 4 activation functions
    {
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
1549
1550
1551
1552
1553
1554
1555
1556
        auto* mm     = p.get_main_module();
        auto seq     = mm->add_parameter("seq", seq_shape);
        auto w       = mm->add_parameter("w", w_shape);
        auto r       = mm->add_parameter("r", r_shape);
        auto bias    = mm->add_parameter("bias", bias_shape);
        auto seq_len = mm->add_parameter("seq_len", sl_shape);
        auto ih      = mm->add_parameter("h0", ih_shape);
        auto und     = mm->add_instruction(migraphx::op::undefined{});
1557

Shucai Xiao's avatar
Shucai Xiao committed
1558
        auto out_hs =
Shucai Xiao's avatar
Shucai Xiao committed
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
            mm->add_instruction(migraphx::op::lstm{hs,
                                                   {migraphx::op::sigmoid{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::tanh{}},
                                                   migraphx::op::rnn_direction::bidirectional,
                                                   clip,
                                                   input_forget},
                                seq,
                                w,
                                r,
                                bias,
                                seq_len,
                                ih,
                                und,
                                und);
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
1578
        auto prog = optimize_onnx("onnx_lstm_bi4af.onnx");
1579
1580
1581
1582
1583
1584
1585

        EXPECT(p == prog);
    }

    // 5 activation functions
    {
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
1586
1587
1588
1589
1590
1591
1592
1593
1594
        auto* mm     = p.get_main_module();
        auto seq     = mm->add_parameter("seq", seq_shape);
        auto w       = mm->add_parameter("w", w_shape);
        auto r       = mm->add_parameter("r", r_shape);
        auto bias    = mm->add_parameter("bias", bias_shape);
        auto seq_len = mm->add_parameter("seq_len", sl_shape);
        auto ih      = mm->add_parameter("h0", ih_shape);
        auto ic      = mm->add_parameter("c0", ih_shape);
        auto und     = mm->add_instruction(migraphx::op::undefined{});
1595

Shucai Xiao's avatar
Shucai Xiao committed
1596
        auto out_hs =
Shucai Xiao's avatar
Shucai Xiao committed
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
            mm->add_instruction(migraphx::op::lstm{hs,
                                                   {migraphx::op::sigmoid{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::sigmoid{},
                                                    migraphx::op::sigmoid{}},
                                                   migraphx::op::rnn_direction::bidirectional,
                                                   clip,
                                                   input_forget},
                                seq,
                                w,
                                r,
                                bias,
                                seq_len,
                                ih,
                                ic,
                                und);
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
1616
        auto prog = optimize_onnx("onnx_lstm_bi5af.onnx");
1617
1618
1619
1620
1621
1622
1623

        EXPECT(p == prog);
    }

    // 6 activation functions
    {
        migraphx::program p;
Shucai Xiao's avatar
Shucai Xiao committed
1624
1625
1626
1627
1628
        auto* mm = p.get_main_module();
        auto seq = mm->add_parameter("seq", seq_shape);
        auto w   = mm->add_parameter("w", w_shape);
        auto r   = mm->add_parameter("r", r_shape);
        auto und = mm->add_instruction(migraphx::op::undefined{});
1629

Shucai Xiao's avatar
Shucai Xiao committed
1630
        auto out_hs =
Shucai Xiao's avatar
Shucai Xiao committed
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
            mm->add_instruction(migraphx::op::lstm{hs,
                                                   {migraphx::op::sigmoid{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::tanh{},
                                                    migraphx::op::sigmoid{},
                                                    migraphx::op::tanh{}},
                                                   migraphx::op::rnn_direction::bidirectional,
                                                   clip,
                                                   input_forget},
                                seq,
                                w,
                                r,
                                und,
                                und,
                                und,
                                und,
                                und);
        mm->add_instruction(migraphx::op::rnn_last_hs_output{}, out_hs);
Shucai Xiao's avatar
Shucai Xiao committed
1650
        auto prog = optimize_onnx("onnx_lstm_bi6af.onnx");
1651
1652
1653

        EXPECT(p == prog);
    }
1654
1655
}

1656
int main(int argc, const char* argv[]) { test::run(argc, argv); }