rewrite_rnn.cpp 55.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
Shucai Xiao's avatar
Shucai Xiao committed
24
25
26
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
Paul's avatar
Paul committed
27
28
29
30
31
32
33
34
35
36
37
38
39
#include <migraphx/op/add.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/gru.hpp>
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/rnn.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/squeeze.hpp>
#include <migraphx/op/sub.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/unsqueeze.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
40
#include <migraphx/op/contiguous.hpp>
41
42
43
#include <migraphx/op/common.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
44
45
#include <migraphx/make_op.hpp>

Shucai Xiao's avatar
Shucai Xiao committed
46
47
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
48
#include <migraphx/ranges.hpp>
49
#include <migraphx/op/common.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
50
51
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
52
53
54
55

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

56
void rewrite_rnn::apply(module& m) const
Shucai Xiao's avatar
Shucai Xiao committed
57
{
58
    for(auto ins : iterator_for(m))
Shucai Xiao's avatar
Shucai Xiao committed
59
    {
Shucai Xiao's avatar
Shucai Xiao committed
60
        if(ins->name() == "rnn")
Shucai Xiao's avatar
Shucai Xiao committed
61
        {
62
            apply_vanilla_rnn(m, ins);
63
        }
64
        else if(ins->name() == "gru")
65
        {
66
            apply_gru(m, ins);
Shucai Xiao's avatar
Shucai Xiao committed
67
        }
68
69
        else if(ins->name() == "lstm")
        {
70
            apply_lstm(m, ins);
71
        }
Shucai Xiao's avatar
Shucai Xiao committed
72
    }
73
74
}

75
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
76
void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
77
78
79
80
{
    assert(ins->name() == "rnn");
    // could be 3 to 6 inputs, but the parse_rnn function will
    // append undefined operators to make 6 arguments when parsing
Shucai Xiao's avatar
Shucai Xiao committed
81
    // an onnx file. Another case is user can have num of arguments
82
    // when writing their module.
83
84
85
86
87
88
89
90
91
    auto args = ins->inputs();

    shape seq_shape         = args[0]->get_shape();
    std::size_t hidden_size = args[1]->get_shape().lens()[1];
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
    migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
    std::vector<float> data(ih_shape.elements(), 0);

Shucai Xiao's avatar
Shucai Xiao committed
92
93
    auto actv_funcs         = vanilla_rnn_actv_funcs(ins);
    auto rnn_op             = any_cast<op::rnn>(ins->get_operator());
94
95
96
    op::rnn_direction dirct = rnn_op.direction;

    // process sequence length
97
    instruction_ref seq_lens = m.end();
98
99
100
101
102
    if((args.size() >= 5) && args[4]->name() != "undefined")
    {
        seq_lens = args[4];
    }

103
    bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
104

105
    instruction_ref last_output{};
106
    if(dirct == op::rnn_direction::bidirectional)
107
108
    {
        // input weight matrix
109
        auto w_forward = m.insert_instruction(
110
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
111
        auto w_reverse = m.insert_instruction(
112
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
113
114

        // hidden state weight matrix
115
        auto r_forward = m.insert_instruction(
116
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
117
        auto r_reverse = m.insert_instruction(
118
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
119
120

        // process bias
121
122
        instruction_ref bias_forward = m.end();
        instruction_ref bias_reverse = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
123
        if(args.size() >= 4 && args[3]->name() != "undefined")
124
        {
125
            bias_forward = m.insert_instruction(
126
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
127
            bias_reverse = m.insert_instruction(
128
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
129
130
131
132
133
134
        }

        // process intial hidden state, it could be the 6th argument
        // or the 5th one (if the sequence len argument is ignored)
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
135
        if(args.size() == 6 && args[5]->name() != "undefined")
136
        {
137
            ih_forward = m.insert_instruction(
138
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
139
            ih_reverse = m.insert_instruction(
140
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
141
142
143
        }
        else
        {
144
145
            ih_forward = m.add_literal(migraphx::literal{ih_shape, data});
            ih_reverse = m.add_literal(migraphx::literal{ih_shape, data});
146
147
        }

148
149
        auto ret_forward =
            vanilla_rnn_cell(true,
150
                             m,
151
152
153
154
155
156
                             ins,
                             {args[0], w_forward, r_forward, bias_forward, seq_lens, ih_forward},
                             actv_funcs.at(0));

        if(variable_seq_len)
        {
157
158
            args[0] =
                m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
159
160
161
162
        }

        auto ret_reverse =
            vanilla_rnn_cell(false,
163
                             m,
164
165
166
                             ins,
                             {args[0], w_reverse, r_reverse, bias_reverse, seq_lens, ih_reverse},
                             actv_funcs.at(1));
167

168
        auto concat_output = m.insert_instruction(
169
            ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
170
        last_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_output);
171
172
173
174

        // The following logic is to ensure the last instruction rewritten from
        // rnn operator is a concat instruction
        // sequence len is 1
175
        if(ret_forward[0] == m.end())
176
        {
177
            m.replace_instruction(
178
                ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
179
180
181
        }
        else
        {
182
            ret_forward[0] = m.insert_instruction(
183
                ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
184
            ret_reverse[0] = m.insert_instruction(
185
                ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
186
            m.replace_instruction(
187
                ins, make_op("concat", {{"axis", 1}}), {ret_forward[0], ret_reverse[0]});
188
189
190
191
        }
    }
    else
    {
192
        bool is_forward = (dirct == op::rnn_direction::forward);
193
194
195
196
197
198
199
        // input weight matrix
        auto w = args[1];

        // hidden state weight matrix
        auto r = args[2];

        // process bias and initial hidden state
200
        instruction_ref bias = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
201
        if(args.size() >= 4 && args[3]->name() != "undefined")
202
203
204
205
206
207
        {
            bias = args[3];
        }

        // process intial hidden state
        instruction_ref ih;
Shucai Xiao's avatar
Shucai Xiao committed
208
        if(args.size() == 6 && args[5]->name() != "undefined")
209
210
211
212
213
        {
            ih = args[5];
        }
        else
        {
214
            ih = m.add_literal(migraphx::literal{ih_shape, data});
215
216
        }

217
        if(not is_forward and variable_seq_len)
218
        {
219
220
            args[0] =
                m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
221
222
223
        }

        auto ret = vanilla_rnn_cell(
224
225
            is_forward, m, ins, {args[0], w, r, bias, seq_lens, ih}, actv_funcs.at(0));
        last_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
226
227
228
229

        // following logic is to ensure the last instruction is a
        // concat instruction
        // sequence len is 1
230
        if(ret[0] == m.end())
231
        {
232
            m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
233
234
235
236
237
        }
        else
        {
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
238
            m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
239
240
241
        }
    }

242
243
    // in case of all sequences are of the same lengths and shorter than the
    // max sequence length, need to pad 0's at the end for output hidden states
244
245
    ins = pad_hidden_states(m, args[0], seq_lens, ins);
    replace_last_hs_output(m, ins, seq_lens, last_output, dirct);
Shucai Xiao's avatar
Shucai Xiao committed
246
247
}

Shucai Xiao's avatar
Shucai Xiao committed
248
std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
249
                                                           module& m,
Shucai Xiao's avatar
Shucai Xiao committed
250
                                                           instruction_ref ins,
251
                                                           std::vector<instruction_ref> inputs,
Shucai Xiao's avatar
Shucai Xiao committed
252
                                                           operation& actv_func) const
Shucai Xiao's avatar
Shucai Xiao committed
253
{
254
255
256
257
258
259
260
261
    assert(inputs.size() == 6);
    auto seq      = inputs.at(0);
    auto w        = inputs.at(1);
    auto r        = inputs.at(2);
    auto bias     = inputs.at(3);
    auto seq_lens = inputs.at(4);
    auto ih       = inputs.at(5);

Shucai Xiao's avatar
Shucai Xiao committed
262
263
    // squeeze and transpose w
    std::vector<int64_t> perm{1, 0};
264
265
    auto sw      = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
    auto tran_sw = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
Shucai Xiao's avatar
Shucai Xiao committed
266
267

    // squeeze and transpose r
268
269
    auto sr      = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
    auto tran_sr = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr);
Shucai Xiao's avatar
Shucai Xiao committed
270
271

    // initial hidden state
272
    auto sih      = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
273
    auto sih_lens = sih->get_shape().lens();
Shucai Xiao's avatar
Shucai Xiao committed
274
275

    // bias
276
    instruction_ref bb{};
277
    if(bias != m.end())
Shucai Xiao's avatar
Shucai Xiao committed
278
    {
279
        long hs    = static_cast<long>(r->get_shape().lens()[2]);
280
281
        auto sbias = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
        auto wb    = m.insert_instruction(
282
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), sbias);
283
        auto rb = m.insert_instruction(
284
            ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), sbias);
285
286
        auto wrb = m.insert_instruction(ins, make_op("add"), wb, rb);
        bb       = m.insert_instruction(
287
            ins, make_op("broadcast", {{"axis", 1}, {"out_lens", sih_lens}}), wrb);
Shucai Xiao's avatar
Shucai Xiao committed
288
289
    }

290
    instruction_ref hidden_out = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
291
    instruction_ref last_out{};
292
293
    last_out     = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), sih);
    long seq_len = get_seq_len(m, seq, seq_lens);
294
    for(long i = 0; i < seq_len; i++)
Shucai Xiao's avatar
Shucai Xiao committed
295
    {
Shucai Xiao's avatar
Shucai Xiao committed
296
        long seq_index = is_forward ? i : (seq_len - 1 - i);
297
        auto xt        = m.insert_instruction(
298
299
300
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
            seq);
301
302
303
304
305
        auto cont_xt = m.insert_instruction(ins, make_op("contiguous"), xt);
        xt           = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
        auto xt_wi   = m.insert_instruction(ins, make_op("dot"), xt, tran_sw);
        auto ht_ri   = m.insert_instruction(ins, make_op("dot"), sih, tran_sr);
        if(bias != m.end())
Shucai Xiao's avatar
Shucai Xiao committed
306
        {
307
            xt_wi = m.insert_instruction(ins, make_op("add"), xt_wi, bb);
Shucai Xiao's avatar
Shucai Xiao committed
308
        }
309
        auto xt_ht = m.insert_instruction(ins, make_op("add"), xt_wi, ht_ri);
Shucai Xiao's avatar
Shucai Xiao committed
310
311

        // apply activation function
312
        auto ht = m.insert_instruction(ins, actv_func, xt_ht);
Shucai Xiao's avatar
Shucai Xiao committed
313
        sih     = ht;
Shucai Xiao's avatar
Shucai Xiao committed
314

Shucai Xiao's avatar
Shucai Xiao committed
315
316
        // add the dimensions of sequence length (axis 0 for sequence length,
        // axis 1 for num_directions
317
        last_out = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), ht);
Shucai Xiao's avatar
Shucai Xiao committed
318

Shucai Xiao's avatar
Shucai Xiao committed
319
320
321
        // concatenation for the last last_out is performed in the apply()
        // function to ensure the last instruction is concat, then we have
        // output inserted
Shucai Xiao's avatar
Shucai Xiao committed
322
        if(i < seq_len - 1)
Shucai Xiao's avatar
Shucai Xiao committed
323
        {
Shucai Xiao's avatar
Shucai Xiao committed
324
325
            if(is_forward)
            {
326
327
                hidden_out = (seq_index == 0)
                                 ? last_out
328
                                 : m.insert_instruction(
329
                                       ins, make_op("concat", {{"axis", 0}}), hidden_out, last_out);
Shucai Xiao's avatar
Shucai Xiao committed
330
331
332
            }
            else
            {
333
334
                hidden_out = (seq_index == seq_len - 1)
                                 ? last_out
335
                                 : m.insert_instruction(
336
                                       ins, make_op("concat", {{"axis", 0}}), last_out, hidden_out);
Shucai Xiao's avatar
Shucai Xiao committed
337
            }
Shucai Xiao's avatar
Shucai Xiao committed
338
339
340
        }
    }

341
    return {hidden_out, last_out};
Shucai Xiao's avatar
Shucai Xiao committed
342
343
}

Shucai Xiao's avatar
Shucai Xiao committed
344
std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins) const
345
346
{
    auto rnn_op = any_cast<op::rnn>(ins->get_operator());
Shucai Xiao's avatar
Shucai Xiao committed
347
348
349
350
    // could be 3 to 6 inputs, but the parse_gru function will
    // append undefined operators to make 6 arguments when parsing
    // an onnx file. Another case is user can have any num of arguments
    // when writing their program.
351
    if(rnn_op.direction == op::rnn_direction::bidirectional)
352
    {
Shucai Xiao's avatar
Shucai Xiao committed
353
        if(rnn_op.actv_funcs.empty())
354
355
        {
            // default is tanh
356
            return {make_op("tanh"), make_op("tanh")};
357
        }
Shucai Xiao's avatar
Shucai Xiao committed
358
        else if(rnn_op.actv_funcs.size() == 1)
359
360
361
362
363
364
365
366
367
368
        {
            return {rnn_op.actv_funcs.at(0), rnn_op.actv_funcs.at(0)};
        }
        else
        {
            return rnn_op.actv_funcs;
        }
    }
    else
    {
Shucai Xiao's avatar
Shucai Xiao committed
369
        if(rnn_op.actv_funcs.empty())
370
371
        {
            // default is tanh
372
            return {make_op("tanh")};
373
374
375
376
377
378
379
380
        }
        else
        {
            return rnn_op.actv_funcs;
        }
    }
}

381
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
382
void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
383
384
385
{
    assert(ins->name() == "gru");
    const auto actv_funcs = gru_actv_funcs(ins);
Shucai Xiao's avatar
Shucai Xiao committed
386
387
388
389
    // could be 3 to 6 inputs, but the parse_gru function will
    // append undefined operators to make 6 arguments when parsing
    // an onnx file. Another case is user can have num of arguments
    // when writing their program.
390
391
392
393
394
395
396
397
398
    auto args = ins->inputs();

    shape seq_shape         = args[0]->get_shape();
    std::size_t hidden_size = args[2]->get_shape().lens()[2];
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
    migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
    std::vector<float> data(ih_shape.elements(), 0.0);

Shucai Xiao's avatar
Shucai Xiao committed
399
    auto gru_op             = any_cast<op::gru>(ins->get_operator());
400
401
402
    op::rnn_direction dirct = gru_op.direction;

    // process sequence length
403
    instruction_ref seq_lens = m.end();
404
405
406
407
408
    if((args.size() >= 5) && args[4]->name() != "undefined")
    {
        seq_lens = args[4];
    }

409
    bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
410

411
    instruction_ref last_output{};
412
    if(dirct == op::rnn_direction::bidirectional)
413
414
    {
        // w weight matrix
415
        auto w_forward = m.insert_instruction(
416
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
417
        auto w_reverse = m.insert_instruction(
418
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
419
420

        // r weight matrix
421
        auto r_forward = m.insert_instruction(
422
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
423
        auto r_reverse = m.insert_instruction(
424
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
425
426

        // bias
427
428
        instruction_ref bias_forward = m.end();
        instruction_ref bias_reverse = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
429
        if(args.size() >= 4 && args[3]->name() != "undefined")
430
        {
431
            bias_forward = m.insert_instruction(
432
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
433
            bias_reverse = m.insert_instruction(
434
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
435
436
437
438
439
        }

        // intial hidden state
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
440
        if(args.size() == 6 && args[5]->name() != "undefined")
441
        {
442
            ih_forward = m.insert_instruction(
443
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
444
            ih_reverse = m.insert_instruction(
445
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
446
447
448
        }
        else
        {
449
450
            ih_forward = m.add_literal(migraphx::literal{ih_shape, data});
            ih_reverse = m.add_literal(migraphx::literal{ih_shape, data});
451
452
        }

453
454
        auto ret_forward =
            gru_cell(true,
455
                     m,
456
457
458
459
460
461
462
463
                     ins,
                     {args[0], w_forward, r_forward, bias_forward, seq_lens, ih_forward},
                     gru_op.linear_before_reset,
                     actv_funcs.at(0),
                     actv_funcs.at(1));

        if(variable_seq_len)
        {
464
465
            args[0] =
                m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
466
        }
Shucai Xiao's avatar
Shucai Xiao committed
467

468
469
        auto ret_reverse =
            gru_cell(false,
470
                     m,
471
472
473
474
475
                     ins,
                     {args[0], w_reverse, r_reverse, bias_reverse, seq_lens, ih_reverse},
                     gru_op.linear_before_reset,
                     actv_funcs.at(2),
                     actv_funcs.at(3));
476

477
        auto concat_output = m.insert_instruction(
478
            ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
479
        last_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_output);
480
481
482

        // The following logic is to ensure the last instruction rewritten
        // from gru operator is a concat
483
        if(ret_forward[0] == m.end())
484
        {
485
            m.replace_instruction(
486
                ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
487
488
489
        }
        else
        {
490
            ret_forward[0] = m.insert_instruction(
491
                ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
492
            ret_reverse[0] = m.insert_instruction(
493
                ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
494
            m.replace_instruction(
495
                ins, make_op("concat", {{"axis", 1}}), {ret_forward[0], ret_reverse[0]});
496
497
498
499
        }
    }
    else
    {
500
        bool is_forward = (dirct == op::rnn_direction::forward);
501
502
503
504
505
        // weight matrix
        auto w = args[1];
        auto r = args[2];

        // bias
506
        instruction_ref bias = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
507
        if(args.size() >= 4 && args[3]->name() != "undefined")
508
509
510
511
512
513
        {
            bias = args[3];
        }

        // intial hidden state
        instruction_ref ih{};
Shucai Xiao's avatar
Shucai Xiao committed
514
        if(args.size() == 6 && args[5]->name() != "undefined")
515
516
517
518
519
        {
            ih = args[5];
        }
        else
        {
520
            ih = m.add_literal(migraphx::literal{ih_shape, data});
521
522
        }

523
        if(not is_forward and variable_seq_len)
524
        {
525
526
            args[0] =
                m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
527
528
        }

529
        auto ret = gru_cell(is_forward,
530
                            m,
531
                            ins,
532
                            {args[0], w, r, bias, seq_lens, ih},
533
534
535
536
                            gru_op.linear_before_reset,
                            actv_funcs.at(0),
                            actv_funcs.at(1));

537
        last_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
538

539
        if(ret[0] == m.end())
540
        {
541
            m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
542
543
544
545
546
        }
        else
        {
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
547
            m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
548
549
550
        }
    }

551
552
    // in case of all sequences are of the same lengths and shorter than the
    // max sequence length, need to pad 0's at the end for output hidden states
553
554
    ins = pad_hidden_states(m, args[0], seq_lens, ins);
    replace_last_hs_output(m, ins, seq_lens, last_output, dirct);
555
556
}

557
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
558
std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
559
                                                   module& m,
Shucai Xiao's avatar
Shucai Xiao committed
560
561
562
563
564
                                                   instruction_ref ins,
                                                   std::vector<instruction_ref> inputs,
                                                   int linear_before_reset,
                                                   const operation& actv_func1,
                                                   const operation& actv_func2) const
565
{
566
567
568
569
570
571
572
    assert(inputs.size() == 6);
    auto seq      = inputs.at(0);
    auto w        = inputs.at(1);
    auto r        = inputs.at(2);
    auto bias     = inputs.at(3);
    auto seq_lens = inputs.at(4);
    auto ih       = inputs.at(5);
573

574
    instruction_ref hidden_states = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
575
    instruction_ref last_output{};
Shucai Xiao's avatar
Shucai Xiao committed
576
    migraphx::shape seq_shape = seq->get_shape();
Shucai Xiao's avatar
Shucai Xiao committed
577
    migraphx::shape r_shape   = r->get_shape();
Paul Fultz II's avatar
Paul Fultz II committed
578
    long hs                   = r_shape.lens()[2];
579

580
581
    migraphx::shape ss(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]});
    std::vector<float> data(ss.elements(), 1.0f);
582
    auto l1 = m.add_literal(migraphx::literal{ss, data});
583

584
    // w matrix squeeze to 2-dim and do a transpose
585
    std::vector<int64_t> perm{1, 0};
586
587
    auto sw = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
    auto tw = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
588

589
    // r slide to two part, zr and h
590
591
    auto sr  = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
    auto rzr = m.insert_instruction(
592
        ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2 * hs}}}), sr);
593
    auto trzr = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rzr);
594

595
    auto rh = m.insert_instruction(
596
        ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), sr);
597
    auto trh = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rh);
598
599

    // initial states
600
    auto sih  = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
Shucai Xiao's avatar
Shucai Xiao committed
601
    size_t bs = ih->get_shape().lens()[1];
602
603

    // bias
604
605
606
    instruction_ref bwb{};
    instruction_ref brb_zr{};
    instruction_ref brb_h{};
607
    if(bias != m.end())
608
    {
609
610
        auto sbias = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
        auto wb    = m.insert_instruction(
611
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {3 * hs}}}), sbias);
612
        bwb = m.insert_instruction(
613
            ins,
614
            make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(3 * hs)}}}),
615
616
            wb);

617
        auto rb_zr = m.insert_instruction(
618
619
620
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {3 * hs}}, {"ends", {5 * hs}}}),
            sbias);
621
        auto rb_h = m.insert_instruction(
622
623
624
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {5 * hs}}, {"ends", {6 * hs}}}),
            sbias);
625
        brb_zr = m.insert_instruction(
626
            ins,
627
            make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(2 * hs)}}}),
628
            rb_zr);
629
        brb_h = m.insert_instruction(
630
            ins,
631
            make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(hs)}}}),
632
            rb_h);
633
634
    }

635
    long seq_len = get_seq_len(m, seq, seq_lens);
636
637
638
    for(long i = 0; i < seq_len; i++)
    {
        long seq_index = is_forward ? i : (seq_len - 1 - i);
639
        auto xt        = m.insert_instruction(
640
641
642
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
            seq);
643
644
        auto cont_xt = m.insert_instruction(ins, make_op("contiguous"), xt);
        xt           = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
645

646
647
648
        auto xt_w    = m.insert_instruction(ins, make_op("dot"), xt, tw);
        auto ih1_rzr = m.insert_instruction(ins, make_op("dot"), sih, trzr);
        if(bias != m.end())
649
        {
650
651
            xt_w    = m.insert_instruction(ins, make_op("add"), xt_w, bwb);
            ih1_rzr = m.insert_instruction(ins, make_op("add"), ih1_rzr, brb_zr);
652
653
        }

654
        auto xw_z = m.insert_instruction(
655
            ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), xt_w);
656
        auto xw_r = m.insert_instruction(
657
            ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), xt_w);
658
        auto xw_h = m.insert_instruction(
659
            ins, make_op("slice", {{"axes", {1}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), xt_w);
660

661
        auto hr_z = m.insert_instruction(
662
            ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), ih1_rzr);
663
        auto hr_r = m.insert_instruction(
664
            ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), ih1_rzr);
665

666
667
        auto xw_hr_z = m.insert_instruction(ins, make_op("add"), xw_z, hr_z);
        auto zt      = m.insert_instruction(ins, actv_func1, xw_hr_z);
668

669
670
        auto xw_hr_r = m.insert_instruction(ins, make_op("add"), xw_r, hr_r);
        auto rt      = m.insert_instruction(ins, actv_func1, xw_hr_r);
671
672
673
674
675

        instruction_ref hr_h{};
        if(linear_before_reset == 0)
        {
            // equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
676
677
678
            auto rt_ht1 = m.insert_instruction(ins, make_op("mul"), rt, sih);
            hr_h        = m.insert_instruction(ins, make_op("dot"), rt_ht1, trh);
            if(bias != m.end())
679
            {
680
                hr_h = m.insert_instruction(ins, make_op("add"), hr_h, brb_h);
681
            }
682
683
684
        }
        else
        {
685
            // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
686
687
            auto ht1_rh = m.insert_instruction(ins, make_op("dot"), sih, trh);
            if(bias != m.end())
688
            {
689
                ht1_rh = m.insert_instruction(ins, make_op("add"), ht1_rh, brb_h);
690
            }
691
            hr_h = m.insert_instruction(ins, make_op("mul"), rt, ht1_rh);
692
693
        }

694
695
        auto xw_hr_h = m.insert_instruction(ins, make_op("add"), xw_h, hr_h);
        auto ht      = m.insert_instruction(ins, actv_func2, xw_hr_h);
696

697
        // equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
698
699
700
701
702
        auto one_minus_zt    = m.insert_instruction(ins, make_op("sub"), l1, zt);
        auto one_minus_zt_ht = m.insert_instruction(ins, make_op("mul"), one_minus_zt, ht);
        auto zt_ht1          = m.insert_instruction(ins, make_op("mul"), zt, sih);
        sih                  = m.insert_instruction(ins, make_op("add"), one_minus_zt_ht, zt_ht1);
        last_output = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), sih);
703
704
705
706
707
708
709
710

        if(i < seq_len - 1)
        {
            if(is_forward)
            {
                hidden_states =
                    (seq_index == 0)
                        ? last_output
711
                        : m.insert_instruction(
712
                              ins, make_op("concat", {{"axis", 0}}), hidden_states, last_output);
713
714
715
716
717
718
            }
            else
            {
                hidden_states =
                    (seq_index == seq_len - 1)
                        ? last_output
719
                        : m.insert_instruction(
720
                              ins, make_op("concat", {{"axis", 0}}), last_output, hidden_states);
721
722
723
724
725
726
727
728
729
730
731
732
733
734
            }
        }
    }

    return {hidden_states, last_output};
}

std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
{
    auto gru_op = any_cast<op::gru>(ins->get_operator());
    // before rewrite the gru operator, need to ensure
    // we have 4 actv funcs, even though a user does not
    // specifiy any actv func. If less than 4, use the
    // algorithm in parse_gru to make 4 actv functions
735
    if(gru_op.direction == op::rnn_direction::bidirectional)
736
737
    {
        if(gru_op.actv_funcs.empty())
738
            return {make_op("sigmoid"), make_op("tanh"), make_op("sigmoid"), make_op("tanh")};
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
        else if(gru_op.actv_funcs.size() == 1)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0)};
        else if(gru_op.actv_funcs.size() == 2)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1)};
        else if(gru_op.actv_funcs.size() == 3)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1),
                    gru_op.actv_funcs.at(2),
                    gru_op.actv_funcs.at(0)};
        else
            return gru_op.actv_funcs;
    }
    else
    {
        if(gru_op.actv_funcs.empty())
760
            return {make_op("sigmoid"), make_op("tanh")};
761
762
763
764
765
766
767
        else if(gru_op.actv_funcs.size() == 1)
            return {gru_op.actv_funcs.at(0), gru_op.actv_funcs.at(0)};
        else
            return gru_op.actv_funcs;
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
768
// for lstm operators
769
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
770
void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
Shucai Xiao's avatar
Shucai Xiao committed
771
772
773
{
    assert(ins->name() == "lstm");
    auto args = ins->inputs();
Shucai Xiao's avatar
Shucai Xiao committed
774

Shucai Xiao's avatar
Shucai Xiao committed
775
    shape seq_shape         = args[0]->get_shape();
776
    std::size_t hidden_size = args[2]->get_shape().lens()[2];
Shucai Xiao's avatar
Shucai Xiao committed
777
778
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
Shucai Xiao's avatar
Shucai Xiao committed
779
    migraphx::shape ihc_shape{type, {1, batch_size, hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
780
    std::vector<float> ihc_data(ihc_shape.elements(), 0.0);
Shucai Xiao's avatar
Shucai Xiao committed
781
782
783

    migraphx::shape pph_shape{type, {1, 3 * hidden_size}};

Shucai Xiao's avatar
Shucai Xiao committed
784
785
    auto actv_funcs         = lstm_actv_funcs(ins);
    auto lstm_op            = any_cast<op::lstm>(ins->get_operator());
Shucai Xiao's avatar
Shucai Xiao committed
786
    op::rnn_direction dirct = lstm_op.direction;
Shucai Xiao's avatar
Shucai Xiao committed
787

Shucai Xiao's avatar
Shucai Xiao committed
788
    // process sequence length
789
    instruction_ref seq_lens = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
790
791
792
793
794
    if((args.size() >= 5) && args[4]->name() != "undefined")
    {
        seq_lens = args[4];
    }

795
    bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
796
797

    instruction_ref last_hs_output{};
Shucai Xiao's avatar
Shucai Xiao committed
798
    instruction_ref last_cell_output{};
Shucai Xiao's avatar
Shucai Xiao committed
799
800
    instruction_ref hidden_state{};
    instruction_ref cell_outputs{};
Shucai Xiao's avatar
Shucai Xiao committed
801
    if(dirct == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
802
803
804
    {
        // input weight matrix
        // input weight matrix
805
        auto w_forward = m.insert_instruction(
806
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
807
        auto w_reverse = m.insert_instruction(
808
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
Shucai Xiao's avatar
Shucai Xiao committed
809
810

        // hidden state weight matrix
811
        auto r_forward = m.insert_instruction(
812
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
813
        auto r_reverse = m.insert_instruction(
814
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
Shucai Xiao's avatar
Shucai Xiao committed
815
816

        // process bias
817
818
        instruction_ref bias_forward = m.end();
        instruction_ref bias_reverse = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
819
        if(args.size() >= 4 && args[3]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
820
        {
821
            bias_forward = m.insert_instruction(
822
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
823
            bias_reverse = m.insert_instruction(
824
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
Shucai Xiao's avatar
Shucai Xiao committed
825
826
827
828
829
830
831
        }

        // process intial hidden state, it is the 6th argument
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
        if(args.size() >= 6 && args[5]->name() != "undefined")
        {
832
            ih_forward = m.insert_instruction(
833
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
834
            ih_reverse = m.insert_instruction(
835
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
Shucai Xiao's avatar
Shucai Xiao committed
836
837
838
        }
        else
        {
839
840
            ih_forward = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
            ih_reverse = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
Shucai Xiao's avatar
Shucai Xiao committed
841
842
843
844
845
        }

        // process initial cell value
        instruction_ref ic_forward{};
        instruction_ref ic_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
846
        if(args.size() >= 7 && args[6]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
847
        {
848
            ic_forward = m.insert_instruction(
849
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[6]);
850
            ic_reverse = m.insert_instruction(
851
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[6]);
Shucai Xiao's avatar
Shucai Xiao committed
852
853
854
        }
        else
        {
855
856
            ic_forward = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
            ic_reverse = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
Shucai Xiao's avatar
Shucai Xiao committed
857
858
859
        }

        // process weight of the peephole
860
861
        instruction_ref pph_forward = m.end();
        instruction_ref pph_reverse = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
862
        if(args.size() == 8 && args[7]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
863
        {
864
            pph_forward = m.insert_instruction(
865
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[7]);
866
            pph_reverse = m.insert_instruction(
867
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[7]);
Shucai Xiao's avatar
Shucai Xiao committed
868
        }
Shucai Xiao's avatar
Shucai Xiao committed
869

Shucai Xiao's avatar
Shucai Xiao committed
870
        auto ret_forward = lstm_cell(true,
871
                                     m,
Shucai Xiao's avatar
Shucai Xiao committed
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
                                     ins,
                                     {args[0],
                                      w_forward,
                                      r_forward,
                                      bias_forward,
                                      seq_lens,
                                      ih_forward,
                                      ic_forward,
                                      pph_forward},
                                     actv_funcs.at(0),
                                     actv_funcs.at(1),
                                     actv_funcs.at(2));

        if(variable_seq_len)
        {
887
888
            args[0] =
                m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
889
890
        }
        auto ret_reverse = lstm_cell(false,
891
                                     m,
Shucai Xiao's avatar
Shucai Xiao committed
892
893
894
895
896
897
898
899
900
901
902
903
904
                                     ins,
                                     {args[0],
                                      w_reverse,
                                      r_reverse,
                                      bias_reverse,
                                      seq_lens,
                                      ih_reverse,
                                      ic_reverse,
                                      pph_reverse},
                                     actv_funcs.at(3),
                                     actv_funcs.at(4),
                                     actv_funcs.at(5));

905
        auto concat_hs_output = m.insert_instruction(
906
            ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
907
        auto concat_cell_output = m.insert_instruction(
908
909
            ins, make_op("concat", {{"axis", 1}}), ret_forward[3], ret_reverse[3]);
        last_hs_output =
910
            m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_hs_output);
911
        last_cell_output =
912
            m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_cell_output);
Shucai Xiao's avatar
Shucai Xiao committed
913
914

        // the following logic is to ensure the last instruction is a concat
915
        if(ret_forward[0] == m.end())
Shucai Xiao's avatar
Shucai Xiao committed
916
        {
Shucai Xiao's avatar
Shucai Xiao committed
917
            cell_outputs = concat_cell_output;
Shucai Xiao's avatar
Shucai Xiao committed
918
919
920
        }
        else
        {
921
            ret_forward[1] = m.insert_instruction(
922
                ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
923
            ret_reverse[1] = m.insert_instruction(
924
                ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
Shucai Xiao's avatar
Shucai Xiao committed
925

926
            ret_forward[3] = m.insert_instruction(
927
                ins, make_op("concat", {{"axis", 0}}), ret_forward[2], ret_forward[3]);
928
            ret_reverse[3] = m.insert_instruction(
929
                ins, make_op("concat", {{"axis", 0}}), ret_reverse[3], ret_reverse[2]);
930
            cell_outputs = m.insert_instruction(
931
                ins, make_op("concat", {{"axis", 1}}), ret_forward[3], ret_reverse[3]);
Shucai Xiao's avatar
Shucai Xiao committed
932
        }
Shucai Xiao's avatar
Shucai Xiao committed
933

934
        hidden_state = m.replace_instruction(
935
            ins, make_op("concat", {{"axis", 1}}), {ret_forward[1], ret_reverse[1]});
Shucai Xiao's avatar
Shucai Xiao committed
936
937
938
    }
    else
    {
Shucai Xiao's avatar
Shucai Xiao committed
939
        bool is_forward = (dirct == op::rnn_direction::forward);
Shucai Xiao's avatar
Shucai Xiao committed
940
941
942
943
944
        // weight matrices
        auto w = args[1];
        auto r = args[2];

        // bias
945
        instruction_ref bias = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
946
        if(args.size() >= 4 && args[3]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
947
948
949
950
951
952
        {
            bias = args[3];
        }

        // initial hidden state
        instruction_ref ih{};
Shucai Xiao's avatar
Shucai Xiao committed
953
        if(args.size() >= 6 && args[5]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
954
955
956
957
958
        {
            ih = args[5];
        }
        else
        {
959
            ih = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
Shucai Xiao's avatar
Shucai Xiao committed
960
961
962
963
        }

        // initial cell value
        instruction_ref ic{};
Shucai Xiao's avatar
Shucai Xiao committed
964
        if(args.size() >= 7 && args[6]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
965
966
967
968
969
        {
            ic = args[6];
        }
        else
        {
970
            ic = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
Shucai Xiao's avatar
Shucai Xiao committed
971
972
973
        }

        // process weight of the peephole
974
        instruction_ref pph = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
975
976
977
978
        if(args.size() == 8 && args[7]->name() != "undefined")
        {
            pph = args[7];
        }
Shucai Xiao's avatar
Shucai Xiao committed
979

980
        if(not is_forward and variable_seq_len)
Shucai Xiao's avatar
Shucai Xiao committed
981
        {
982
983
            args[0] =
                m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
984
        }
Shucai Xiao's avatar
Shucai Xiao committed
985
        auto ret = lstm_cell(is_forward,
986
                             m,
Shucai Xiao's avatar
Shucai Xiao committed
987
                             ins,
Shucai Xiao's avatar
Shucai Xiao committed
988
                             {args[0], w, r, bias, seq_lens, ih, ic, pph},
Shucai Xiao's avatar
Shucai Xiao committed
989
990
991
992
                             actv_funcs.at(0),
                             actv_funcs.at(1),
                             actv_funcs.at(2));

993
994
        last_hs_output   = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
        last_cell_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[3]);
Shucai Xiao's avatar
Shucai Xiao committed
995

996
        if(ret[0] == m.end())
Shucai Xiao's avatar
Shucai Xiao committed
997
        {
Shucai Xiao's avatar
Shucai Xiao committed
998
            cell_outputs = ret[3];
999
            hidden_state = m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
Shucai Xiao's avatar
Shucai Xiao committed
1000
1001
1002
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
1003
1004
            auto concat_cell_arg0 = is_forward ? ret[2] : ret[3];
            auto concat_cell_arg1 = is_forward ? ret[3] : ret[2];
1005
            cell_outputs          = m.insert_instruction(
1006
                ins, make_op("concat", {{"axis", 0}}), concat_cell_arg0, concat_cell_arg1);
Shucai Xiao's avatar
Shucai Xiao committed
1007

Shucai Xiao's avatar
Shucai Xiao committed
1008
1009
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
1010
            hidden_state     = m.replace_instruction(
1011
                ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
Shucai Xiao's avatar
Shucai Xiao committed
1012
1013
1014
        }
    }

1015
1016
    // in case of all sequences are of the same lengths and shorter than the
    // max sequence length, need to pad 0's at the end for output hidden states
1017
    hidden_state = pad_hidden_states(m, args[0], seq_lens, hidden_state);
1018
1019

    // replace last hidden states with corresponding instructions
1020
    ins = replace_last_hs_output(m, hidden_state, seq_lens, last_hs_output, dirct);
1021
1022

    // replace last cell outputs with corresponding instructions
1023
    replace_last_cell_output(m, ins, seq_lens, cell_outputs, last_cell_output, dirct);
Shucai Xiao's avatar
Shucai Xiao committed
1024
1025
}

1026
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
Shucai Xiao's avatar
Shucai Xiao committed
1027
std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
1028
                                                    module& m,
Shucai Xiao's avatar
Shucai Xiao committed
1029
1030
1031
1032
1033
                                                    instruction_ref ins,
                                                    std::vector<instruction_ref> inputs,
                                                    const operation& actv_func1,
                                                    const operation& actv_func2,
                                                    const operation& actv_func3) const
Shucai Xiao's avatar
Shucai Xiao committed
1034
{
Shucai Xiao's avatar
Shucai Xiao committed
1035
    // must have 7 args in the input vector
Shucai Xiao's avatar
Shucai Xiao committed
1036
1037
1038
1039
1040
1041
1042
1043
1044
    assert(inputs.size() == 8);
    auto seq      = inputs.at(0);
    auto w        = inputs.at(1);
    auto r        = inputs.at(2);
    auto bias     = inputs.at(3);
    auto seq_lens = inputs.at(4);
    auto ih       = inputs.at(5);
    auto ic       = inputs.at(6);
    auto pph      = inputs.at(7);
Shucai Xiao's avatar
Shucai Xiao committed
1045

1046
1047
    instruction_ref hidden_states = m.end();
    instruction_ref cell_outputs  = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
1048
1049

    instruction_ref last_hs_output{};
Shucai Xiao's avatar
Shucai Xiao committed
1050
1051
    instruction_ref last_cell_output{};

1052
    migraphx::shape r_shape = r->get_shape();
Paul Fultz II's avatar
Paul Fultz II committed
1053
    long hs                 = r_shape.lens()[2];
1054
    auto bs                 = ih->get_shape().lens()[1];
Shucai Xiao's avatar
Shucai Xiao committed
1055
1056

    std::vector<int64_t> perm{1, 0};
1057
    // w matrix, squeeze and transpose
1058
1059
    auto sw  = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
    auto tsw = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
Shucai Xiao's avatar
Shucai Xiao committed
1060

1061
    // r matrix, squeeze and transpose
1062
1063
    auto sr  = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
    auto tsr = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr);
Shucai Xiao's avatar
Shucai Xiao committed
1064

Shucai Xiao's avatar
Shucai Xiao committed
1065
    // initial hidden state
1066
    auto sih = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
Shucai Xiao's avatar
Shucai Xiao committed
1067
1068

    // initial cell state
1069
    auto sic     = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ic);
1070
    auto ic_lens = sic->get_shape().lens();
Shucai Xiao's avatar
Shucai Xiao committed
1071

1072
    // bias
1073
    instruction_ref wrb{};
1074
    if(bias != m.end())
1075
    {
1076

1077
1078
        auto sbias = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
        auto ub_wb = m.insert_instruction(
1079
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {4 * hs}}}), sbias);
1080
        auto ub_rb = m.insert_instruction(
1081
1082
1083
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {4 * hs}}, {"ends", {8 * hs}}}),
            sbias);
1084
        auto ub_wrb = m.insert_instruction(ins, make_op("add"), ub_wb, ub_rb);
1085

1086
        wrb = m.insert_instruction(
1087
            ins,
1088
            make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, 4 * static_cast<size_t>(hs)}}}),
1089
            ub_wrb);
1090
1091
    }

Shucai Xiao's avatar
Shucai Xiao committed
1092
    // peep hole
Shucai Xiao's avatar
Shucai Xiao committed
1093
1094
1095
    instruction_ref pphi_brcst{};
    instruction_ref ppho_brcst{};
    instruction_ref pphf_brcst{};
1096
    if(pph != m.end())
Shucai Xiao's avatar
Shucai Xiao committed
1097
    {
1098
1099
        auto spph = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), pph);
        auto pphi = m.insert_instruction(
1100
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), spph);
1101
        pphi_brcst = m.insert_instruction(
1102
            ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), pphi);
1103

1104
        auto ppho = m.insert_instruction(
1105
            ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), spph);
1106
        ppho_brcst = m.insert_instruction(
1107
            ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), ppho);
1108

1109
        auto pphf = m.insert_instruction(
1110
            ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), spph);
1111
        pphf_brcst = m.insert_instruction(
1112
            ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), pphf);
Shucai Xiao's avatar
Shucai Xiao committed
1113
    }
Shucai Xiao's avatar
Shucai Xiao committed
1114

1115
    long seq_len = get_seq_len(m, seq, seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1116
    for(long i = 0; i < seq_len; ++i)
Shucai Xiao's avatar
Shucai Xiao committed
1117
1118
    {
        long seq_index = is_forward ? i : (seq_len - 1 - i);
1119
        auto xt        = m.insert_instruction(
1120
1121
1122
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
            seq);
1123
1124
        auto cont_xt = m.insert_instruction(ins, make_op("contiguous"), xt);
        xt           = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
1125

1126
1127
1128
1129
        auto xt_tsw  = m.insert_instruction(ins, make_op("dot"), xt, tsw);
        auto sih_tsr = m.insert_instruction(ins, make_op("dot"), sih, tsr);
        auto xt_sih  = m.insert_instruction(ins, make_op("add"), xt_tsw, sih_tsr);
        if(bias != m.end())
1130
        {
1131
            xt_sih = m.insert_instruction(ins, make_op("add"), xt_sih, wrb);
1132
        }
Shucai Xiao's avatar
Shucai Xiao committed
1133

1134
        auto it_before_actv = m.insert_instruction(
1135
            ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), xt_sih);
1136
        auto ot_before_actv = m.insert_instruction(
1137
            ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), xt_sih);
1138
        auto ft_before_actv = m.insert_instruction(
1139
1140
1141
            ins,
            make_op("slice", {{"axes", {1}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}),
            xt_sih);
1142
        auto ct_before_actv = m.insert_instruction(
1143
1144
1145
            ins,
            make_op("slice", {{"axes", {1}}, {"starts", {3 * hs}}, {"ends", {4 * hs}}}),
            xt_sih);
1146

1147
        if(pph != m.end())
Shucai Xiao's avatar
Shucai Xiao committed
1148
        {
1149
1150
            auto pphi_ct   = m.insert_instruction(ins, make_op("mul"), pphi_brcst, sic);
            it_before_actv = m.insert_instruction(ins, make_op("add"), it_before_actv, pphi_ct);
Shucai Xiao's avatar
Shucai Xiao committed
1151

1152
1153
            auto pphf_ct   = m.insert_instruction(ins, make_op("mul"), pphf_brcst, sic);
            ft_before_actv = m.insert_instruction(ins, make_op("add"), ft_before_actv, pphf_ct);
Shucai Xiao's avatar
Shucai Xiao committed
1154
        }
1155
1156
1157
        auto it = m.insert_instruction(ins, actv_func1, it_before_actv);
        auto ft = m.insert_instruction(ins, actv_func1, ft_before_actv);
        auto ct = m.insert_instruction(ins, actv_func2, ct_before_actv);
1158
1159

        // equation Ct = ft (.) Ct-1 + it (.) ct
1160
1161
1162
        auto ft_cell = m.insert_instruction(ins, make_op("mul"), ft, sic);
        auto it_ct   = m.insert_instruction(ins, make_op("mul"), it, ct);
        auto cellt   = m.insert_instruction(ins, make_op("add"), ft_cell, it_ct);
1163

1164
        if(pph != m.end())
Shucai Xiao's avatar
Shucai Xiao committed
1165
        {
1166
1167
            auto ppho_cellt = m.insert_instruction(ins, make_op("mul"), ppho_brcst, cellt);
            ot_before_actv  = m.insert_instruction(ins, make_op("add"), ot_before_actv, ppho_cellt);
Shucai Xiao's avatar
Shucai Xiao committed
1168
        }
1169
        auto ot = m.insert_instruction(ins, actv_func1, ot_before_actv);
1170
1171

        // Ht = ot (.) h(Ct)
1172
1173
        auto h_cellt = m.insert_instruction(ins, actv_func3, cellt);
        auto ht      = m.insert_instruction(ins, make_op("mul"), ot, h_cellt);
1174
1175
1176
1177

        sic = cellt;
        sih = ht;

1178
        last_hs_output = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), ht);
1179
        last_cell_output =
1180
            m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), cellt);
1181

Shucai Xiao's avatar
Shucai Xiao committed
1182
        if(i < seq_len - 1)
1183
        {
Shucai Xiao's avatar
Shucai Xiao committed
1184
            if(i == 0)
1185
            {
Shucai Xiao's avatar
Shucai Xiao committed
1186
1187
                hidden_states = last_hs_output;
                cell_outputs  = last_cell_output;
1188
1189
1190
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
1191
1192
                auto concat_hs_arg0 = is_forward ? hidden_states : last_hs_output;
                auto concat_hs_arg1 = is_forward ? last_hs_output : hidden_states;
1193
                hidden_states       = m.insert_instruction(
1194
                    ins, make_op("concat", {{"axis", 0}}), concat_hs_arg0, concat_hs_arg1);
Shucai Xiao's avatar
Shucai Xiao committed
1195
1196
1197

                auto concat_cell_arg0 = is_forward ? cell_outputs : last_cell_output;
                auto concat_cell_arg1 = is_forward ? last_cell_output : cell_outputs;
1198
                cell_outputs          = m.insert_instruction(
1199
                    ins, make_op("concat", {{"axis", 0}}), concat_cell_arg0, concat_cell_arg1);
1200
1201
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
1202
    }
1203

Shucai Xiao's avatar
Shucai Xiao committed
1204
    return {hidden_states, last_hs_output, cell_outputs, last_cell_output};
Shucai Xiao's avatar
Shucai Xiao committed
1205
1206
1207
1208
1209
1210
1211
1212
1213
}

std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
{
    auto lstm_op = any_cast<op::lstm>(ins->get_operator());
    // before rewrite the lstm operator, need to ensure
    // we have 6 actv funcs, even though a user does not
    // specifiy any actv func. If less than 46, use the
    // algorithm in parse_lstm to make 6 actv functions
Shucai Xiao's avatar
Shucai Xiao committed
1214
    const auto& actv_funcs     = lstm_op.actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1215
    std::size_t num_actv_funcs = actv_funcs.size();
Shucai Xiao's avatar
Shucai Xiao committed
1216
    if(lstm_op.direction == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1217
1218
1219
1220
    {
        switch(num_actv_funcs)
        {
        case 0:
1221
1222
1223
1224
1225
1226
            return {make_op("sigmoid"),
                    make_op("tanh"),
                    make_op("tanh"),
                    make_op("sigmoid"),
                    make_op("tanh"),
                    make_op("tanh")};
Shucai Xiao's avatar
Shucai Xiao committed
1227
1228

        case 1:
Shucai Xiao's avatar
Shucai Xiao committed
1229
1230
1231
1232
1233
1234
            return {actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0)};
Shucai Xiao's avatar
Shucai Xiao committed
1235
1236

        case 2:
Shucai Xiao's avatar
Shucai Xiao committed
1237
1238
1239
1240
1241
1242
1243
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(1),
                    actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(1)};

Shucai Xiao's avatar
Shucai Xiao committed
1244
        case 3:
Shucai Xiao's avatar
Shucai Xiao committed
1245
1246
1247
1248
1249
1250
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2)};
Shucai Xiao's avatar
Shucai Xiao committed
1251
1252

        case 4:
Shucai Xiao's avatar
Shucai Xiao committed
1253
1254
1255
1256
1257
1258
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(3),
                    actv_funcs.at(3),
                    actv_funcs.at(3)};
Shucai Xiao's avatar
Shucai Xiao committed
1259
1260

        case 5:
Shucai Xiao's avatar
Shucai Xiao committed
1261
1262
1263
1264
1265
1266
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(3),
                    actv_funcs.at(4),
                    actv_funcs.at(4)};
Shucai Xiao's avatar
Shucai Xiao committed
1267

Shucai Xiao's avatar
Shucai Xiao committed
1268
        default: return actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1269
1270
1271
1272
1273
1274
        }
    }
    else
    {
        switch(num_actv_funcs)
        {
1275
        case 0: return {make_op("sigmoid"), make_op("tanh"), make_op("tanh")};
Shucai Xiao's avatar
Shucai Xiao committed
1276

Shucai Xiao's avatar
Shucai Xiao committed
1277
        case 1: return {actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0)};
Shucai Xiao's avatar
Shucai Xiao committed
1278

Shucai Xiao's avatar
Shucai Xiao committed
1279
1280
1281
        case 2: return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1)};

        default: return actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1282
1283
1284
1285
        }
    }
}

1286
bool rewrite_rnn::is_variable_seq_lens(const module& m, instruction_ref seq_lens) const
Shucai Xiao's avatar
Shucai Xiao committed
1287
1288
{
    bool is_var_lens = false;
1289
    if(seq_lens != m.end())
Shucai Xiao's avatar
Shucai Xiao committed
1290
1291
1292
1293
1294
1295
1296
    {
        if(seq_lens->can_eval())
        {
            auto arg_lens = seq_lens->eval();
            std::vector<int64_t> vec_lens;
            arg_lens.visit([&](auto l) { vec_lens.assign(l.begin(), l.end()); });
            int64_t l = 0;
1297
            if(not vec_lens.empty())
Shucai Xiao's avatar
Shucai Xiao committed
1298
1299
1300
            {
                l = vec_lens[0];
            }
1301
            if(not std::all_of(vec_lens.begin(), vec_lens.end(), [&](auto v) { return v == l; }))
Shucai Xiao's avatar
Shucai Xiao committed
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
            {
                is_var_lens = true;
            }
        }
        else
        {
            is_var_lens = true;
        }
    }

    return is_var_lens;
}

std::size_t
1316
rewrite_rnn::get_seq_len(const module& m, instruction_ref input, instruction_ref seq_lens) const
Shucai Xiao's avatar
Shucai Xiao committed
1317
{
1318
    bool is_var_lens = is_variable_seq_lens(m, seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1319
1320
    auto input_shape = input->get_shape();
    auto length      = input_shape.lens()[0];
1321
    if(not is_var_lens and seq_lens != m.end())
Shucai Xiao's avatar
Shucai Xiao committed
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
    {
        auto arg_len = seq_lens->eval();
        std::vector<std::size_t> vec_lens;
        arg_len.visit([&](auto l) { vec_lens.assign(l.begin(), l.end()); });
        length = vec_lens.empty() ? length : vec_lens[0];
    }

    return length;
}

1332
instruction_ref rewrite_rnn::replace_last_hs_output(module& m,
Shucai Xiao's avatar
Shucai Xiao committed
1333
1334
1335
1336
1337
                                                    instruction_ref ins,
                                                    instruction_ref seq_lens,
                                                    instruction_ref last_hs_output,
                                                    op::rnn_direction dirct) const
{
1338
    bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1339
1340
1341
    instruction_ref result_ins{};
    if(variable_seq_len)
    {
1342
1343
1344
1345
1346
1347
1348
        result_ins =
            m.insert_instruction(std::next(ins),
                                 make_op("rnn_var_sl_shift_output",
                                         {{"output_name", "hidden_states"}, {"direction", dirct}}),
                                 ins,
                                 seq_lens);
        m.replace_instruction(ins, result_ins);
1349
1350
        auto hs_outputs = find_all(result_ins->outputs(),
                                   [&](auto i) { return i->name() == "rnn_last_hs_output"; });
Shucai Xiao's avatar
Shucai Xiao committed
1351

1352
        for(auto& hs_out : hs_outputs)
Shucai Xiao's avatar
Shucai Xiao committed
1353
        {
1354
            auto inputs = hs_out->inputs();
1355
1356
1357
1358
            m.replace_instruction(hs_out,
                                  make_op("rnn_var_sl_last_output", {{"direction", dirct}}),
                                  inputs.front(),
                                  seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1359
1360
1361
1362
        }
    }
    else
    {
1363
1364
        auto hs_outputs =
            find_all(ins->outputs(), [&](auto i) { return i->name() == "rnn_last_hs_output"; });
Shucai Xiao's avatar
Shucai Xiao committed
1365

1366
1367
        for(auto& hs_out : hs_outputs)
        {
1368
            m.replace_instruction(hs_out, last_hs_output);
Shucai Xiao's avatar
Shucai Xiao committed
1369
        }
1370

Shucai Xiao's avatar
Shucai Xiao committed
1371
1372
1373
1374
1375
1376
        result_ins = ins;
    }

    return result_ins;
}

1377
void rewrite_rnn::replace_last_cell_output(module& m,
Shucai Xiao's avatar
Shucai Xiao committed
1378
1379
1380
1381
1382
1383
                                           instruction_ref ins,
                                           instruction_ref seq_lens,
                                           instruction_ref cell_outputs,
                                           instruction_ref last_cell_output,
                                           op::rnn_direction dirct) const
{
1384
    bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
1385
1386
    auto ins_outputs =
        find_all(ins->outputs(), [&](auto i) { return i->name() == "rnn_last_cell_output"; });
Shucai Xiao's avatar
Shucai Xiao committed
1387
1388
1389

    if(variable_seq_len)
    {
1390
        if(not ins_outputs.empty())
Shucai Xiao's avatar
Shucai Xiao committed
1391
        {
1392
            cell_outputs = m.insert_instruction(
1393
1394
1395
1396
1397
                std::next(ins),
                make_op("rnn_var_sl_shift_output",
                        {{"output_name", "cell_outputs"}, {"direction", dirct}}),
                cell_outputs,
                seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1398
1399
        }

1400
        for(auto co : ins_outputs)
Shucai Xiao's avatar
Shucai Xiao committed
1401
        {
1402
1403
1404
1405
            m.replace_instruction(co,
                                  make_op("rnn_var_sl_last_output", {{"direction", dirct}}),
                                  cell_outputs,
                                  seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1406
1407
1408
1409
1410
1411
        }
    }
    // replace the rnn_last_cell_output with the last_cell_output. The while
    // loop is to handle the case of multiple rnn_last_cell_output operators
    else
    {
1412
        for(auto co : ins_outputs)
Shucai Xiao's avatar
Shucai Xiao committed
1413
        {
1414
            m.replace_instruction(co, last_cell_output);
Shucai Xiao's avatar
Shucai Xiao committed
1415
1416
1417
1418
        }
    }
}

1419
instruction_ref rewrite_rnn::pad_hidden_states(module& m,
1420
1421
1422
1423
1424
                                               instruction_ref seq,
                                               instruction_ref seq_lens,
                                               instruction_ref hs) const
{
    auto max_seq_len = seq->get_shape().lens()[0];
1425
    auto seq_len     = get_seq_len(m, seq, seq_lens);
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436

    // condition of all sequence are of the same length and
    // less than max_seq_len, we need to append the hs outputs
    auto hs_padded = hs;
    if(seq_len < max_seq_len)
    {
        auto s        = hs->get_shape();
        auto pad_lens = s.lens();
        pad_lens[0]   = static_cast<std::size_t>(max_seq_len - seq_len);
        shape pad_s{s.type(), pad_lens};
        std::vector<float> pad_data(pad_s.elements(), 0.0f);
1437
1438
1439
        auto pl   = m.add_literal(pad_s, pad_data.begin(), pad_data.end());
        hs_padded = m.insert_instruction(std::next(hs), make_op("concat", {{"axis", 0}}), hs, pl);
        m.replace_instruction(hs, hs_padded);
1440
1441
1442
1443
1444
    }

    return hs_padded;
}

Shucai Xiao's avatar
Shucai Xiao committed
1445
1446
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx