rewrite_rnn.cpp 55.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
Shucai Xiao's avatar
Shucai Xiao committed
24
25
26
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
Paul's avatar
Paul committed
27
28
29
30
31
32
33
34
35
36
37
38
39
#include <migraphx/op/add.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/gru.hpp>
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/rnn.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/squeeze.hpp>
#include <migraphx/op/sub.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/unsqueeze.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
40
#include <migraphx/op/contiguous.hpp>
41
42
43
#include <migraphx/op/common.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
44
45
#include <migraphx/make_op.hpp>

Shucai Xiao's avatar
Shucai Xiao committed
46
47
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
48
#include <migraphx/ranges.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
49
50
51
52

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

53
void rewrite_rnn::apply(module& m) const
Shucai Xiao's avatar
Shucai Xiao committed
54
{
55
    for(auto ins : iterator_for(m))
Shucai Xiao's avatar
Shucai Xiao committed
56
    {
Shucai Xiao's avatar
Shucai Xiao committed
57
        if(ins->name() == "rnn")
Shucai Xiao's avatar
Shucai Xiao committed
58
        {
59
            apply_vanilla_rnn(m, ins);
60
        }
61
        else if(ins->name() == "gru")
62
        {
63
            apply_gru(m, ins);
Shucai Xiao's avatar
Shucai Xiao committed
64
        }
65
66
        else if(ins->name() == "lstm")
        {
67
            apply_lstm(m, ins);
68
        }
Shucai Xiao's avatar
Shucai Xiao committed
69
    }
70
71
}

72
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
73
void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
74
75
76
77
{
    assert(ins->name() == "rnn");
    // could be 3 to 6 inputs, but the parse_rnn function will
    // append undefined operators to make 6 arguments when parsing
Shucai Xiao's avatar
Shucai Xiao committed
78
    // an onnx file. Another case is user can have num of arguments
79
    // when writing their module.
80
81
82
83
84
85
86
87
88
    auto args = ins->inputs();

    shape seq_shape         = args[0]->get_shape();
    std::size_t hidden_size = args[1]->get_shape().lens()[1];
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
    migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
    std::vector<float> data(ih_shape.elements(), 0);

Shucai Xiao's avatar
Shucai Xiao committed
89
90
    auto actv_funcs         = vanilla_rnn_actv_funcs(ins);
    auto rnn_op             = any_cast<op::rnn>(ins->get_operator());
91
92
93
    op::rnn_direction dirct = rnn_op.direction;

    // process sequence length
94
    instruction_ref seq_lens = m.end();
95
96
97
98
99
    if((args.size() >= 5) && args[4]->name() != "undefined")
    {
        seq_lens = args[4];
    }

100
    bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
101

102
    instruction_ref last_output{};
103
    if(dirct == op::rnn_direction::bidirectional)
104
105
    {
        // input weight matrix
106
        auto w_forward = m.insert_instruction(
107
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
108
        auto w_reverse = m.insert_instruction(
109
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
110
111

        // hidden state weight matrix
112
        auto r_forward = m.insert_instruction(
113
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
114
        auto r_reverse = m.insert_instruction(
115
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
116
117

        // process bias
118
119
        instruction_ref bias_forward = m.end();
        instruction_ref bias_reverse = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
120
        if(args.size() >= 4 && args[3]->name() != "undefined")
121
        {
122
            bias_forward = m.insert_instruction(
123
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
124
            bias_reverse = m.insert_instruction(
125
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
126
127
128
129
130
131
        }

        // process intial hidden state, it could be the 6th argument
        // or the 5th one (if the sequence len argument is ignored)
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
132
        if(args.size() == 6 && args[5]->name() != "undefined")
133
        {
134
            ih_forward = m.insert_instruction(
135
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
136
            ih_reverse = m.insert_instruction(
137
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
138
139
140
        }
        else
        {
141
142
            ih_forward = m.add_literal(migraphx::literal{ih_shape, data});
            ih_reverse = m.add_literal(migraphx::literal{ih_shape, data});
143
144
        }

145
146
        auto ret_forward =
            vanilla_rnn_cell(true,
147
                             m,
148
149
150
151
152
153
                             ins,
                             {args[0], w_forward, r_forward, bias_forward, seq_lens, ih_forward},
                             actv_funcs.at(0));

        if(variable_seq_len)
        {
154
155
            args[0] =
                m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
156
157
158
159
        }

        auto ret_reverse =
            vanilla_rnn_cell(false,
160
                             m,
161
162
163
                             ins,
                             {args[0], w_reverse, r_reverse, bias_reverse, seq_lens, ih_reverse},
                             actv_funcs.at(1));
164

165
        auto concat_output = m.insert_instruction(
166
            ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
167
        last_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_output);
168
169
170
171

        // The following logic is to ensure the last instruction rewritten from
        // rnn operator is a concat instruction
        // sequence len is 1
172
        if(ret_forward[0] == m.end())
173
        {
174
            m.replace_instruction(
175
                ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
176
177
178
        }
        else
        {
179
            ret_forward[0] = m.insert_instruction(
180
                ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
181
            ret_reverse[0] = m.insert_instruction(
182
                ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
183
            m.replace_instruction(
184
                ins, make_op("concat", {{"axis", 1}}), {ret_forward[0], ret_reverse[0]});
185
186
187
188
        }
    }
    else
    {
189
        bool is_forward = (dirct == op::rnn_direction::forward);
190
191
192
193
194
195
196
        // input weight matrix
        auto w = args[1];

        // hidden state weight matrix
        auto r = args[2];

        // process bias and initial hidden state
197
        instruction_ref bias = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
198
        if(args.size() >= 4 && args[3]->name() != "undefined")
199
200
201
202
203
204
        {
            bias = args[3];
        }

        // process intial hidden state
        instruction_ref ih;
Shucai Xiao's avatar
Shucai Xiao committed
205
        if(args.size() == 6 && args[5]->name() != "undefined")
206
207
208
209
210
        {
            ih = args[5];
        }
        else
        {
211
            ih = m.add_literal(migraphx::literal{ih_shape, data});
212
213
        }

214
        if(not is_forward and variable_seq_len)
215
        {
216
217
            args[0] =
                m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
218
219
220
        }

        auto ret = vanilla_rnn_cell(
221
222
            is_forward, m, ins, {args[0], w, r, bias, seq_lens, ih}, actv_funcs.at(0));
        last_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
223
224
225
226

        // following logic is to ensure the last instruction is a
        // concat instruction
        // sequence len is 1
227
        if(ret[0] == m.end())
228
        {
229
            m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
230
231
232
233
234
        }
        else
        {
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
235
            m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
236
237
238
        }
    }

239
240
    // in case of all sequences are of the same lengths and shorter than the
    // max sequence length, need to pad 0's at the end for output hidden states
241
242
    ins = pad_hidden_states(m, args[0], seq_lens, ins);
    replace_last_hs_output(m, ins, seq_lens, last_output, dirct);
Shucai Xiao's avatar
Shucai Xiao committed
243
244
}

Shucai Xiao's avatar
Shucai Xiao committed
245
std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
246
                                                           module& m,
Shucai Xiao's avatar
Shucai Xiao committed
247
                                                           instruction_ref ins,
248
                                                           std::vector<instruction_ref> inputs,
Shucai Xiao's avatar
Shucai Xiao committed
249
                                                           operation& actv_func) const
Shucai Xiao's avatar
Shucai Xiao committed
250
{
251
252
253
254
255
256
257
258
    assert(inputs.size() == 6);
    auto seq      = inputs.at(0);
    auto w        = inputs.at(1);
    auto r        = inputs.at(2);
    auto bias     = inputs.at(3);
    auto seq_lens = inputs.at(4);
    auto ih       = inputs.at(5);

Shucai Xiao's avatar
Shucai Xiao committed
259
260
    // squeeze and transpose w
    std::vector<int64_t> perm{1, 0};
261
262
    auto sw      = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
    auto tran_sw = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
Shucai Xiao's avatar
Shucai Xiao committed
263
264

    // squeeze and transpose r
265
266
    auto sr      = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
    auto tran_sr = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr);
Shucai Xiao's avatar
Shucai Xiao committed
267
268

    // initial hidden state
269
    auto sih      = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
270
    auto sih_lens = sih->get_shape().lens();
Shucai Xiao's avatar
Shucai Xiao committed
271
272

    // bias
273
    instruction_ref bb{};
274
    if(bias != m.end())
Shucai Xiao's avatar
Shucai Xiao committed
275
    {
276
        long hs    = static_cast<long>(r->get_shape().lens()[2]);
277
278
        auto sbias = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
        auto wb    = m.insert_instruction(
279
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), sbias);
280
        auto rb = m.insert_instruction(
281
            ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), sbias);
282
283
        auto wrb = m.insert_instruction(ins, make_op("add"), wb, rb);
        bb       = m.insert_instruction(
284
            ins, make_op("broadcast", {{"axis", 1}, {"out_lens", sih_lens}}), wrb);
Shucai Xiao's avatar
Shucai Xiao committed
285
286
    }

287
    instruction_ref hidden_out = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
288
    instruction_ref last_out{};
289
290
    last_out     = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), sih);
    long seq_len = get_seq_len(m, seq, seq_lens);
291
    for(long i = 0; i < seq_len; i++)
Shucai Xiao's avatar
Shucai Xiao committed
292
    {
Shucai Xiao's avatar
Shucai Xiao committed
293
        long seq_index = is_forward ? i : (seq_len - 1 - i);
294
        auto xt        = m.insert_instruction(
295
296
297
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
            seq);
298
299
300
301
302
        auto cont_xt = m.insert_instruction(ins, make_op("contiguous"), xt);
        xt           = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
        auto xt_wi   = m.insert_instruction(ins, make_op("dot"), xt, tran_sw);
        auto ht_ri   = m.insert_instruction(ins, make_op("dot"), sih, tran_sr);
        if(bias != m.end())
Shucai Xiao's avatar
Shucai Xiao committed
303
        {
304
            xt_wi = m.insert_instruction(ins, make_op("add"), xt_wi, bb);
Shucai Xiao's avatar
Shucai Xiao committed
305
        }
306
        auto xt_ht = m.insert_instruction(ins, make_op("add"), xt_wi, ht_ri);
Shucai Xiao's avatar
Shucai Xiao committed
307
308

        // apply activation function
309
        auto ht = m.insert_instruction(ins, actv_func, xt_ht);
Shucai Xiao's avatar
Shucai Xiao committed
310
        sih     = ht;
Shucai Xiao's avatar
Shucai Xiao committed
311

Shucai Xiao's avatar
Shucai Xiao committed
312
313
        // add the dimensions of sequence length (axis 0 for sequence length,
        // axis 1 for num_directions
314
        last_out = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), ht);
Shucai Xiao's avatar
Shucai Xiao committed
315

Shucai Xiao's avatar
Shucai Xiao committed
316
317
318
        // concatenation for the last last_out is performed in the apply()
        // function to ensure the last instruction is concat, then we have
        // output inserted
Shucai Xiao's avatar
Shucai Xiao committed
319
        if(i < seq_len - 1)
Shucai Xiao's avatar
Shucai Xiao committed
320
        {
Shucai Xiao's avatar
Shucai Xiao committed
321
322
            if(is_forward)
            {
323
324
                hidden_out = (seq_index == 0)
                                 ? last_out
325
                                 : m.insert_instruction(
326
                                       ins, make_op("concat", {{"axis", 0}}), hidden_out, last_out);
Shucai Xiao's avatar
Shucai Xiao committed
327
328
329
            }
            else
            {
330
331
                hidden_out = (seq_index == seq_len - 1)
                                 ? last_out
332
                                 : m.insert_instruction(
333
                                       ins, make_op("concat", {{"axis", 0}}), last_out, hidden_out);
Shucai Xiao's avatar
Shucai Xiao committed
334
            }
Shucai Xiao's avatar
Shucai Xiao committed
335
336
337
        }
    }

338
    return {hidden_out, last_out};
Shucai Xiao's avatar
Shucai Xiao committed
339
340
}

Shucai Xiao's avatar
Shucai Xiao committed
341
std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins) const
342
343
{
    auto rnn_op = any_cast<op::rnn>(ins->get_operator());
Shucai Xiao's avatar
Shucai Xiao committed
344
345
346
347
    // could be 3 to 6 inputs, but the parse_gru function will
    // append undefined operators to make 6 arguments when parsing
    // an onnx file. Another case is user can have any num of arguments
    // when writing their program.
348
    if(rnn_op.direction == op::rnn_direction::bidirectional)
349
    {
Shucai Xiao's avatar
Shucai Xiao committed
350
        if(rnn_op.actv_funcs.empty())
351
352
        {
            // default is tanh
353
            return {make_op("tanh"), make_op("tanh")};
354
        }
Shucai Xiao's avatar
Shucai Xiao committed
355
        else if(rnn_op.actv_funcs.size() == 1)
356
357
358
359
360
361
362
363
364
365
        {
            return {rnn_op.actv_funcs.at(0), rnn_op.actv_funcs.at(0)};
        }
        else
        {
            return rnn_op.actv_funcs;
        }
    }
    else
    {
Shucai Xiao's avatar
Shucai Xiao committed
366
        if(rnn_op.actv_funcs.empty())
367
368
        {
            // default is tanh
369
            return {make_op("tanh")};
370
371
372
373
374
375
376
377
        }
        else
        {
            return rnn_op.actv_funcs;
        }
    }
}

378
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
379
void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
380
381
382
{
    assert(ins->name() == "gru");
    const auto actv_funcs = gru_actv_funcs(ins);
Shucai Xiao's avatar
Shucai Xiao committed
383
384
385
386
    // could be 3 to 6 inputs, but the parse_gru function will
    // append undefined operators to make 6 arguments when parsing
    // an onnx file. Another case is user can have num of arguments
    // when writing their program.
387
388
389
390
391
392
393
394
395
    auto args = ins->inputs();

    shape seq_shape         = args[0]->get_shape();
    std::size_t hidden_size = args[2]->get_shape().lens()[2];
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
    migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
    std::vector<float> data(ih_shape.elements(), 0.0);

Shucai Xiao's avatar
Shucai Xiao committed
396
    auto gru_op             = any_cast<op::gru>(ins->get_operator());
397
398
399
    op::rnn_direction dirct = gru_op.direction;

    // process sequence length
400
    instruction_ref seq_lens = m.end();
401
402
403
404
405
    if((args.size() >= 5) && args[4]->name() != "undefined")
    {
        seq_lens = args[4];
    }

406
    bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
407

408
    instruction_ref last_output{};
409
    if(dirct == op::rnn_direction::bidirectional)
410
411
    {
        // w weight matrix
412
        auto w_forward = m.insert_instruction(
413
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
414
        auto w_reverse = m.insert_instruction(
415
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
416
417

        // r weight matrix
418
        auto r_forward = m.insert_instruction(
419
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
420
        auto r_reverse = m.insert_instruction(
421
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
422
423

        // bias
424
425
        instruction_ref bias_forward = m.end();
        instruction_ref bias_reverse = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
426
        if(args.size() >= 4 && args[3]->name() != "undefined")
427
        {
428
            bias_forward = m.insert_instruction(
429
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
430
            bias_reverse = m.insert_instruction(
431
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
432
433
434
435
436
        }

        // intial hidden state
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
437
        if(args.size() == 6 && args[5]->name() != "undefined")
438
        {
439
            ih_forward = m.insert_instruction(
440
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
441
            ih_reverse = m.insert_instruction(
442
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
443
444
445
        }
        else
        {
446
447
            ih_forward = m.add_literal(migraphx::literal{ih_shape, data});
            ih_reverse = m.add_literal(migraphx::literal{ih_shape, data});
448
449
        }

450
451
        auto ret_forward =
            gru_cell(true,
452
                     m,
453
454
455
456
457
458
459
460
                     ins,
                     {args[0], w_forward, r_forward, bias_forward, seq_lens, ih_forward},
                     gru_op.linear_before_reset,
                     actv_funcs.at(0),
                     actv_funcs.at(1));

        if(variable_seq_len)
        {
461
462
            args[0] =
                m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
463
        }
Shucai Xiao's avatar
Shucai Xiao committed
464

465
466
        auto ret_reverse =
            gru_cell(false,
467
                     m,
468
469
470
471
472
                     ins,
                     {args[0], w_reverse, r_reverse, bias_reverse, seq_lens, ih_reverse},
                     gru_op.linear_before_reset,
                     actv_funcs.at(2),
                     actv_funcs.at(3));
473

474
        auto concat_output = m.insert_instruction(
475
            ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
476
        last_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_output);
477
478
479

        // The following logic is to ensure the last instruction rewritten
        // from gru operator is a concat
480
        if(ret_forward[0] == m.end())
481
        {
482
            m.replace_instruction(
483
                ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
484
485
486
        }
        else
        {
487
            ret_forward[0] = m.insert_instruction(
488
                ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
489
            ret_reverse[0] = m.insert_instruction(
490
                ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
491
            m.replace_instruction(
492
                ins, make_op("concat", {{"axis", 1}}), {ret_forward[0], ret_reverse[0]});
493
494
495
496
        }
    }
    else
    {
497
        bool is_forward = (dirct == op::rnn_direction::forward);
498
499
500
501
502
        // weight matrix
        auto w = args[1];
        auto r = args[2];

        // bias
503
        instruction_ref bias = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
504
        if(args.size() >= 4 && args[3]->name() != "undefined")
505
506
507
508
509
510
        {
            bias = args[3];
        }

        // intial hidden state
        instruction_ref ih{};
Shucai Xiao's avatar
Shucai Xiao committed
511
        if(args.size() == 6 && args[5]->name() != "undefined")
512
513
514
515
516
        {
            ih = args[5];
        }
        else
        {
517
            ih = m.add_literal(migraphx::literal{ih_shape, data});
518
519
        }

520
        if(not is_forward and variable_seq_len)
521
        {
522
523
            args[0] =
                m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
524
525
        }

526
        auto ret = gru_cell(is_forward,
527
                            m,
528
                            ins,
529
                            {args[0], w, r, bias, seq_lens, ih},
530
531
532
533
                            gru_op.linear_before_reset,
                            actv_funcs.at(0),
                            actv_funcs.at(1));

534
        last_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
535

536
        if(ret[0] == m.end())
537
        {
538
            m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
539
540
541
542
543
        }
        else
        {
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
544
            m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
545
546
547
        }
    }

548
549
    // in case of all sequences are of the same lengths and shorter than the
    // max sequence length, need to pad 0's at the end for output hidden states
550
551
    ins = pad_hidden_states(m, args[0], seq_lens, ins);
    replace_last_hs_output(m, ins, seq_lens, last_output, dirct);
552
553
}

554
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
555
std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
556
                                                   module& m,
Shucai Xiao's avatar
Shucai Xiao committed
557
558
559
560
561
                                                   instruction_ref ins,
                                                   std::vector<instruction_ref> inputs,
                                                   int linear_before_reset,
                                                   const operation& actv_func1,
                                                   const operation& actv_func2) const
562
{
563
564
565
566
567
568
569
    assert(inputs.size() == 6);
    auto seq      = inputs.at(0);
    auto w        = inputs.at(1);
    auto r        = inputs.at(2);
    auto bias     = inputs.at(3);
    auto seq_lens = inputs.at(4);
    auto ih       = inputs.at(5);
570

571
    instruction_ref hidden_states = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
572
    instruction_ref last_output{};
Shucai Xiao's avatar
Shucai Xiao committed
573
    migraphx::shape seq_shape = seq->get_shape();
Shucai Xiao's avatar
Shucai Xiao committed
574
    migraphx::shape r_shape   = r->get_shape();
Paul Fultz II's avatar
Paul Fultz II committed
575
    long hs                   = r_shape.lens()[2];
576

577
578
    migraphx::shape ss(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]});
    std::vector<float> data(ss.elements(), 1.0f);
579
    auto l1 = m.add_literal(migraphx::literal{ss, data});
580

581
    // w matrix squeeze to 2-dim and do a transpose
582
    std::vector<int64_t> perm{1, 0};
583
584
    auto sw = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
    auto tw = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
585

586
    // r slide to two part, zr and h
587
588
    auto sr  = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
    auto rzr = m.insert_instruction(
589
        ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2 * hs}}}), sr);
590
    auto trzr = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rzr);
591

592
    auto rh = m.insert_instruction(
593
        ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), sr);
594
    auto trh = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rh);
595
596

    // initial states
597
    auto sih  = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
Shucai Xiao's avatar
Shucai Xiao committed
598
    size_t bs = ih->get_shape().lens()[1];
599
600

    // bias
601
602
603
    instruction_ref bwb{};
    instruction_ref brb_zr{};
    instruction_ref brb_h{};
604
    if(bias != m.end())
605
    {
606
607
        auto sbias = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
        auto wb    = m.insert_instruction(
608
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {3 * hs}}}), sbias);
609
        bwb = m.insert_instruction(
610
            ins,
611
            make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(3 * hs)}}}),
612
613
            wb);

614
        auto rb_zr = m.insert_instruction(
615
616
617
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {3 * hs}}, {"ends", {5 * hs}}}),
            sbias);
618
        auto rb_h = m.insert_instruction(
619
620
621
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {5 * hs}}, {"ends", {6 * hs}}}),
            sbias);
622
        brb_zr = m.insert_instruction(
623
            ins,
624
            make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(2 * hs)}}}),
625
            rb_zr);
626
        brb_h = m.insert_instruction(
627
            ins,
628
            make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(hs)}}}),
629
            rb_h);
630
631
    }

632
    long seq_len = get_seq_len(m, seq, seq_lens);
633
634
635
    for(long i = 0; i < seq_len; i++)
    {
        long seq_index = is_forward ? i : (seq_len - 1 - i);
636
        auto xt        = m.insert_instruction(
637
638
639
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
            seq);
640
641
        auto cont_xt = m.insert_instruction(ins, make_op("contiguous"), xt);
        xt           = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
642

643
644
645
        auto xt_w    = m.insert_instruction(ins, make_op("dot"), xt, tw);
        auto ih1_rzr = m.insert_instruction(ins, make_op("dot"), sih, trzr);
        if(bias != m.end())
646
        {
647
648
            xt_w    = m.insert_instruction(ins, make_op("add"), xt_w, bwb);
            ih1_rzr = m.insert_instruction(ins, make_op("add"), ih1_rzr, brb_zr);
649
650
        }

651
        auto xw_z = m.insert_instruction(
652
            ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), xt_w);
653
        auto xw_r = m.insert_instruction(
654
            ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), xt_w);
655
        auto xw_h = m.insert_instruction(
656
            ins, make_op("slice", {{"axes", {1}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), xt_w);
657

658
        auto hr_z = m.insert_instruction(
659
            ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), ih1_rzr);
660
        auto hr_r = m.insert_instruction(
661
            ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), ih1_rzr);
662

663
664
        auto xw_hr_z = m.insert_instruction(ins, make_op("add"), xw_z, hr_z);
        auto zt      = m.insert_instruction(ins, actv_func1, xw_hr_z);
665

666
667
        auto xw_hr_r = m.insert_instruction(ins, make_op("add"), xw_r, hr_r);
        auto rt      = m.insert_instruction(ins, actv_func1, xw_hr_r);
668
669
670
671
672

        instruction_ref hr_h{};
        if(linear_before_reset == 0)
        {
            // equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
673
674
675
            auto rt_ht1 = m.insert_instruction(ins, make_op("mul"), rt, sih);
            hr_h        = m.insert_instruction(ins, make_op("dot"), rt_ht1, trh);
            if(bias != m.end())
676
            {
677
                hr_h = m.insert_instruction(ins, make_op("add"), hr_h, brb_h);
678
            }
679
680
681
        }
        else
        {
682
            // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
683
684
            auto ht1_rh = m.insert_instruction(ins, make_op("dot"), sih, trh);
            if(bias != m.end())
685
            {
686
                ht1_rh = m.insert_instruction(ins, make_op("add"), ht1_rh, brb_h);
687
            }
688
            hr_h = m.insert_instruction(ins, make_op("mul"), rt, ht1_rh);
689
690
        }

691
692
        auto xw_hr_h = m.insert_instruction(ins, make_op("add"), xw_h, hr_h);
        auto ht      = m.insert_instruction(ins, actv_func2, xw_hr_h);
693

694
        // equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
695
696
697
698
699
        auto one_minus_zt    = m.insert_instruction(ins, make_op("sub"), l1, zt);
        auto one_minus_zt_ht = m.insert_instruction(ins, make_op("mul"), one_minus_zt, ht);
        auto zt_ht1          = m.insert_instruction(ins, make_op("mul"), zt, sih);
        sih                  = m.insert_instruction(ins, make_op("add"), one_minus_zt_ht, zt_ht1);
        last_output = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), sih);
700
701
702
703
704
705
706
707

        if(i < seq_len - 1)
        {
            if(is_forward)
            {
                hidden_states =
                    (seq_index == 0)
                        ? last_output
708
                        : m.insert_instruction(
709
                              ins, make_op("concat", {{"axis", 0}}), hidden_states, last_output);
710
711
712
713
714
715
            }
            else
            {
                hidden_states =
                    (seq_index == seq_len - 1)
                        ? last_output
716
                        : m.insert_instruction(
717
                              ins, make_op("concat", {{"axis", 0}}), last_output, hidden_states);
718
719
720
721
722
723
724
725
726
727
728
729
730
731
            }
        }
    }

    return {hidden_states, last_output};
}

std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
{
    auto gru_op = any_cast<op::gru>(ins->get_operator());
    // before rewrite the gru operator, need to ensure
    // we have 4 actv funcs, even though a user does not
    // specifiy any actv func. If less than 4, use the
    // algorithm in parse_gru to make 4 actv functions
732
    if(gru_op.direction == op::rnn_direction::bidirectional)
733
734
    {
        if(gru_op.actv_funcs.empty())
735
            return {make_op("sigmoid"), make_op("tanh"), make_op("sigmoid"), make_op("tanh")};
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
        else if(gru_op.actv_funcs.size() == 1)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0)};
        else if(gru_op.actv_funcs.size() == 2)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1)};
        else if(gru_op.actv_funcs.size() == 3)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1),
                    gru_op.actv_funcs.at(2),
                    gru_op.actv_funcs.at(0)};
        else
            return gru_op.actv_funcs;
    }
    else
    {
        if(gru_op.actv_funcs.empty())
757
            return {make_op("sigmoid"), make_op("tanh")};
758
759
760
761
762
763
764
        else if(gru_op.actv_funcs.size() == 1)
            return {gru_op.actv_funcs.at(0), gru_op.actv_funcs.at(0)};
        else
            return gru_op.actv_funcs;
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
765
// for lstm operators
766
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
767
void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
Shucai Xiao's avatar
Shucai Xiao committed
768
769
770
{
    assert(ins->name() == "lstm");
    auto args = ins->inputs();
Shucai Xiao's avatar
Shucai Xiao committed
771

Shucai Xiao's avatar
Shucai Xiao committed
772
    shape seq_shape         = args[0]->get_shape();
773
    std::size_t hidden_size = args[2]->get_shape().lens()[2];
Shucai Xiao's avatar
Shucai Xiao committed
774
775
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
Shucai Xiao's avatar
Shucai Xiao committed
776
    migraphx::shape ihc_shape{type, {1, batch_size, hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
777
    std::vector<float> ihc_data(ihc_shape.elements(), 0.0);
Shucai Xiao's avatar
Shucai Xiao committed
778
779
780

    migraphx::shape pph_shape{type, {1, 3 * hidden_size}};

Shucai Xiao's avatar
Shucai Xiao committed
781
782
    auto actv_funcs         = lstm_actv_funcs(ins);
    auto lstm_op            = any_cast<op::lstm>(ins->get_operator());
Shucai Xiao's avatar
Shucai Xiao committed
783
    op::rnn_direction dirct = lstm_op.direction;
Shucai Xiao's avatar
Shucai Xiao committed
784

Shucai Xiao's avatar
Shucai Xiao committed
785
    // process sequence length
786
    instruction_ref seq_lens = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
787
788
789
790
791
    if((args.size() >= 5) && args[4]->name() != "undefined")
    {
        seq_lens = args[4];
    }

792
    bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
793
794

    instruction_ref last_hs_output{};
Shucai Xiao's avatar
Shucai Xiao committed
795
    instruction_ref last_cell_output{};
Shucai Xiao's avatar
Shucai Xiao committed
796
797
    instruction_ref hidden_state{};
    instruction_ref cell_outputs{};
Shucai Xiao's avatar
Shucai Xiao committed
798
    if(dirct == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
799
800
801
    {
        // input weight matrix
        // input weight matrix
802
        auto w_forward = m.insert_instruction(
803
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
804
        auto w_reverse = m.insert_instruction(
805
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
Shucai Xiao's avatar
Shucai Xiao committed
806
807

        // hidden state weight matrix
808
        auto r_forward = m.insert_instruction(
809
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
810
        auto r_reverse = m.insert_instruction(
811
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
Shucai Xiao's avatar
Shucai Xiao committed
812
813

        // process bias
814
815
        instruction_ref bias_forward = m.end();
        instruction_ref bias_reverse = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
816
        if(args.size() >= 4 && args[3]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
817
        {
818
            bias_forward = m.insert_instruction(
819
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
820
            bias_reverse = m.insert_instruction(
821
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
Shucai Xiao's avatar
Shucai Xiao committed
822
823
824
825
826
827
828
        }

        // process intial hidden state, it is the 6th argument
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
        if(args.size() >= 6 && args[5]->name() != "undefined")
        {
829
            ih_forward = m.insert_instruction(
830
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
831
            ih_reverse = m.insert_instruction(
832
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
Shucai Xiao's avatar
Shucai Xiao committed
833
834
835
        }
        else
        {
836
837
            ih_forward = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
            ih_reverse = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
Shucai Xiao's avatar
Shucai Xiao committed
838
839
840
841
842
        }

        // process initial cell value
        instruction_ref ic_forward{};
        instruction_ref ic_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
843
        if(args.size() >= 7 && args[6]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
844
        {
845
            ic_forward = m.insert_instruction(
846
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[6]);
847
            ic_reverse = m.insert_instruction(
848
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[6]);
Shucai Xiao's avatar
Shucai Xiao committed
849
850
851
        }
        else
        {
852
853
            ic_forward = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
            ic_reverse = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
Shucai Xiao's avatar
Shucai Xiao committed
854
855
856
        }

        // process weight of the peephole
857
858
        instruction_ref pph_forward = m.end();
        instruction_ref pph_reverse = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
859
        if(args.size() == 8 && args[7]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
860
        {
861
            pph_forward = m.insert_instruction(
862
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[7]);
863
            pph_reverse = m.insert_instruction(
864
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[7]);
Shucai Xiao's avatar
Shucai Xiao committed
865
        }
Shucai Xiao's avatar
Shucai Xiao committed
866

Shucai Xiao's avatar
Shucai Xiao committed
867
        auto ret_forward = lstm_cell(true,
868
                                     m,
Shucai Xiao's avatar
Shucai Xiao committed
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
                                     ins,
                                     {args[0],
                                      w_forward,
                                      r_forward,
                                      bias_forward,
                                      seq_lens,
                                      ih_forward,
                                      ic_forward,
                                      pph_forward},
                                     actv_funcs.at(0),
                                     actv_funcs.at(1),
                                     actv_funcs.at(2));

        if(variable_seq_len)
        {
884
885
            args[0] =
                m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
886
887
        }
        auto ret_reverse = lstm_cell(false,
888
                                     m,
Shucai Xiao's avatar
Shucai Xiao committed
889
890
891
892
893
894
895
896
897
898
899
900
901
                                     ins,
                                     {args[0],
                                      w_reverse,
                                      r_reverse,
                                      bias_reverse,
                                      seq_lens,
                                      ih_reverse,
                                      ic_reverse,
                                      pph_reverse},
                                     actv_funcs.at(3),
                                     actv_funcs.at(4),
                                     actv_funcs.at(5));

902
        auto concat_hs_output = m.insert_instruction(
903
            ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
904
        auto concat_cell_output = m.insert_instruction(
905
906
            ins, make_op("concat", {{"axis", 1}}), ret_forward[3], ret_reverse[3]);
        last_hs_output =
907
            m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_hs_output);
908
        last_cell_output =
909
            m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_cell_output);
Shucai Xiao's avatar
Shucai Xiao committed
910
911

        // the following logic is to ensure the last instruction is a concat
912
        if(ret_forward[0] == m.end())
Shucai Xiao's avatar
Shucai Xiao committed
913
        {
Shucai Xiao's avatar
Shucai Xiao committed
914
            cell_outputs = concat_cell_output;
Shucai Xiao's avatar
Shucai Xiao committed
915
916
917
        }
        else
        {
918
            ret_forward[1] = m.insert_instruction(
919
                ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
920
            ret_reverse[1] = m.insert_instruction(
921
                ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
Shucai Xiao's avatar
Shucai Xiao committed
922

923
            ret_forward[3] = m.insert_instruction(
924
                ins, make_op("concat", {{"axis", 0}}), ret_forward[2], ret_forward[3]);
925
            ret_reverse[3] = m.insert_instruction(
926
                ins, make_op("concat", {{"axis", 0}}), ret_reverse[3], ret_reverse[2]);
927
            cell_outputs = m.insert_instruction(
928
                ins, make_op("concat", {{"axis", 1}}), ret_forward[3], ret_reverse[3]);
Shucai Xiao's avatar
Shucai Xiao committed
929
        }
Shucai Xiao's avatar
Shucai Xiao committed
930

931
        hidden_state = m.replace_instruction(
932
            ins, make_op("concat", {{"axis", 1}}), {ret_forward[1], ret_reverse[1]});
Shucai Xiao's avatar
Shucai Xiao committed
933
934
935
    }
    else
    {
Shucai Xiao's avatar
Shucai Xiao committed
936
        bool is_forward = (dirct == op::rnn_direction::forward);
Shucai Xiao's avatar
Shucai Xiao committed
937
938
939
940
941
        // weight matrices
        auto w = args[1];
        auto r = args[2];

        // bias
942
        instruction_ref bias = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
943
        if(args.size() >= 4 && args[3]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
944
945
946
947
948
949
        {
            bias = args[3];
        }

        // initial hidden state
        instruction_ref ih{};
Shucai Xiao's avatar
Shucai Xiao committed
950
        if(args.size() >= 6 && args[5]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
951
952
953
954
955
        {
            ih = args[5];
        }
        else
        {
956
            ih = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
Shucai Xiao's avatar
Shucai Xiao committed
957
958
959
960
        }

        // initial cell value
        instruction_ref ic{};
Shucai Xiao's avatar
Shucai Xiao committed
961
        if(args.size() >= 7 && args[6]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
962
963
964
965
966
        {
            ic = args[6];
        }
        else
        {
967
            ic = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
Shucai Xiao's avatar
Shucai Xiao committed
968
969
970
        }

        // process weight of the peephole
971
        instruction_ref pph = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
972
973
974
975
        if(args.size() == 8 && args[7]->name() != "undefined")
        {
            pph = args[7];
        }
Shucai Xiao's avatar
Shucai Xiao committed
976

977
        if(not is_forward and variable_seq_len)
Shucai Xiao's avatar
Shucai Xiao committed
978
        {
979
980
            args[0] =
                m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
981
        }
Shucai Xiao's avatar
Shucai Xiao committed
982
        auto ret = lstm_cell(is_forward,
983
                             m,
Shucai Xiao's avatar
Shucai Xiao committed
984
                             ins,
Shucai Xiao's avatar
Shucai Xiao committed
985
                             {args[0], w, r, bias, seq_lens, ih, ic, pph},
Shucai Xiao's avatar
Shucai Xiao committed
986
987
988
989
                             actv_funcs.at(0),
                             actv_funcs.at(1),
                             actv_funcs.at(2));

990
991
        last_hs_output   = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
        last_cell_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[3]);
Shucai Xiao's avatar
Shucai Xiao committed
992

993
        if(ret[0] == m.end())
Shucai Xiao's avatar
Shucai Xiao committed
994
        {
Shucai Xiao's avatar
Shucai Xiao committed
995
            cell_outputs = ret[3];
996
            hidden_state = m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
Shucai Xiao's avatar
Shucai Xiao committed
997
998
999
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
1000
1001
            auto concat_cell_arg0 = is_forward ? ret[2] : ret[3];
            auto concat_cell_arg1 = is_forward ? ret[3] : ret[2];
1002
            cell_outputs          = m.insert_instruction(
1003
                ins, make_op("concat", {{"axis", 0}}), concat_cell_arg0, concat_cell_arg1);
Shucai Xiao's avatar
Shucai Xiao committed
1004

Shucai Xiao's avatar
Shucai Xiao committed
1005
1006
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
1007
            hidden_state     = m.replace_instruction(
1008
                ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
Shucai Xiao's avatar
Shucai Xiao committed
1009
1010
1011
        }
    }

1012
1013
    // in case of all sequences are of the same lengths and shorter than the
    // max sequence length, need to pad 0's at the end for output hidden states
1014
    hidden_state = pad_hidden_states(m, args[0], seq_lens, hidden_state);
1015
1016

    // replace last hidden states with corresponding instructions
1017
    ins = replace_last_hs_output(m, hidden_state, seq_lens, last_hs_output, dirct);
1018
1019

    // replace last cell outputs with corresponding instructions
1020
    replace_last_cell_output(m, ins, seq_lens, cell_outputs, last_cell_output, dirct);
Shucai Xiao's avatar
Shucai Xiao committed
1021
1022
}

1023
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
Shucai Xiao's avatar
Shucai Xiao committed
1024
std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
1025
                                                    module& m,
Shucai Xiao's avatar
Shucai Xiao committed
1026
1027
1028
1029
1030
                                                    instruction_ref ins,
                                                    std::vector<instruction_ref> inputs,
                                                    const operation& actv_func1,
                                                    const operation& actv_func2,
                                                    const operation& actv_func3) const
Shucai Xiao's avatar
Shucai Xiao committed
1031
{
Shucai Xiao's avatar
Shucai Xiao committed
1032
    // must have 7 args in the input vector
Shucai Xiao's avatar
Shucai Xiao committed
1033
1034
1035
1036
1037
1038
1039
1040
1041
    assert(inputs.size() == 8);
    auto seq      = inputs.at(0);
    auto w        = inputs.at(1);
    auto r        = inputs.at(2);
    auto bias     = inputs.at(3);
    auto seq_lens = inputs.at(4);
    auto ih       = inputs.at(5);
    auto ic       = inputs.at(6);
    auto pph      = inputs.at(7);
Shucai Xiao's avatar
Shucai Xiao committed
1042

1043
1044
    instruction_ref hidden_states = m.end();
    instruction_ref cell_outputs  = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
1045
1046

    instruction_ref last_hs_output{};
Shucai Xiao's avatar
Shucai Xiao committed
1047
1048
    instruction_ref last_cell_output{};

1049
    migraphx::shape r_shape = r->get_shape();
Paul Fultz II's avatar
Paul Fultz II committed
1050
    long hs                 = r_shape.lens()[2];
1051
    auto bs                 = ih->get_shape().lens()[1];
Shucai Xiao's avatar
Shucai Xiao committed
1052
1053

    std::vector<int64_t> perm{1, 0};
1054
    // w matrix, squeeze and transpose
1055
1056
    auto sw  = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
    auto tsw = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
Shucai Xiao's avatar
Shucai Xiao committed
1057

1058
    // r matrix, squeeze and transpose
1059
1060
    auto sr  = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
    auto tsr = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr);
Shucai Xiao's avatar
Shucai Xiao committed
1061

Shucai Xiao's avatar
Shucai Xiao committed
1062
    // initial hidden state
1063
    auto sih = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
Shucai Xiao's avatar
Shucai Xiao committed
1064
1065

    // initial cell state
1066
    auto sic     = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ic);
1067
    auto ic_lens = sic->get_shape().lens();
Shucai Xiao's avatar
Shucai Xiao committed
1068

1069
    // bias
1070
    instruction_ref wrb{};
1071
    if(bias != m.end())
1072
    {
1073

1074
1075
        auto sbias = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
        auto ub_wb = m.insert_instruction(
1076
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {4 * hs}}}), sbias);
1077
        auto ub_rb = m.insert_instruction(
1078
1079
1080
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {4 * hs}}, {"ends", {8 * hs}}}),
            sbias);
1081
        auto ub_wrb = m.insert_instruction(ins, make_op("add"), ub_wb, ub_rb);
1082

1083
        wrb = m.insert_instruction(
1084
            ins,
1085
            make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, 4 * static_cast<size_t>(hs)}}}),
1086
            ub_wrb);
1087
1088
    }

Shucai Xiao's avatar
Shucai Xiao committed
1089
    // peep hole
Shucai Xiao's avatar
Shucai Xiao committed
1090
1091
1092
    instruction_ref pphi_brcst{};
    instruction_ref ppho_brcst{};
    instruction_ref pphf_brcst{};
1093
    if(pph != m.end())
Shucai Xiao's avatar
Shucai Xiao committed
1094
    {
1095
1096
        auto spph = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), pph);
        auto pphi = m.insert_instruction(
1097
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), spph);
1098
        pphi_brcst = m.insert_instruction(
1099
            ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), pphi);
1100

1101
        auto ppho = m.insert_instruction(
1102
            ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), spph);
1103
        ppho_brcst = m.insert_instruction(
1104
            ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), ppho);
1105

1106
        auto pphf = m.insert_instruction(
1107
            ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), spph);
1108
        pphf_brcst = m.insert_instruction(
1109
            ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), pphf);
Shucai Xiao's avatar
Shucai Xiao committed
1110
    }
Shucai Xiao's avatar
Shucai Xiao committed
1111

1112
    long seq_len = get_seq_len(m, seq, seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1113
    for(long i = 0; i < seq_len; ++i)
Shucai Xiao's avatar
Shucai Xiao committed
1114
1115
    {
        long seq_index = is_forward ? i : (seq_len - 1 - i);
1116
        auto xt        = m.insert_instruction(
1117
1118
1119
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
            seq);
1120
1121
        auto cont_xt = m.insert_instruction(ins, make_op("contiguous"), xt);
        xt           = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
1122

1123
1124
1125
1126
        auto xt_tsw  = m.insert_instruction(ins, make_op("dot"), xt, tsw);
        auto sih_tsr = m.insert_instruction(ins, make_op("dot"), sih, tsr);
        auto xt_sih  = m.insert_instruction(ins, make_op("add"), xt_tsw, sih_tsr);
        if(bias != m.end())
1127
        {
1128
            xt_sih = m.insert_instruction(ins, make_op("add"), xt_sih, wrb);
1129
        }
Shucai Xiao's avatar
Shucai Xiao committed
1130

1131
        auto it_before_actv = m.insert_instruction(
1132
            ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), xt_sih);
1133
        auto ot_before_actv = m.insert_instruction(
1134
            ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), xt_sih);
1135
        auto ft_before_actv = m.insert_instruction(
1136
1137
1138
            ins,
            make_op("slice", {{"axes", {1}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}),
            xt_sih);
1139
        auto ct_before_actv = m.insert_instruction(
1140
1141
1142
            ins,
            make_op("slice", {{"axes", {1}}, {"starts", {3 * hs}}, {"ends", {4 * hs}}}),
            xt_sih);
1143

1144
        if(pph != m.end())
Shucai Xiao's avatar
Shucai Xiao committed
1145
        {
1146
1147
            auto pphi_ct   = m.insert_instruction(ins, make_op("mul"), pphi_brcst, sic);
            it_before_actv = m.insert_instruction(ins, make_op("add"), it_before_actv, pphi_ct);
Shucai Xiao's avatar
Shucai Xiao committed
1148

1149
1150
            auto pphf_ct   = m.insert_instruction(ins, make_op("mul"), pphf_brcst, sic);
            ft_before_actv = m.insert_instruction(ins, make_op("add"), ft_before_actv, pphf_ct);
Shucai Xiao's avatar
Shucai Xiao committed
1151
        }
1152
1153
1154
        auto it = m.insert_instruction(ins, actv_func1, it_before_actv);
        auto ft = m.insert_instruction(ins, actv_func1, ft_before_actv);
        auto ct = m.insert_instruction(ins, actv_func2, ct_before_actv);
1155
1156

        // equation Ct = ft (.) Ct-1 + it (.) ct
1157
1158
1159
        auto ft_cell = m.insert_instruction(ins, make_op("mul"), ft, sic);
        auto it_ct   = m.insert_instruction(ins, make_op("mul"), it, ct);
        auto cellt   = m.insert_instruction(ins, make_op("add"), ft_cell, it_ct);
1160

1161
        if(pph != m.end())
Shucai Xiao's avatar
Shucai Xiao committed
1162
        {
1163
1164
            auto ppho_cellt = m.insert_instruction(ins, make_op("mul"), ppho_brcst, cellt);
            ot_before_actv  = m.insert_instruction(ins, make_op("add"), ot_before_actv, ppho_cellt);
Shucai Xiao's avatar
Shucai Xiao committed
1165
        }
1166
        auto ot = m.insert_instruction(ins, actv_func1, ot_before_actv);
1167
1168

        // Ht = ot (.) h(Ct)
1169
1170
        auto h_cellt = m.insert_instruction(ins, actv_func3, cellt);
        auto ht      = m.insert_instruction(ins, make_op("mul"), ot, h_cellt);
1171
1172
1173
1174

        sic = cellt;
        sih = ht;

1175
        last_hs_output = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), ht);
1176
        last_cell_output =
1177
            m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), cellt);
1178

Shucai Xiao's avatar
Shucai Xiao committed
1179
        if(i < seq_len - 1)
1180
        {
Shucai Xiao's avatar
Shucai Xiao committed
1181
            if(i == 0)
1182
            {
Shucai Xiao's avatar
Shucai Xiao committed
1183
1184
                hidden_states = last_hs_output;
                cell_outputs  = last_cell_output;
1185
1186
1187
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
1188
1189
                auto concat_hs_arg0 = is_forward ? hidden_states : last_hs_output;
                auto concat_hs_arg1 = is_forward ? last_hs_output : hidden_states;
1190
                hidden_states       = m.insert_instruction(
1191
                    ins, make_op("concat", {{"axis", 0}}), concat_hs_arg0, concat_hs_arg1);
Shucai Xiao's avatar
Shucai Xiao committed
1192
1193
1194

                auto concat_cell_arg0 = is_forward ? cell_outputs : last_cell_output;
                auto concat_cell_arg1 = is_forward ? last_cell_output : cell_outputs;
1195
                cell_outputs          = m.insert_instruction(
1196
                    ins, make_op("concat", {{"axis", 0}}), concat_cell_arg0, concat_cell_arg1);
1197
1198
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
1199
    }
1200

Shucai Xiao's avatar
Shucai Xiao committed
1201
    return {hidden_states, last_hs_output, cell_outputs, last_cell_output};
Shucai Xiao's avatar
Shucai Xiao committed
1202
1203
1204
1205
1206
1207
1208
1209
1210
}

std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
{
    auto lstm_op = any_cast<op::lstm>(ins->get_operator());
    // before rewrite the lstm operator, need to ensure
    // we have 6 actv funcs, even though a user does not
    // specifiy any actv func. If less than 46, use the
    // algorithm in parse_lstm to make 6 actv functions
Shucai Xiao's avatar
Shucai Xiao committed
1211
    const auto& actv_funcs     = lstm_op.actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1212
    std::size_t num_actv_funcs = actv_funcs.size();
Shucai Xiao's avatar
Shucai Xiao committed
1213
    if(lstm_op.direction == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1214
1215
1216
1217
    {
        switch(num_actv_funcs)
        {
        case 0:
1218
1219
1220
1221
1222
1223
            return {make_op("sigmoid"),
                    make_op("tanh"),
                    make_op("tanh"),
                    make_op("sigmoid"),
                    make_op("tanh"),
                    make_op("tanh")};
Shucai Xiao's avatar
Shucai Xiao committed
1224
1225

        case 1:
Shucai Xiao's avatar
Shucai Xiao committed
1226
1227
1228
1229
1230
1231
            return {actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0)};
Shucai Xiao's avatar
Shucai Xiao committed
1232
1233

        case 2:
Shucai Xiao's avatar
Shucai Xiao committed
1234
1235
1236
1237
1238
1239
1240
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(1),
                    actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(1)};

Shucai Xiao's avatar
Shucai Xiao committed
1241
        case 3:
Shucai Xiao's avatar
Shucai Xiao committed
1242
1243
1244
1245
1246
1247
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2)};
Shucai Xiao's avatar
Shucai Xiao committed
1248
1249

        case 4:
Shucai Xiao's avatar
Shucai Xiao committed
1250
1251
1252
1253
1254
1255
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(3),
                    actv_funcs.at(3),
                    actv_funcs.at(3)};
Shucai Xiao's avatar
Shucai Xiao committed
1256
1257

        case 5:
Shucai Xiao's avatar
Shucai Xiao committed
1258
1259
1260
1261
1262
1263
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(3),
                    actv_funcs.at(4),
                    actv_funcs.at(4)};
Shucai Xiao's avatar
Shucai Xiao committed
1264

Shucai Xiao's avatar
Shucai Xiao committed
1265
        default: return actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1266
1267
1268
1269
1270
1271
        }
    }
    else
    {
        switch(num_actv_funcs)
        {
1272
        case 0: return {make_op("sigmoid"), make_op("tanh"), make_op("tanh")};
Shucai Xiao's avatar
Shucai Xiao committed
1273

Shucai Xiao's avatar
Shucai Xiao committed
1274
        case 1: return {actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0)};
Shucai Xiao's avatar
Shucai Xiao committed
1275

Shucai Xiao's avatar
Shucai Xiao committed
1276
1277
1278
        case 2: return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1)};

        default: return actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1279
1280
1281
1282
        }
    }
}

1283
bool rewrite_rnn::is_variable_seq_lens(const module& m, instruction_ref seq_lens) const
Shucai Xiao's avatar
Shucai Xiao committed
1284
1285
{
    bool is_var_lens = false;
1286
    if(seq_lens != m.end())
Shucai Xiao's avatar
Shucai Xiao committed
1287
1288
1289
1290
1291
1292
1293
    {
        if(seq_lens->can_eval())
        {
            auto arg_lens = seq_lens->eval();
            std::vector<int64_t> vec_lens;
            arg_lens.visit([&](auto l) { vec_lens.assign(l.begin(), l.end()); });
            int64_t l = 0;
1294
            if(not vec_lens.empty())
Shucai Xiao's avatar
Shucai Xiao committed
1295
1296
1297
            {
                l = vec_lens[0];
            }
1298
            if(not std::all_of(vec_lens.begin(), vec_lens.end(), [&](auto v) { return v == l; }))
Shucai Xiao's avatar
Shucai Xiao committed
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
            {
                is_var_lens = true;
            }
        }
        else
        {
            is_var_lens = true;
        }
    }

    return is_var_lens;
}

std::size_t
1313
rewrite_rnn::get_seq_len(const module& m, instruction_ref input, instruction_ref seq_lens) const
Shucai Xiao's avatar
Shucai Xiao committed
1314
{
1315
    bool is_var_lens = is_variable_seq_lens(m, seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1316
1317
    auto input_shape = input->get_shape();
    auto length      = input_shape.lens()[0];
1318
    if(not is_var_lens and seq_lens != m.end())
Shucai Xiao's avatar
Shucai Xiao committed
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
    {
        auto arg_len = seq_lens->eval();
        std::vector<std::size_t> vec_lens;
        arg_len.visit([&](auto l) { vec_lens.assign(l.begin(), l.end()); });
        length = vec_lens.empty() ? length : vec_lens[0];
    }

    return length;
}

1329
instruction_ref rewrite_rnn::replace_last_hs_output(module& m,
Shucai Xiao's avatar
Shucai Xiao committed
1330
1331
1332
1333
1334
                                                    instruction_ref ins,
                                                    instruction_ref seq_lens,
                                                    instruction_ref last_hs_output,
                                                    op::rnn_direction dirct) const
{
1335
    bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1336
1337
1338
    instruction_ref result_ins{};
    if(variable_seq_len)
    {
1339
1340
1341
1342
1343
1344
1345
        result_ins =
            m.insert_instruction(std::next(ins),
                                 make_op("rnn_var_sl_shift_output",
                                         {{"output_name", "hidden_states"}, {"direction", dirct}}),
                                 ins,
                                 seq_lens);
        m.replace_instruction(ins, result_ins);
1346
1347
        auto hs_outputs = find_all(result_ins->outputs(),
                                   [&](auto i) { return i->name() == "rnn_last_hs_output"; });
Shucai Xiao's avatar
Shucai Xiao committed
1348

1349
        for(auto& hs_out : hs_outputs)
Shucai Xiao's avatar
Shucai Xiao committed
1350
        {
1351
            auto inputs = hs_out->inputs();
1352
1353
1354
1355
            m.replace_instruction(hs_out,
                                  make_op("rnn_var_sl_last_output", {{"direction", dirct}}),
                                  inputs.front(),
                                  seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1356
1357
1358
1359
        }
    }
    else
    {
1360
1361
        auto hs_outputs =
            find_all(ins->outputs(), [&](auto i) { return i->name() == "rnn_last_hs_output"; });
Shucai Xiao's avatar
Shucai Xiao committed
1362

1363
1364
        for(auto& hs_out : hs_outputs)
        {
1365
            m.replace_instruction(hs_out, last_hs_output);
Shucai Xiao's avatar
Shucai Xiao committed
1366
        }
1367

Shucai Xiao's avatar
Shucai Xiao committed
1368
1369
1370
1371
1372
1373
        result_ins = ins;
    }

    return result_ins;
}

1374
void rewrite_rnn::replace_last_cell_output(module& m,
Shucai Xiao's avatar
Shucai Xiao committed
1375
1376
1377
1378
1379
1380
                                           instruction_ref ins,
                                           instruction_ref seq_lens,
                                           instruction_ref cell_outputs,
                                           instruction_ref last_cell_output,
                                           op::rnn_direction dirct) const
{
1381
    bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
1382
1383
    auto ins_outputs =
        find_all(ins->outputs(), [&](auto i) { return i->name() == "rnn_last_cell_output"; });
Shucai Xiao's avatar
Shucai Xiao committed
1384
1385
1386

    if(variable_seq_len)
    {
1387
        if(not ins_outputs.empty())
Shucai Xiao's avatar
Shucai Xiao committed
1388
        {
1389
            cell_outputs = m.insert_instruction(
1390
1391
1392
1393
1394
                std::next(ins),
                make_op("rnn_var_sl_shift_output",
                        {{"output_name", "cell_outputs"}, {"direction", dirct}}),
                cell_outputs,
                seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1395
1396
        }

1397
        for(auto co : ins_outputs)
Shucai Xiao's avatar
Shucai Xiao committed
1398
        {
1399
1400
1401
1402
            m.replace_instruction(co,
                                  make_op("rnn_var_sl_last_output", {{"direction", dirct}}),
                                  cell_outputs,
                                  seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1403
1404
1405
1406
1407
1408
        }
    }
    // replace the rnn_last_cell_output with the last_cell_output. The while
    // loop is to handle the case of multiple rnn_last_cell_output operators
    else
    {
1409
        for(auto co : ins_outputs)
Shucai Xiao's avatar
Shucai Xiao committed
1410
        {
1411
            m.replace_instruction(co, last_cell_output);
Shucai Xiao's avatar
Shucai Xiao committed
1412
1413
1414
1415
        }
    }
}

1416
instruction_ref rewrite_rnn::pad_hidden_states(module& m,
1417
1418
1419
1420
1421
                                               instruction_ref seq,
                                               instruction_ref seq_lens,
                                               instruction_ref hs) const
{
    auto max_seq_len = seq->get_shape().lens()[0];
1422
    auto seq_len     = get_seq_len(m, seq, seq_lens);
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433

    // condition of all sequence are of the same length and
    // less than max_seq_len, we need to append the hs outputs
    auto hs_padded = hs;
    if(seq_len < max_seq_len)
    {
        auto s        = hs->get_shape();
        auto pad_lens = s.lens();
        pad_lens[0]   = static_cast<std::size_t>(max_seq_len - seq_len);
        shape pad_s{s.type(), pad_lens};
        std::vector<float> pad_data(pad_s.elements(), 0.0f);
1434
1435
1436
        auto pl   = m.add_literal(pad_s, pad_data.begin(), pad_data.end());
        hs_padded = m.insert_instruction(std::next(hs), make_op("concat", {{"axis", 0}}), hs, pl);
        m.replace_instruction(hs, hs_padded);
1437
1438
1439
1440
1441
    }

    return hs_padded;
}

Shucai Xiao's avatar
Shucai Xiao committed
1442
1443
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx