rewrite_rnn.cpp 55.4 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
Paul's avatar
Paul committed
4
5
6
7
8
9
10
11
12
13
14
15
16
#include <migraphx/op/add.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/gru.hpp>
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/rnn.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/squeeze.hpp>
#include <migraphx/op/sub.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/unsqueeze.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
17
#include <migraphx/op/contiguous.hpp>
18
19
20
#include <migraphx/op/common.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
21
22
#include <migraphx/make_op.hpp>

Shucai Xiao's avatar
Shucai Xiao committed
23
24
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
25
#include <migraphx/ranges.hpp>
26
#include <migraphx/op/common.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
27
28
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
29
30
31
32

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

33
void rewrite_rnn::apply(module& prog) const
Shucai Xiao's avatar
Shucai Xiao committed
34
35
36
{
    for(auto ins : iterator_for(prog))
    {
Shucai Xiao's avatar
Shucai Xiao committed
37
        if(ins->name() == "rnn")
Shucai Xiao's avatar
Shucai Xiao committed
38
        {
Shucai Xiao's avatar
Shucai Xiao committed
39
            apply_vanilla_rnn(prog, ins);
40
        }
41
        else if(ins->name() == "gru")
42
43
        {
            apply_gru(prog, ins);
Shucai Xiao's avatar
Shucai Xiao committed
44
        }
45
46
47
48
        else if(ins->name() == "lstm")
        {
            apply_lstm(prog, ins);
        }
Shucai Xiao's avatar
Shucai Xiao committed
49
    }
50
51
}

52
void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
53
54
55
56
{
    assert(ins->name() == "rnn");
    // could be 3 to 6 inputs, but the parse_rnn function will
    // append undefined operators to make 6 arguments when parsing
Shucai Xiao's avatar
Shucai Xiao committed
57
    // an onnx file. Another case is user can have num of arguments
58
    // when writing their module.
59
60
61
62
63
64
65
66
67
    auto args = ins->inputs();

    shape seq_shape         = args[0]->get_shape();
    std::size_t hidden_size = args[1]->get_shape().lens()[1];
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
    migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
    std::vector<float> data(ih_shape.elements(), 0);

Shucai Xiao's avatar
Shucai Xiao committed
68
69
    auto actv_funcs         = vanilla_rnn_actv_funcs(ins);
    auto rnn_op             = any_cast<op::rnn>(ins->get_operator());
70
71
72
73
74
75
76
77
78
79
80
    op::rnn_direction dirct = rnn_op.direction;

    // process sequence length
    instruction_ref seq_lens = prog.end();
    if((args.size() >= 5) && args[4]->name() != "undefined")
    {
        seq_lens = args[4];
    }

    bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);

81
    instruction_ref last_output{};
82
    if(dirct == op::rnn_direction::bidirectional)
83
84
    {
        // input weight matrix
85
86
87
88
        auto w_forward = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
        auto w_reverse = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
89
90

        // hidden state weight matrix
91
92
93
94
        auto r_forward = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
        auto r_reverse = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
95
96
97
98

        // process bias
        instruction_ref bias_forward = prog.end();
        instruction_ref bias_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
99
        if(args.size() >= 4 && args[3]->name() != "undefined")
100
        {
101
102
103
104
            bias_forward = prog.insert_instruction(
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
            bias_reverse = prog.insert_instruction(
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
105
106
107
108
109
110
        }

        // process intial hidden state, it could be the 6th argument
        // or the 5th one (if the sequence len argument is ignored)
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
111
        if(args.size() == 6 && args[5]->name() != "undefined")
112
        {
113
114
115
116
            ih_forward = prog.insert_instruction(
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
            ih_reverse = prog.insert_instruction(
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
117
118
119
120
121
122
123
        }
        else
        {
            ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
            ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
        }

124
125
126
127
128
129
130
131
132
        auto ret_forward =
            vanilla_rnn_cell(true,
                             prog,
                             ins,
                             {args[0], w_forward, r_forward, bias_forward, seq_lens, ih_forward},
                             actv_funcs.at(0));

        if(variable_seq_len)
        {
133
134
            args[0] = prog.insert_instruction(
                ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
135
136
137
138
139
140
141
142
        }

        auto ret_reverse =
            vanilla_rnn_cell(false,
                             prog,
                             ins,
                             {args[0], w_reverse, r_reverse, bias_reverse, seq_lens, ih_reverse},
                             actv_funcs.at(1));
143

144
145
146
147
        auto concat_output = prog.insert_instruction(
            ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
        last_output =
            prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_output);
148
149
150
151
152
153

        // The following logic is to ensure the last instruction rewritten from
        // rnn operator is a concat instruction
        // sequence len is 1
        if(ret_forward[0] == prog.end())
        {
154
155
            prog.replace_instruction(
                ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
156
157
158
        }
        else
        {
159
160
161
162
163
164
            ret_forward[0] = prog.insert_instruction(
                ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
            ret_reverse[0] = prog.insert_instruction(
                ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
            prog.replace_instruction(
                ins, make_op("concat", {{"axis", 1}}), {ret_forward[0], ret_reverse[0]});
165
166
167
168
        }
    }
    else
    {
169
        bool is_forward = (dirct == op::rnn_direction::forward);
170
171
172
173
174
175
176
177
        // input weight matrix
        auto w = args[1];

        // hidden state weight matrix
        auto r = args[2];

        // process bias and initial hidden state
        instruction_ref bias = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
178
        if(args.size() >= 4 && args[3]->name() != "undefined")
179
180
181
182
183
184
        {
            bias = args[3];
        }

        // process intial hidden state
        instruction_ref ih;
Shucai Xiao's avatar
Shucai Xiao committed
185
        if(args.size() == 6 && args[5]->name() != "undefined")
186
187
188
189
190
191
192
193
        {
            ih = args[5];
        }
        else
        {
            ih = prog.add_literal(migraphx::literal{ih_shape, data});
        }

194
195
        if(!is_forward and variable_seq_len)
        {
196
197
            args[0] = prog.insert_instruction(
                ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
198
199
200
201
        }

        auto ret = vanilla_rnn_cell(
            is_forward, prog, ins, {args[0], w, r, bias, seq_lens, ih}, actv_funcs.at(0));
202
        last_output = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
203
204
205
206
207
208

        // following logic is to ensure the last instruction is a
        // concat instruction
        // sequence len is 1
        if(ret[0] == prog.end())
        {
209
            prog.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
210
211
212
213
214
        }
        else
        {
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
215
216
            prog.replace_instruction(
                ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
217
218
219
        }
    }

220
221
222
223
    // in case of all sequences are of the same lengths and shorter than the
    // max sequence length, need to pad 0's at the end for output hidden states
    ins = pad_hidden_states(prog, args[0], seq_lens, ins);
    replace_last_hs_output(prog, ins, seq_lens, last_output, dirct);
Shucai Xiao's avatar
Shucai Xiao committed
224
225
}

Shucai Xiao's avatar
Shucai Xiao committed
226
std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
227
                                                           module& prog,
Shucai Xiao's avatar
Shucai Xiao committed
228
                                                           instruction_ref ins,
229
                                                           std::vector<instruction_ref> inputs,
Shucai Xiao's avatar
Shucai Xiao committed
230
                                                           operation& actv_func) const
Shucai Xiao's avatar
Shucai Xiao committed
231
{
232
233
234
235
236
237
238
239
    assert(inputs.size() == 6);
    auto seq      = inputs.at(0);
    auto w        = inputs.at(1);
    auto r        = inputs.at(2);
    auto bias     = inputs.at(3);
    auto seq_lens = inputs.at(4);
    auto ih       = inputs.at(5);

Shucai Xiao's avatar
Shucai Xiao committed
240
241
    // squeeze and transpose w
    std::vector<int64_t> perm{1, 0};
242
243
    auto sw      = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
    auto tran_sw = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sw);
Shucai Xiao's avatar
Shucai Xiao committed
244
245

    // squeeze and transpose r
246
247
    auto sr      = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
    auto tran_sr = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sr);
Shucai Xiao's avatar
Shucai Xiao committed
248
249

    // initial hidden state
250
    auto sih      = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
251
    auto sih_lens = sih->get_shape().lens();
Shucai Xiao's avatar
Shucai Xiao committed
252
253

    // bias
254
    instruction_ref bb{};
Shucai Xiao's avatar
Shucai Xiao committed
255
    if(bias != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
256
    {
257
        long hs    = static_cast<long>(r->get_shape().lens()[2]);
258
259
260
261
262
263
264
265
        auto sbias = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
        auto wb    = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), sbias);
        auto rb = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), sbias);
        auto wrb = prog.insert_instruction(ins, make_op("add"), wb, rb);
        bb       = prog.insert_instruction(
            ins, make_op("broadcast", {{"axis", 1}, {"dims", sih_lens}}), wrb);
Shucai Xiao's avatar
Shucai Xiao committed
266
267
    }

Shucai Xiao's avatar
Shucai Xiao committed
268
269
    instruction_ref hidden_out = prog.end();
    instruction_ref last_out{};
270
    last_out     = prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), sih);
271
272
    long seq_len = static_cast<long>(get_seq_len(prog, seq, seq_lens));
    for(long i = 0; i < seq_len; i++)
Shucai Xiao's avatar
Shucai Xiao committed
273
    {
Shucai Xiao's avatar
Shucai Xiao committed
274
        long seq_index = is_forward ? i : (seq_len - 1 - i);
275
276
277
278
279
280
281
282
        auto xt        = prog.insert_instruction(
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
            seq);
        auto cont_xt = prog.insert_instruction(ins, make_op("contiguous"), xt);
        xt           = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
        auto xt_wi   = prog.insert_instruction(ins, make_op("dot"), xt, tran_sw);
        auto ht_ri   = prog.insert_instruction(ins, make_op("dot"), sih, tran_sr);
Shucai Xiao's avatar
Shucai Xiao committed
283
        if(bias != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
284
        {
285
            xt_wi = prog.insert_instruction(ins, make_op("add"), xt_wi, bb);
Shucai Xiao's avatar
Shucai Xiao committed
286
        }
287
        auto xt_ht = prog.insert_instruction(ins, make_op("add"), xt_wi, ht_ri);
Shucai Xiao's avatar
Shucai Xiao committed
288
289

        // apply activation function
Shucai Xiao's avatar
Shucai Xiao committed
290
291
        auto ht = prog.insert_instruction(ins, actv_func, xt_ht);
        sih     = ht;
Shucai Xiao's avatar
Shucai Xiao committed
292

Shucai Xiao's avatar
Shucai Xiao committed
293
294
        // add the dimensions of sequence length (axis 0 for sequence length,
        // axis 1 for num_directions
295
        last_out = prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), ht);
Shucai Xiao's avatar
Shucai Xiao committed
296

Shucai Xiao's avatar
Shucai Xiao committed
297
298
299
        // concatenation for the last last_out is performed in the apply()
        // function to ensure the last instruction is concat, then we have
        // output inserted
Shucai Xiao's avatar
Shucai Xiao committed
300
        if(i < seq_len - 1)
Shucai Xiao's avatar
Shucai Xiao committed
301
        {
Shucai Xiao's avatar
Shucai Xiao committed
302
303
            if(is_forward)
            {
304
305
306
307
                hidden_out = (seq_index == 0)
                                 ? last_out
                                 : prog.insert_instruction(
                                       ins, make_op("concat", {{"axis", 0}}), hidden_out, last_out);
Shucai Xiao's avatar
Shucai Xiao committed
308
309
310
            }
            else
            {
311
312
313
314
                hidden_out = (seq_index == seq_len - 1)
                                 ? last_out
                                 : prog.insert_instruction(
                                       ins, make_op("concat", {{"axis", 0}}), last_out, hidden_out);
Shucai Xiao's avatar
Shucai Xiao committed
315
            }
Shucai Xiao's avatar
Shucai Xiao committed
316
317
318
        }
    }

319
    return {hidden_out, last_out};
Shucai Xiao's avatar
Shucai Xiao committed
320
321
}

Shucai Xiao's avatar
Shucai Xiao committed
322
std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins) const
323
324
{
    auto rnn_op = any_cast<op::rnn>(ins->get_operator());
Shucai Xiao's avatar
Shucai Xiao committed
325
326
327
328
    // could be 3 to 6 inputs, but the parse_gru function will
    // append undefined operators to make 6 arguments when parsing
    // an onnx file. Another case is user can have any num of arguments
    // when writing their program.
329
    if(rnn_op.direction == op::rnn_direction::bidirectional)
330
    {
Shucai Xiao's avatar
Shucai Xiao committed
331
        if(rnn_op.actv_funcs.empty())
332
333
        {
            // default is tanh
334
            return {make_op("tanh"), make_op("tanh")};
335
        }
Shucai Xiao's avatar
Shucai Xiao committed
336
        else if(rnn_op.actv_funcs.size() == 1)
337
338
339
340
341
342
343
344
345
346
        {
            return {rnn_op.actv_funcs.at(0), rnn_op.actv_funcs.at(0)};
        }
        else
        {
            return rnn_op.actv_funcs;
        }
    }
    else
    {
Shucai Xiao's avatar
Shucai Xiao committed
347
        if(rnn_op.actv_funcs.empty())
348
349
        {
            // default is tanh
350
            return {make_op("tanh")};
351
352
353
354
355
356
357
358
        }
        else
        {
            return rnn_op.actv_funcs;
        }
    }
}

359
void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
360
361
362
{
    assert(ins->name() == "gru");
    const auto actv_funcs = gru_actv_funcs(ins);
Shucai Xiao's avatar
Shucai Xiao committed
363
364
365
366
    // could be 3 to 6 inputs, but the parse_gru function will
    // append undefined operators to make 6 arguments when parsing
    // an onnx file. Another case is user can have num of arguments
    // when writing their program.
367
368
369
370
371
372
373
374
375
    auto args = ins->inputs();

    shape seq_shape         = args[0]->get_shape();
    std::size_t hidden_size = args[2]->get_shape().lens()[2];
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
    migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
    std::vector<float> data(ih_shape.elements(), 0.0);

Shucai Xiao's avatar
Shucai Xiao committed
376
    auto gru_op             = any_cast<op::gru>(ins->get_operator());
377
378
379
380
381
382
383
384
385
386
387
    op::rnn_direction dirct = gru_op.direction;

    // process sequence length
    instruction_ref seq_lens = prog.end();
    if((args.size() >= 5) && args[4]->name() != "undefined")
    {
        seq_lens = args[4];
    }

    bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);

388
    instruction_ref last_output{};
389
    if(dirct == op::rnn_direction::bidirectional)
390
391
    {
        // w weight matrix
392
393
394
395
        auto w_forward = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
        auto w_reverse = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
396
397

        // r weight matrix
398
399
400
401
        auto r_forward = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
        auto r_reverse = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
402
403
404
405

        // bias
        instruction_ref bias_forward = prog.end();
        instruction_ref bias_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
406
        if(args.size() >= 4 && args[3]->name() != "undefined")
407
        {
408
409
410
411
            bias_forward = prog.insert_instruction(
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
            bias_reverse = prog.insert_instruction(
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
412
413
414
415
416
        }

        // intial hidden state
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
417
        if(args.size() == 6 && args[5]->name() != "undefined")
418
        {
419
420
421
422
            ih_forward = prog.insert_instruction(
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
            ih_reverse = prog.insert_instruction(
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
423
424
425
426
427
428
429
        }
        else
        {
            ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
            ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
        }

430
431
432
433
434
435
436
437
438
439
440
        auto ret_forward =
            gru_cell(true,
                     prog,
                     ins,
                     {args[0], w_forward, r_forward, bias_forward, seq_lens, ih_forward},
                     gru_op.linear_before_reset,
                     actv_funcs.at(0),
                     actv_funcs.at(1));

        if(variable_seq_len)
        {
441
442
            args[0] = prog.insert_instruction(
                ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
443
        }
Shucai Xiao's avatar
Shucai Xiao committed
444

445
446
447
448
449
450
451
452
        auto ret_reverse =
            gru_cell(false,
                     prog,
                     ins,
                     {args[0], w_reverse, r_reverse, bias_reverse, seq_lens, ih_reverse},
                     gru_op.linear_before_reset,
                     actv_funcs.at(2),
                     actv_funcs.at(3));
453

454
455
456
457
        auto concat_output = prog.insert_instruction(
            ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
        last_output =
            prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_output);
458
459
460
461
462

        // The following logic is to ensure the last instruction rewritten
        // from gru operator is a concat
        if(ret_forward[0] == prog.end())
        {
463
464
            prog.replace_instruction(
                ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
465
466
467
        }
        else
        {
468
469
470
471
472
473
            ret_forward[0] = prog.insert_instruction(
                ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
            ret_reverse[0] = prog.insert_instruction(
                ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
            prog.replace_instruction(
                ins, make_op("concat", {{"axis", 1}}), {ret_forward[0], ret_reverse[0]});
474
475
476
477
        }
    }
    else
    {
478
        bool is_forward = (dirct == op::rnn_direction::forward);
479
480
481
482
483
484
        // weight matrix
        auto w = args[1];
        auto r = args[2];

        // bias
        instruction_ref bias = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
485
        if(args.size() >= 4 && args[3]->name() != "undefined")
486
487
488
489
490
491
        {
            bias = args[3];
        }

        // intial hidden state
        instruction_ref ih{};
Shucai Xiao's avatar
Shucai Xiao committed
492
        if(args.size() == 6 && args[5]->name() != "undefined")
493
494
495
496
497
498
499
500
        {
            ih = args[5];
        }
        else
        {
            ih = prog.add_literal(migraphx::literal{ih_shape, data});
        }

501
502
        if(!is_forward and variable_seq_len)
        {
503
504
            args[0] = prog.insert_instruction(
                ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
505
506
        }

507
508
509
        auto ret = gru_cell(is_forward,
                            prog,
                            ins,
510
                            {args[0], w, r, bias, seq_lens, ih},
511
512
513
514
                            gru_op.linear_before_reset,
                            actv_funcs.at(0),
                            actv_funcs.at(1));

515
        last_output = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
516
517
518

        if(ret[0] == prog.end())
        {
519
            prog.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
520
521
522
523
524
        }
        else
        {
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
525
526
            prog.replace_instruction(
                ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
527
528
529
        }
    }

530
531
532
533
    // in case of all sequences are of the same lengths and shorter than the
    // max sequence length, need to pad 0's at the end for output hidden states
    ins = pad_hidden_states(prog, args[0], seq_lens, ins);
    replace_last_hs_output(prog, ins, seq_lens, last_output, dirct);
534
535
536
}

std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
537
                                                   module& prog,
Shucai Xiao's avatar
Shucai Xiao committed
538
539
540
541
542
                                                   instruction_ref ins,
                                                   std::vector<instruction_ref> inputs,
                                                   int linear_before_reset,
                                                   const operation& actv_func1,
                                                   const operation& actv_func2) const
543
{
544
545
546
547
548
549
550
    assert(inputs.size() == 6);
    auto seq      = inputs.at(0);
    auto w        = inputs.at(1);
    auto r        = inputs.at(2);
    auto bias     = inputs.at(3);
    auto seq_lens = inputs.at(4);
    auto ih       = inputs.at(5);
551

Shucai Xiao's avatar
Shucai Xiao committed
552
553
    instruction_ref hidden_states = prog.end();
    instruction_ref last_output{};
Shucai Xiao's avatar
Shucai Xiao committed
554
    migraphx::shape seq_shape = seq->get_shape();
Shucai Xiao's avatar
Shucai Xiao committed
555
556
    migraphx::shape r_shape   = r->get_shape();
    long hs                   = static_cast<long>(r_shape.lens()[2]);
557

558
559
560
    migraphx::shape ss(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]});
    std::vector<float> data(ss.elements(), 1.0f);
    auto l1 = prog.add_literal(migraphx::literal{ss, data});
561

562
    // w matrix squeeze to 2-dim and do a transpose
563
    std::vector<int64_t> perm{1, 0};
564
565
    auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
    auto tw = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sw);
566

567
    // r slide to two part, zr and h
568
569
570
571
    auto sr  = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
    auto rzr = prog.insert_instruction(
        ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2 * hs}}}), sr);
    auto trzr = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), rzr);
572

573
574
575
    auto rh = prog.insert_instruction(
        ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), sr);
    auto trh = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), rh);
576
577

    // initial states
578
    auto sih  = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
Shucai Xiao's avatar
Shucai Xiao committed
579
    size_t bs = ih->get_shape().lens()[1];
580
581

    // bias
582
583
584
    instruction_ref bwb{};
    instruction_ref brb_zr{};
    instruction_ref brb_h{};
585
586
    if(bias != prog.end())
    {
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
        auto sbias = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
        auto wb    = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {3 * hs}}}), sbias);
        bwb = prog.insert_instruction(
            ins,
            make_op("broadcast", {{"axis", 1}, {"dims", {bs, static_cast<size_t>(3 * hs)}}}),
            wb);

        auto rb_zr = prog.insert_instruction(
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {3 * hs}}, {"ends", {5 * hs}}}),
            sbias);
        auto rb_h = prog.insert_instruction(
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {5 * hs}}, {"ends", {6 * hs}}}),
            sbias);
        brb_zr = prog.insert_instruction(
            ins,
            make_op("broadcast", {{"axis", 1}, {"dims", {bs, static_cast<size_t>(2 * hs)}}}),
            rb_zr);
        brb_h = prog.insert_instruction(
            ins,
            make_op("broadcast", {{"axis", 1}, {"dims", {bs, static_cast<size_t>(hs)}}}),
            rb_h);
611
612
    }

613
    long seq_len = static_cast<long>(get_seq_len(prog, seq, seq_lens));
614
615
616
    for(long i = 0; i < seq_len; i++)
    {
        long seq_index = is_forward ? i : (seq_len - 1 - i);
617
618
619
620
621
622
623
624
625
        auto xt        = prog.insert_instruction(
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
            seq);
        auto cont_xt = prog.insert_instruction(ins, make_op("contiguous"), xt);
        xt           = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);

        auto xt_w    = prog.insert_instruction(ins, make_op("dot"), xt, tw);
        auto ih1_rzr = prog.insert_instruction(ins, make_op("dot"), sih, trzr);
Shucai Xiao's avatar
Shucai Xiao committed
626
        if(bias != prog.end())
627
        {
628
629
            xt_w    = prog.insert_instruction(ins, make_op("add"), xt_w, bwb);
            ih1_rzr = prog.insert_instruction(ins, make_op("add"), ih1_rzr, brb_zr);
630
631
        }

632
633
634
635
636
637
        auto xw_z = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), xt_w);
        auto xw_r = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), xt_w);
        auto xw_h = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {1}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), xt_w);
638

639
640
641
642
        auto hr_z = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), ih1_rzr);
        auto hr_r = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), ih1_rzr);
643

644
        auto xw_hr_z = prog.insert_instruction(ins, make_op("add"), xw_z, hr_z);
Shucai Xiao's avatar
Shucai Xiao committed
645
        auto zt      = prog.insert_instruction(ins, actv_func1, xw_hr_z);
646

647
        auto xw_hr_r = prog.insert_instruction(ins, make_op("add"), xw_r, hr_r);
Shucai Xiao's avatar
Shucai Xiao committed
648
        auto rt      = prog.insert_instruction(ins, actv_func1, xw_hr_r);
649
650
651
652
653

        instruction_ref hr_h{};
        if(linear_before_reset == 0)
        {
            // equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
654
655
            auto rt_ht1 = prog.insert_instruction(ins, make_op("mul"), rt, sih);
            hr_h        = prog.insert_instruction(ins, make_op("dot"), rt_ht1, trh);
Shucai Xiao's avatar
Shucai Xiao committed
656
            if(bias != prog.end())
657
            {
658
                hr_h = prog.insert_instruction(ins, make_op("add"), hr_h, brb_h);
659
            }
660
661
662
        }
        else
        {
663
            // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
664
            auto ht1_rh = prog.insert_instruction(ins, make_op("dot"), sih, trh);
Shucai Xiao's avatar
Shucai Xiao committed
665
            if(bias != prog.end())
666
            {
667
                ht1_rh = prog.insert_instruction(ins, make_op("add"), ht1_rh, brb_h);
668
            }
669
            hr_h = prog.insert_instruction(ins, make_op("mul"), rt, ht1_rh);
670
671
        }

672
        auto xw_hr_h = prog.insert_instruction(ins, make_op("add"), xw_h, hr_h);
Shucai Xiao's avatar
Shucai Xiao committed
673
        auto ht      = prog.insert_instruction(ins, actv_func2, xw_hr_h);
674

675
        // equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
676
677
678
679
680
        auto one_minus_zt    = prog.insert_instruction(ins, make_op("sub"), l1, zt);
        auto one_minus_zt_ht = prog.insert_instruction(ins, make_op("mul"), one_minus_zt, ht);
        auto zt_ht1          = prog.insert_instruction(ins, make_op("mul"), zt, sih);
        sih         = prog.insert_instruction(ins, make_op("add"), one_minus_zt_ht, zt_ht1);
        last_output = prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), sih);
681
682
683
684
685
686
687
688

        if(i < seq_len - 1)
        {
            if(is_forward)
            {
                hidden_states =
                    (seq_index == 0)
                        ? last_output
689
690
                        : prog.insert_instruction(
                              ins, make_op("concat", {{"axis", 0}}), hidden_states, last_output);
691
692
693
694
695
696
            }
            else
            {
                hidden_states =
                    (seq_index == seq_len - 1)
                        ? last_output
697
698
                        : prog.insert_instruction(
                              ins, make_op("concat", {{"axis", 0}}), last_output, hidden_states);
699
700
701
702
703
704
705
706
707
708
709
710
711
712
            }
        }
    }

    return {hidden_states, last_output};
}

std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
{
    auto gru_op = any_cast<op::gru>(ins->get_operator());
    // before rewrite the gru operator, need to ensure
    // we have 4 actv funcs, even though a user does not
    // specifiy any actv func. If less than 4, use the
    // algorithm in parse_gru to make 4 actv functions
713
    if(gru_op.direction == op::rnn_direction::bidirectional)
714
715
    {
        if(gru_op.actv_funcs.empty())
716
            return {make_op("sigmoid"), make_op("tanh"), make_op("sigmoid"), make_op("tanh")};
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
        else if(gru_op.actv_funcs.size() == 1)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0)};
        else if(gru_op.actv_funcs.size() == 2)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1)};
        else if(gru_op.actv_funcs.size() == 3)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1),
                    gru_op.actv_funcs.at(2),
                    gru_op.actv_funcs.at(0)};
        else
            return gru_op.actv_funcs;
    }
    else
    {
        if(gru_op.actv_funcs.empty())
738
            return {make_op("sigmoid"), make_op("tanh")};
739
740
741
742
743
744
745
        else if(gru_op.actv_funcs.size() == 1)
            return {gru_op.actv_funcs.at(0), gru_op.actv_funcs.at(0)};
        else
            return gru_op.actv_funcs;
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
746
// for lstm operators
747
void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
Shucai Xiao's avatar
Shucai Xiao committed
748
749
750
{
    assert(ins->name() == "lstm");
    auto args = ins->inputs();
Shucai Xiao's avatar
Shucai Xiao committed
751

Shucai Xiao's avatar
Shucai Xiao committed
752
    shape seq_shape         = args[0]->get_shape();
753
    std::size_t hidden_size = args[2]->get_shape().lens()[2];
Shucai Xiao's avatar
Shucai Xiao committed
754
755
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
Shucai Xiao's avatar
Shucai Xiao committed
756
    migraphx::shape ihc_shape{type, {1, batch_size, hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
757
    std::vector<float> ihc_data(ihc_shape.elements(), 0.0);
Shucai Xiao's avatar
Shucai Xiao committed
758
759
760

    migraphx::shape pph_shape{type, {1, 3 * hidden_size}};

Shucai Xiao's avatar
Shucai Xiao committed
761
762
    auto actv_funcs         = lstm_actv_funcs(ins);
    auto lstm_op            = any_cast<op::lstm>(ins->get_operator());
Shucai Xiao's avatar
Shucai Xiao committed
763
    op::rnn_direction dirct = lstm_op.direction;
Shucai Xiao's avatar
Shucai Xiao committed
764

Shucai Xiao's avatar
Shucai Xiao committed
765
766
767
768
769
770
771
772
773
774
    // process sequence length
    instruction_ref seq_lens = prog.end();
    if((args.size() >= 5) && args[4]->name() != "undefined")
    {
        seq_lens = args[4];
    }

    bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);

    instruction_ref last_hs_output{};
Shucai Xiao's avatar
Shucai Xiao committed
775
    instruction_ref last_cell_output{};
Shucai Xiao's avatar
Shucai Xiao committed
776
777
    instruction_ref hidden_state{};
    instruction_ref cell_outputs{};
Shucai Xiao's avatar
Shucai Xiao committed
778
    if(dirct == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
779
780
781
    {
        // input weight matrix
        // input weight matrix
782
783
784
785
        auto w_forward = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
        auto w_reverse = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
Shucai Xiao's avatar
Shucai Xiao committed
786
787

        // hidden state weight matrix
788
789
790
791
        auto r_forward = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
        auto r_reverse = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
Shucai Xiao's avatar
Shucai Xiao committed
792
793
794
795

        // process bias
        instruction_ref bias_forward = prog.end();
        instruction_ref bias_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
796
        if(args.size() >= 4 && args[3]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
797
        {
798
799
800
801
            bias_forward = prog.insert_instruction(
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
            bias_reverse = prog.insert_instruction(
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
Shucai Xiao's avatar
Shucai Xiao committed
802
803
804
805
806
807
808
        }

        // process intial hidden state, it is the 6th argument
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
        if(args.size() >= 6 && args[5]->name() != "undefined")
        {
809
810
811
812
            ih_forward = prog.insert_instruction(
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
            ih_reverse = prog.insert_instruction(
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
Shucai Xiao's avatar
Shucai Xiao committed
813
814
815
816
817
818
819
820
821
822
        }
        else
        {
            ih_forward = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
            ih_reverse = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
        }

        // process initial cell value
        instruction_ref ic_forward{};
        instruction_ref ic_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
823
        if(args.size() >= 7 && args[6]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
824
        {
825
826
827
828
            ic_forward = prog.insert_instruction(
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[6]);
            ic_reverse = prog.insert_instruction(
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[6]);
Shucai Xiao's avatar
Shucai Xiao committed
829
830
831
832
833
834
835
836
        }
        else
        {
            ic_forward = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
            ic_reverse = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
        }

        // process weight of the peephole
Shucai Xiao's avatar
Shucai Xiao committed
837
838
        instruction_ref pph_forward = prog.end();
        instruction_ref pph_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
839
        if(args.size() == 8 && args[7]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
840
        {
841
842
843
844
            pph_forward = prog.insert_instruction(
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[7]);
            pph_reverse = prog.insert_instruction(
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[7]);
Shucai Xiao's avatar
Shucai Xiao committed
845
        }
Shucai Xiao's avatar
Shucai Xiao committed
846

Shucai Xiao's avatar
Shucai Xiao committed
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
        auto ret_forward = lstm_cell(true,
                                     prog,
                                     ins,
                                     {args[0],
                                      w_forward,
                                      r_forward,
                                      bias_forward,
                                      seq_lens,
                                      ih_forward,
                                      ic_forward,
                                      pph_forward},
                                     actv_funcs.at(0),
                                     actv_funcs.at(1),
                                     actv_funcs.at(2));

        if(variable_seq_len)
        {
864
865
            args[0] = prog.insert_instruction(
                ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
        }
        auto ret_reverse = lstm_cell(false,
                                     prog,
                                     ins,
                                     {args[0],
                                      w_reverse,
                                      r_reverse,
                                      bias_reverse,
                                      seq_lens,
                                      ih_reverse,
                                      ic_reverse,
                                      pph_reverse},
                                     actv_funcs.at(3),
                                     actv_funcs.at(4),
                                     actv_funcs.at(5));

882
883
884
885
886
887
888
889
        auto concat_hs_output = prog.insert_instruction(
            ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
        auto concat_cell_output = prog.insert_instruction(
            ins, make_op("concat", {{"axis", 1}}), ret_forward[3], ret_reverse[3]);
        last_hs_output =
            prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_hs_output);
        last_cell_output =
            prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_cell_output);
Shucai Xiao's avatar
Shucai Xiao committed
890
891

        // the following logic is to ensure the last instruction is a concat
Shucai Xiao's avatar
Shucai Xiao committed
892
        if(ret_forward[0] == prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
893
        {
Shucai Xiao's avatar
Shucai Xiao committed
894
            cell_outputs = concat_cell_output;
Shucai Xiao's avatar
Shucai Xiao committed
895
896
897
        }
        else
        {
898
899
900
901
            ret_forward[1] = prog.insert_instruction(
                ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
            ret_reverse[1] = prog.insert_instruction(
                ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
Shucai Xiao's avatar
Shucai Xiao committed
902

903
904
905
906
907
908
            ret_forward[3] = prog.insert_instruction(
                ins, make_op("concat", {{"axis", 0}}), ret_forward[2], ret_forward[3]);
            ret_reverse[3] = prog.insert_instruction(
                ins, make_op("concat", {{"axis", 0}}), ret_reverse[3], ret_reverse[2]);
            cell_outputs = prog.insert_instruction(
                ins, make_op("concat", {{"axis", 1}}), ret_forward[3], ret_reverse[3]);
Shucai Xiao's avatar
Shucai Xiao committed
909
        }
Shucai Xiao's avatar
Shucai Xiao committed
910

911
912
        hidden_state = prog.replace_instruction(
            ins, make_op("concat", {{"axis", 1}}), {ret_forward[1], ret_reverse[1]});
Shucai Xiao's avatar
Shucai Xiao committed
913
914
915
    }
    else
    {
Shucai Xiao's avatar
Shucai Xiao committed
916
        bool is_forward = (dirct == op::rnn_direction::forward);
Shucai Xiao's avatar
Shucai Xiao committed
917
918
919
920
921
922
        // weight matrices
        auto w = args[1];
        auto r = args[2];

        // bias
        instruction_ref bias = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
923
        if(args.size() >= 4 && args[3]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
924
925
926
927
928
929
        {
            bias = args[3];
        }

        // initial hidden state
        instruction_ref ih{};
Shucai Xiao's avatar
Shucai Xiao committed
930
        if(args.size() >= 6 && args[5]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
931
932
933
934
935
936
937
938
939
940
        {
            ih = args[5];
        }
        else
        {
            ih = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
        }

        // initial cell value
        instruction_ref ic{};
Shucai Xiao's avatar
Shucai Xiao committed
941
        if(args.size() >= 7 && args[6]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
942
943
944
945
946
947
948
949
950
        {
            ic = args[6];
        }
        else
        {
            ic = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
        }

        // process weight of the peephole
Shucai Xiao's avatar
Shucai Xiao committed
951
        instruction_ref pph = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
952
953
954
955
        if(args.size() == 8 && args[7]->name() != "undefined")
        {
            pph = args[7];
        }
Shucai Xiao's avatar
Shucai Xiao committed
956

Shucai Xiao's avatar
Shucai Xiao committed
957
958
        if(!is_forward and variable_seq_len)
        {
959
960
            args[0] = prog.insert_instruction(
                ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
961
        }
Shucai Xiao's avatar
Shucai Xiao committed
962
        auto ret = lstm_cell(is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
963
964
                             prog,
                             ins,
Shucai Xiao's avatar
Shucai Xiao committed
965
                             {args[0], w, r, bias, seq_lens, ih, ic, pph},
Shucai Xiao's avatar
Shucai Xiao committed
966
967
968
969
                             actv_funcs.at(0),
                             actv_funcs.at(1),
                             actv_funcs.at(2));

970
971
972
        last_hs_output = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
        last_cell_output =
            prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[3]);
Shucai Xiao's avatar
Shucai Xiao committed
973

Shucai Xiao's avatar
Shucai Xiao committed
974
        if(ret[0] == prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
975
        {
Shucai Xiao's avatar
Shucai Xiao committed
976
            cell_outputs = ret[3];
977
            hidden_state = prog.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
Shucai Xiao's avatar
Shucai Xiao committed
978
979
980
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
981
982
            auto concat_cell_arg0 = is_forward ? ret[2] : ret[3];
            auto concat_cell_arg1 = is_forward ? ret[3] : ret[2];
983
984
            cell_outputs          = prog.insert_instruction(
                ins, make_op("concat", {{"axis", 0}}), concat_cell_arg0, concat_cell_arg1);
Shucai Xiao's avatar
Shucai Xiao committed
985

Shucai Xiao's avatar
Shucai Xiao committed
986
987
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
988
989
            hidden_state     = prog.replace_instruction(
                ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
Shucai Xiao's avatar
Shucai Xiao committed
990
991
992
        }
    }

993
994
995
996
997
    // in case of all sequences are of the same lengths and shorter than the
    // max sequence length, need to pad 0's at the end for output hidden states
    hidden_state = pad_hidden_states(prog, args[0], seq_lens, hidden_state);

    // replace last hidden states with corresponding instructions
Shucai Xiao's avatar
Shucai Xiao committed
998
    ins = replace_last_hs_output(prog, hidden_state, seq_lens, last_hs_output, dirct);
999
1000

    // replace last cell outputs with corresponding instructions
Shucai Xiao's avatar
Shucai Xiao committed
1001
    replace_last_cell_output(prog, ins, seq_lens, cell_outputs, last_cell_output, dirct);
Shucai Xiao's avatar
Shucai Xiao committed
1002
1003
1004
}

std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
1005
                                                    module& prog,
Shucai Xiao's avatar
Shucai Xiao committed
1006
1007
1008
1009
1010
                                                    instruction_ref ins,
                                                    std::vector<instruction_ref> inputs,
                                                    const operation& actv_func1,
                                                    const operation& actv_func2,
                                                    const operation& actv_func3) const
Shucai Xiao's avatar
Shucai Xiao committed
1011
{
Shucai Xiao's avatar
Shucai Xiao committed
1012
    // must have 7 args in the input vector
Shucai Xiao's avatar
Shucai Xiao committed
1013
1014
1015
1016
1017
1018
1019
1020
1021
    assert(inputs.size() == 8);
    auto seq      = inputs.at(0);
    auto w        = inputs.at(1);
    auto r        = inputs.at(2);
    auto bias     = inputs.at(3);
    auto seq_lens = inputs.at(4);
    auto ih       = inputs.at(5);
    auto ic       = inputs.at(6);
    auto pph      = inputs.at(7);
Shucai Xiao's avatar
Shucai Xiao committed
1022

Shucai Xiao's avatar
Shucai Xiao committed
1023
    instruction_ref hidden_states = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
1024
1025
1026
    instruction_ref cell_outputs  = prog.end();

    instruction_ref last_hs_output{};
Shucai Xiao's avatar
Shucai Xiao committed
1027
1028
    instruction_ref last_cell_output{};

1029
1030
1031
    migraphx::shape r_shape = r->get_shape();
    long hs                 = static_cast<long>(r_shape.lens()[2]);
    auto bs                 = ih->get_shape().lens()[1];
Shucai Xiao's avatar
Shucai Xiao committed
1032
1033

    std::vector<int64_t> perm{1, 0};
1034
    // w matrix, squeeze and transpose
1035
1036
    auto sw  = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
    auto tsw = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sw);
Shucai Xiao's avatar
Shucai Xiao committed
1037

1038
    // r matrix, squeeze and transpose
1039
1040
    auto sr  = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
    auto tsr = prog.insert_instruction(ins, make_op("transpose", {{"dims", perm}}), sr);
Shucai Xiao's avatar
Shucai Xiao committed
1041

Shucai Xiao's avatar
Shucai Xiao committed
1042
    // initial hidden state
1043
    auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
Shucai Xiao's avatar
Shucai Xiao committed
1044
1045

    // initial cell state
1046
    auto sic     = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ic);
1047
    auto ic_lens = sic->get_shape().lens();
Shucai Xiao's avatar
Shucai Xiao committed
1048

1049
    // bias
1050
    instruction_ref wrb{};
Shucai Xiao's avatar
Shucai Xiao committed
1051
    if(bias != prog.end())
1052
    {
1053

1054
1055
1056
1057
1058
1059
1060
1061
        auto sbias = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
        auto ub_wb = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {4 * hs}}}), sbias);
        auto ub_rb = prog.insert_instruction(
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {4 * hs}}, {"ends", {8 * hs}}}),
            sbias);
        auto ub_wrb = prog.insert_instruction(ins, make_op("add"), ub_wb, ub_rb);
1062

1063
        wrb = prog.insert_instruction(
1064
1065
1066
            ins,
            make_op("broadcast", {{"axis", 1}, {"dims", {bs, 4 * static_cast<size_t>(hs)}}}),
            ub_wrb);
1067
1068
    }

Shucai Xiao's avatar
Shucai Xiao committed
1069
    // peep hole
Shucai Xiao's avatar
Shucai Xiao committed
1070
1071
1072
    instruction_ref pphi_brcst{};
    instruction_ref ppho_brcst{};
    instruction_ref pphf_brcst{};
Shucai Xiao's avatar
Shucai Xiao committed
1073
    if(pph != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
1074
    {
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
        auto spph = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), pph);
        auto pphi = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), spph);
        pphi_brcst = prog.insert_instruction(
            ins, make_op("broadcast", {{"axis", 1}, {"dims", ic_lens}}), pphi);

        auto ppho = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), spph);
        ppho_brcst = prog.insert_instruction(
            ins, make_op("broadcast", {{"axis", 1}, {"dims", ic_lens}}), ppho);

        auto pphf = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), spph);
        pphf_brcst = prog.insert_instruction(
            ins, make_op("broadcast", {{"axis", 1}, {"dims", ic_lens}}), pphf);
Shucai Xiao's avatar
Shucai Xiao committed
1090
    }
Shucai Xiao's avatar
Shucai Xiao committed
1091

Shucai Xiao's avatar
Shucai Xiao committed
1092
    long seq_len = static_cast<long>(get_seq_len(prog, seq, seq_lens));
Shucai Xiao's avatar
Shucai Xiao committed
1093
    for(long i = 0; i < seq_len; ++i)
Shucai Xiao's avatar
Shucai Xiao committed
1094
1095
    {
        long seq_index = is_forward ? i : (seq_len - 1 - i);
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
        auto xt        = prog.insert_instruction(
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
            seq);
        auto cont_xt = prog.insert_instruction(ins, make_op("contiguous"), xt);
        xt           = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);

        auto xt_tsw  = prog.insert_instruction(ins, make_op("dot"), xt, tsw);
        auto sih_tsr = prog.insert_instruction(ins, make_op("dot"), sih, tsr);
        auto xt_sih  = prog.insert_instruction(ins, make_op("add"), xt_tsw, sih_tsr);
Shucai Xiao's avatar
Shucai Xiao committed
1106
        if(bias != prog.end())
1107
        {
1108
            xt_sih = prog.insert_instruction(ins, make_op("add"), xt_sih, wrb);
1109
        }
Shucai Xiao's avatar
Shucai Xiao committed
1110

1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
        auto it_before_actv = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), xt_sih);
        auto ot_before_actv = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), xt_sih);
        auto ft_before_actv = prog.insert_instruction(
            ins,
            make_op("slice", {{"axes", {1}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}),
            xt_sih);
        auto ct_before_actv = prog.insert_instruction(
            ins,
            make_op("slice", {{"axes", {1}}, {"starts", {3 * hs}}, {"ends", {4 * hs}}}),
            xt_sih);
1123

Shucai Xiao's avatar
Shucai Xiao committed
1124
        if(pph != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
1125
        {
1126
1127
            auto pphi_ct   = prog.insert_instruction(ins, make_op("mul"), pphi_brcst, sic);
            it_before_actv = prog.insert_instruction(ins, make_op("add"), it_before_actv, pphi_ct);
Shucai Xiao's avatar
Shucai Xiao committed
1128

1129
1130
            auto pphf_ct   = prog.insert_instruction(ins, make_op("mul"), pphf_brcst, sic);
            ft_before_actv = prog.insert_instruction(ins, make_op("add"), ft_before_actv, pphf_ct);
Shucai Xiao's avatar
Shucai Xiao committed
1131
        }
Shucai Xiao's avatar
Shucai Xiao committed
1132
        auto it = prog.insert_instruction(ins, actv_func1, it_before_actv);
1133
        auto ft = prog.insert_instruction(ins, actv_func1, ft_before_actv);
1134
        auto ct = prog.insert_instruction(ins, actv_func2, ct_before_actv);
1135
1136

        // equation Ct = ft (.) Ct-1 + it (.) ct
1137
1138
1139
        auto ft_cell = prog.insert_instruction(ins, make_op("mul"), ft, sic);
        auto it_ct   = prog.insert_instruction(ins, make_op("mul"), it, ct);
        auto cellt   = prog.insert_instruction(ins, make_op("add"), ft_cell, it_ct);
1140

Shucai Xiao's avatar
Shucai Xiao committed
1141
        if(pph != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
1142
        {
1143
1144
1145
            auto ppho_cellt = prog.insert_instruction(ins, make_op("mul"), ppho_brcst, cellt);
            ot_before_actv =
                prog.insert_instruction(ins, make_op("add"), ot_before_actv, ppho_cellt);
Shucai Xiao's avatar
Shucai Xiao committed
1146
        }
1147
1148
1149
1150
        auto ot = prog.insert_instruction(ins, actv_func1, ot_before_actv);

        // Ht = ot (.) h(Ct)
        auto h_cellt = prog.insert_instruction(ins, actv_func3, cellt);
1151
        auto ht      = prog.insert_instruction(ins, make_op("mul"), ot, h_cellt);
1152
1153
1154
1155

        sic = cellt;
        sih = ht;

1156
1157
1158
        last_hs_output = prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), ht);
        last_cell_output =
            prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), cellt);
1159

Shucai Xiao's avatar
Shucai Xiao committed
1160
        if(i < seq_len - 1)
1161
        {
Shucai Xiao's avatar
Shucai Xiao committed
1162
            if(i == 0)
1163
            {
Shucai Xiao's avatar
Shucai Xiao committed
1164
1165
                hidden_states = last_hs_output;
                cell_outputs  = last_cell_output;
1166
1167
1168
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
1169
1170
                auto concat_hs_arg0 = is_forward ? hidden_states : last_hs_output;
                auto concat_hs_arg1 = is_forward ? last_hs_output : hidden_states;
1171
1172
                hidden_states       = prog.insert_instruction(
                    ins, make_op("concat", {{"axis", 0}}), concat_hs_arg0, concat_hs_arg1);
Shucai Xiao's avatar
Shucai Xiao committed
1173
1174
1175

                auto concat_cell_arg0 = is_forward ? cell_outputs : last_cell_output;
                auto concat_cell_arg1 = is_forward ? last_cell_output : cell_outputs;
1176
1177
                cell_outputs          = prog.insert_instruction(
                    ins, make_op("concat", {{"axis", 0}}), concat_cell_arg0, concat_cell_arg1);
1178
1179
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
1180
    }
1181

Shucai Xiao's avatar
Shucai Xiao committed
1182
    return {hidden_states, last_hs_output, cell_outputs, last_cell_output};
Shucai Xiao's avatar
Shucai Xiao committed
1183
1184
1185
1186
1187
1188
1189
1190
1191
}

std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
{
    auto lstm_op = any_cast<op::lstm>(ins->get_operator());
    // before rewrite the lstm operator, need to ensure
    // we have 6 actv funcs, even though a user does not
    // specifiy any actv func. If less than 46, use the
    // algorithm in parse_lstm to make 6 actv functions
Shucai Xiao's avatar
Shucai Xiao committed
1192
    const auto& actv_funcs     = lstm_op.actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1193
    std::size_t num_actv_funcs = actv_funcs.size();
Shucai Xiao's avatar
Shucai Xiao committed
1194
    if(lstm_op.direction == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1195
1196
1197
1198
    {
        switch(num_actv_funcs)
        {
        case 0:
1199
1200
1201
1202
1203
1204
            return {make_op("sigmoid"),
                    make_op("tanh"),
                    make_op("tanh"),
                    make_op("sigmoid"),
                    make_op("tanh"),
                    make_op("tanh")};
Shucai Xiao's avatar
Shucai Xiao committed
1205
1206

        case 1:
Shucai Xiao's avatar
Shucai Xiao committed
1207
1208
1209
1210
1211
1212
            return {actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0)};
Shucai Xiao's avatar
Shucai Xiao committed
1213
1214

        case 2:
Shucai Xiao's avatar
Shucai Xiao committed
1215
1216
1217
1218
1219
1220
1221
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(1),
                    actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(1)};

Shucai Xiao's avatar
Shucai Xiao committed
1222
        case 3:
Shucai Xiao's avatar
Shucai Xiao committed
1223
1224
1225
1226
1227
1228
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2)};
Shucai Xiao's avatar
Shucai Xiao committed
1229
1230

        case 4:
Shucai Xiao's avatar
Shucai Xiao committed
1231
1232
1233
1234
1235
1236
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(3),
                    actv_funcs.at(3),
                    actv_funcs.at(3)};
Shucai Xiao's avatar
Shucai Xiao committed
1237
1238

        case 5:
Shucai Xiao's avatar
Shucai Xiao committed
1239
1240
1241
1242
1243
1244
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(3),
                    actv_funcs.at(4),
                    actv_funcs.at(4)};
Shucai Xiao's avatar
Shucai Xiao committed
1245

Shucai Xiao's avatar
Shucai Xiao committed
1246
        default: return actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1247
1248
1249
1250
1251
1252
        }
    }
    else
    {
        switch(num_actv_funcs)
        {
1253
        case 0: return {make_op("sigmoid"), make_op("tanh"), make_op("tanh")};
Shucai Xiao's avatar
Shucai Xiao committed
1254

Shucai Xiao's avatar
Shucai Xiao committed
1255
        case 1: return {actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0)};
Shucai Xiao's avatar
Shucai Xiao committed
1256

Shucai Xiao's avatar
Shucai Xiao committed
1257
1258
1259
        case 2: return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1)};

        default: return actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1260
1261
1262
1263
        }
    }
}

1264
bool rewrite_rnn::is_variable_seq_lens(const module& prog, instruction_ref seq_lens) const
Shucai Xiao's avatar
Shucai Xiao committed
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
{
    bool is_var_lens = false;
    if(seq_lens != prog.end())
    {
        if(seq_lens->can_eval())
        {
            auto arg_lens = seq_lens->eval();
            std::vector<int64_t> vec_lens;
            arg_lens.visit([&](auto l) { vec_lens.assign(l.begin(), l.end()); });
            int64_t l = 0;
            if(!vec_lens.empty())
            {
                l = vec_lens[0];
            }
            if(!std::all_of(vec_lens.begin(), vec_lens.end(), [&](auto v) { return v == l; }))
            {
                is_var_lens = true;
            }
        }
        else
        {
            is_var_lens = true;
        }
    }

    return is_var_lens;
}

std::size_t
1294
rewrite_rnn::get_seq_len(const module& prog, instruction_ref input, instruction_ref seq_lens) const
Shucai Xiao's avatar
Shucai Xiao committed
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
{
    bool is_var_lens = is_variable_seq_lens(prog, seq_lens);
    auto input_shape = input->get_shape();
    auto length      = input_shape.lens()[0];
    if(!is_var_lens and seq_lens != prog.end())
    {
        auto arg_len = seq_lens->eval();
        std::vector<std::size_t> vec_lens;
        arg_len.visit([&](auto l) { vec_lens.assign(l.begin(), l.end()); });
        length = vec_lens.empty() ? length : vec_lens[0];
    }

    return length;
}

1310
instruction_ref rewrite_rnn::replace_last_hs_output(module& prog,
Shucai Xiao's avatar
Shucai Xiao committed
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
                                                    instruction_ref ins,
                                                    instruction_ref seq_lens,
                                                    instruction_ref last_hs_output,
                                                    op::rnn_direction dirct) const
{
    bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);
    instruction_ref result_ins{};
    if(variable_seq_len)
    {
        result_ins = prog.insert_instruction(
1321
1322
1323
1324
1325
            std::next(ins),
            make_op("rnn_var_sl_shift_output",
                    {{"output_name", "hidden_states"}, {"direction", dirct}}),
            ins,
            seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1326
        prog.replace_instruction(ins, result_ins);
1327
1328
        auto hs_outputs = find_all(result_ins->outputs(),
                                   [&](auto i) { return i->name() == "rnn_last_hs_output"; });
Shucai Xiao's avatar
Shucai Xiao committed
1329

1330
        for(auto& hs_out : hs_outputs)
Shucai Xiao's avatar
Shucai Xiao committed
1331
        {
1332
            auto inputs = hs_out->inputs();
1333
1334
1335
1336
            prog.replace_instruction(hs_out,
                                     make_op("rnn_var_sl_last_output", {{"direction", dirct}}),
                                     inputs.front(),
                                     seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1337
1338
1339
1340
        }
    }
    else
    {
1341
1342
        auto hs_outputs =
            find_all(ins->outputs(), [&](auto i) { return i->name() == "rnn_last_hs_output"; });
Shucai Xiao's avatar
Shucai Xiao committed
1343

1344
1345
1346
        for(auto& hs_out : hs_outputs)
        {
            prog.replace_instruction(hs_out, last_hs_output);
Shucai Xiao's avatar
Shucai Xiao committed
1347
        }
1348

Shucai Xiao's avatar
Shucai Xiao committed
1349
1350
1351
1352
1353
1354
        result_ins = ins;
    }

    return result_ins;
}

1355
void rewrite_rnn::replace_last_cell_output(module& prog,
Shucai Xiao's avatar
Shucai Xiao committed
1356
1357
1358
1359
1360
1361
1362
                                           instruction_ref ins,
                                           instruction_ref seq_lens,
                                           instruction_ref cell_outputs,
                                           instruction_ref last_cell_output,
                                           op::rnn_direction dirct) const
{
    bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);
1363
1364
    auto ins_outputs =
        find_all(ins->outputs(), [&](auto i) { return i->name() == "rnn_last_cell_output"; });
Shucai Xiao's avatar
Shucai Xiao committed
1365
1366
1367

    if(variable_seq_len)
    {
1368
        if(!ins_outputs.empty())
Shucai Xiao's avatar
Shucai Xiao committed
1369
        {
1370
1371
1372
1373
1374
1375
            cell_outputs = prog.insert_instruction(
                std::next(ins),
                make_op("rnn_var_sl_shift_output",
                        {{"output_name", "cell_outputs"}, {"direction", dirct}}),
                cell_outputs,
                seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1376
1377
        }

1378
        for(auto co : ins_outputs)
Shucai Xiao's avatar
Shucai Xiao committed
1379
        {
1380
1381
1382
1383
            prog.replace_instruction(co,
                                     make_op("rnn_var_sl_last_output", {{"direction", dirct}}),
                                     cell_outputs,
                                     seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1384
1385
1386
1387
1388
1389
        }
    }
    // replace the rnn_last_cell_output with the last_cell_output. The while
    // loop is to handle the case of multiple rnn_last_cell_output operators
    else
    {
1390
        for(auto co : ins_outputs)
Shucai Xiao's avatar
Shucai Xiao committed
1391
        {
1392
            prog.replace_instruction(co, last_cell_output);
Shucai Xiao's avatar
Shucai Xiao committed
1393
1394
1395
1396
        }
    }
}

1397
instruction_ref rewrite_rnn::pad_hidden_states(module& prog,
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
                                               instruction_ref seq,
                                               instruction_ref seq_lens,
                                               instruction_ref hs) const
{
    auto max_seq_len = seq->get_shape().lens()[0];
    auto seq_len     = get_seq_len(prog, seq, seq_lens);

    // condition of all sequence are of the same length and
    // less than max_seq_len, we need to append the hs outputs
    auto hs_padded = hs;
    if(seq_len < max_seq_len)
    {
        auto s        = hs->get_shape();
        auto pad_lens = s.lens();
        pad_lens[0]   = static_cast<std::size_t>(max_seq_len - seq_len);
        shape pad_s{s.type(), pad_lens};
        std::vector<float> pad_data(pad_s.elements(), 0.0f);
1415
1416
1417
        auto pl = prog.add_literal(pad_s, pad_data.begin(), pad_data.end());
        hs_padded =
            prog.insert_instruction(std::next(hs), make_op("concat", {{"axis", 0}}), hs, pl);
1418
1419
1420
1421
1422
1423
        prog.replace_instruction(hs, hs_padded);
    }

    return hs_padded;
}

1424
namespace op {
Shucai Xiao's avatar
Shucai Xiao committed
1425
std::ostream& operator<<(std::ostream& os, rnn_direction v)
1426
{
Shucai Xiao's avatar
Shucai Xiao committed
1427
1428
    std::vector<std::string> rnn_direction_str = {"forward", "reverse", "bidirectional"};
    os << rnn_direction_str[static_cast<std::underlying_type<rnn_direction>::type>(v)];
1429
1430
    return os;
}
Shucai Xiao's avatar
Shucai Xiao committed
1431
} // namespace op
1432

Shucai Xiao's avatar
Shucai Xiao committed
1433
1434
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx