rewrite_rnn.cpp 55.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
Shucai Xiao's avatar
Shucai Xiao committed
24
25
26
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
Paul's avatar
Paul committed
27
28
29
30
31
32
33
34
35
36
37
38
39
#include <migraphx/op/add.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/gru.hpp>
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/rnn.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/squeeze.hpp>
#include <migraphx/op/sub.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/unsqueeze.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
40
#include <migraphx/op/contiguous.hpp>
41
42
43
#include <migraphx/op/common.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
44
45
#include <migraphx/make_op.hpp>

Shucai Xiao's avatar
Shucai Xiao committed
46
47
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
48
#include <migraphx/ranges.hpp>
49
#include <migraphx/op/common.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
50
51
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
52
53
54
55

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

56
void rewrite_rnn::apply(module& m) const
Shucai Xiao's avatar
Shucai Xiao committed
57
{
58
    for(auto ins : iterator_for(m))
Shucai Xiao's avatar
Shucai Xiao committed
59
    {
Shucai Xiao's avatar
Shucai Xiao committed
60
        if(ins->name() == "rnn")
Shucai Xiao's avatar
Shucai Xiao committed
61
        {
62
            apply_vanilla_rnn(m, ins);
63
        }
64
        else if(ins->name() == "gru")
65
        {
66
            apply_gru(m, ins);
Shucai Xiao's avatar
Shucai Xiao committed
67
        }
68
69
        else if(ins->name() == "lstm")
        {
70
            apply_lstm(m, ins);
71
        }
Shucai Xiao's avatar
Shucai Xiao committed
72
    }
73
74
}

75
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
76
void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
77
78
79
80
{
    assert(ins->name() == "rnn");
    // could be 3 to 6 inputs, but the parse_rnn function will
    // append undefined operators to make 6 arguments when parsing
Shucai Xiao's avatar
Shucai Xiao committed
81
    // an onnx file. Another case is user can have num of arguments
82
    // when writing their module.
83
84
85
86
87
88
89
90
91
    auto args = ins->inputs();

    shape seq_shape         = args[0]->get_shape();
    std::size_t hidden_size = args[1]->get_shape().lens()[1];
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
    migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
    std::vector<float> data(ih_shape.elements(), 0);

Shucai Xiao's avatar
Shucai Xiao committed
92
93
    auto actv_funcs         = vanilla_rnn_actv_funcs(ins);
    auto rnn_op             = any_cast<op::rnn>(ins->get_operator());
94
95
96
    op::rnn_direction dirct = rnn_op.direction;

    // process sequence length
97
    instruction_ref seq_lens = m.end();
98
99
100
101
102
    if((args.size() >= 5) && args[4]->name() != "undefined")
    {
        seq_lens = args[4];
    }

103
    bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
104

105
    instruction_ref last_output{};
106
    if(dirct == op::rnn_direction::bidirectional)
107
108
    {
        // input weight matrix
109
        auto w_forward = m.insert_instruction(
110
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
111
        auto w_reverse = m.insert_instruction(
112
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
113
114

        // hidden state weight matrix
115
        auto r_forward = m.insert_instruction(
116
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
117
        auto r_reverse = m.insert_instruction(
118
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
119
120

        // process bias
121
122
        instruction_ref bias_forward = m.end();
        instruction_ref bias_reverse = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
123
        if(args.size() >= 4 && args[3]->name() != "undefined")
124
        {
125
            bias_forward = m.insert_instruction(
126
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
127
            bias_reverse = m.insert_instruction(
128
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
129
130
131
132
133
134
        }

        // process intial hidden state, it could be the 6th argument
        // or the 5th one (if the sequence len argument is ignored)
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
135
        if(args.size() == 6 && args[5]->name() != "undefined")
136
        {
137
            ih_forward = m.insert_instruction(
138
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
139
            ih_reverse = m.insert_instruction(
140
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
141
142
143
        }
        else
        {
144
145
            ih_forward = m.add_literal(migraphx::literal{ih_shape, data});
            ih_reverse = m.add_literal(migraphx::literal{ih_shape, data});
146
147
        }

148
149
        auto ret_forward =
            vanilla_rnn_cell(true,
150
                             m,
151
152
153
154
155
156
                             ins,
                             {args[0], w_forward, r_forward, bias_forward, seq_lens, ih_forward},
                             actv_funcs.at(0));

        if(variable_seq_len)
        {
157
158
            args[0] =
                m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
159
160
161
162
        }

        auto ret_reverse =
            vanilla_rnn_cell(false,
163
                             m,
164
165
166
                             ins,
                             {args[0], w_reverse, r_reverse, bias_reverse, seq_lens, ih_reverse},
                             actv_funcs.at(1));
167

168
        auto concat_output = m.insert_instruction(
169
            ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
170
        last_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_output);
171
172
173
174

        // The following logic is to ensure the last instruction rewritten from
        // rnn operator is a concat instruction
        // sequence len is 1
175
        if(ret_forward[0] == m.end())
176
        {
177
            m.replace_instruction(
178
                ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
179
180
181
        }
        else
        {
182
            ret_forward[0] = m.insert_instruction(
183
                ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
184
            ret_reverse[0] = m.insert_instruction(
185
                ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
186
            m.replace_instruction(
187
                ins, make_op("concat", {{"axis", 1}}), {ret_forward[0], ret_reverse[0]});
188
189
190
191
        }
    }
    else
    {
192
        bool is_forward = (dirct == op::rnn_direction::forward);
193
194
195
196
197
198
199
        // input weight matrix
        auto w = args[1];

        // hidden state weight matrix
        auto r = args[2];

        // process bias and initial hidden state
200
        instruction_ref bias = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
201
        if(args.size() >= 4 && args[3]->name() != "undefined")
202
203
204
205
206
207
        {
            bias = args[3];
        }

        // process intial hidden state
        instruction_ref ih;
Shucai Xiao's avatar
Shucai Xiao committed
208
        if(args.size() == 6 && args[5]->name() != "undefined")
209
210
211
212
213
        {
            ih = args[5];
        }
        else
        {
214
            ih = m.add_literal(migraphx::literal{ih_shape, data});
215
216
        }

217
218
        if(!is_forward and variable_seq_len)
        {
219
220
            args[0] =
                m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
221
222
223
        }

        auto ret = vanilla_rnn_cell(
224
225
            is_forward, m, ins, {args[0], w, r, bias, seq_lens, ih}, actv_funcs.at(0));
        last_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
226
227
228
229

        // following logic is to ensure the last instruction is a
        // concat instruction
        // sequence len is 1
230
        if(ret[0] == m.end())
231
        {
232
            m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
233
234
235
236
237
        }
        else
        {
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
238
            m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
239
240
241
        }
    }

242
243
    // in case of all sequences are of the same lengths and shorter than the
    // max sequence length, need to pad 0's at the end for output hidden states
244
245
    ins = pad_hidden_states(m, args[0], seq_lens, ins);
    replace_last_hs_output(m, ins, seq_lens, last_output, dirct);
Shucai Xiao's avatar
Shucai Xiao committed
246
247
}

Shucai Xiao's avatar
Shucai Xiao committed
248
std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
249
                                                           module& m,
Shucai Xiao's avatar
Shucai Xiao committed
250
                                                           instruction_ref ins,
251
                                                           std::vector<instruction_ref> inputs,
Shucai Xiao's avatar
Shucai Xiao committed
252
                                                           operation& actv_func) const
Shucai Xiao's avatar
Shucai Xiao committed
253
{
254
255
256
257
258
259
260
261
    assert(inputs.size() == 6);
    auto seq      = inputs.at(0);
    auto w        = inputs.at(1);
    auto r        = inputs.at(2);
    auto bias     = inputs.at(3);
    auto seq_lens = inputs.at(4);
    auto ih       = inputs.at(5);

Shucai Xiao's avatar
Shucai Xiao committed
262
263
    // squeeze and transpose w
    std::vector<int64_t> perm{1, 0};
264
265
    auto sw      = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
    auto tran_sw = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
Shucai Xiao's avatar
Shucai Xiao committed
266
267

    // squeeze and transpose r
268
269
    auto sr      = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
    auto tran_sr = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr);
Shucai Xiao's avatar
Shucai Xiao committed
270
271

    // initial hidden state
272
    auto sih      = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
273
    auto sih_lens = sih->get_shape().lens();
Shucai Xiao's avatar
Shucai Xiao committed
274
275

    // bias
276
    instruction_ref bb{};
277
    if(bias != m.end())
Shucai Xiao's avatar
Shucai Xiao committed
278
    {
279
        long hs    = static_cast<long>(r->get_shape().lens()[2]);
280
281
        auto sbias = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
        auto wb    = m.insert_instruction(
282
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), sbias);
283
        auto rb = m.insert_instruction(
284
            ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), sbias);
285
286
        auto wrb = m.insert_instruction(ins, make_op("add"), wb, rb);
        bb       = m.insert_instruction(
287
            ins, make_op("broadcast", {{"axis", 1}, {"out_lens", sih_lens}}), wrb);
Shucai Xiao's avatar
Shucai Xiao committed
288
289
    }

290
    instruction_ref hidden_out = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
291
    instruction_ref last_out{};
292
293
    last_out     = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), sih);
    long seq_len = get_seq_len(m, seq, seq_lens);
294
    for(long i = 0; i < seq_len; i++)
Shucai Xiao's avatar
Shucai Xiao committed
295
    {
Shucai Xiao's avatar
Shucai Xiao committed
296
        long seq_index = is_forward ? i : (seq_len - 1 - i);
297
        auto xt        = m.insert_instruction(
298
299
300
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
            seq);
301
302
303
304
305
        auto cont_xt = m.insert_instruction(ins, make_op("contiguous"), xt);
        xt           = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
        auto xt_wi   = m.insert_instruction(ins, make_op("dot"), xt, tran_sw);
        auto ht_ri   = m.insert_instruction(ins, make_op("dot"), sih, tran_sr);
        if(bias != m.end())
Shucai Xiao's avatar
Shucai Xiao committed
306
        {
307
            xt_wi = m.insert_instruction(ins, make_op("add"), xt_wi, bb);
Shucai Xiao's avatar
Shucai Xiao committed
308
        }
309
        auto xt_ht = m.insert_instruction(ins, make_op("add"), xt_wi, ht_ri);
Shucai Xiao's avatar
Shucai Xiao committed
310
311

        // apply activation function
312
        auto ht = m.insert_instruction(ins, actv_func, xt_ht);
Shucai Xiao's avatar
Shucai Xiao committed
313
        sih     = ht;
Shucai Xiao's avatar
Shucai Xiao committed
314

Shucai Xiao's avatar
Shucai Xiao committed
315
316
        // add the dimensions of sequence length (axis 0 for sequence length,
        // axis 1 for num_directions
317
        last_out = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), ht);
Shucai Xiao's avatar
Shucai Xiao committed
318

Shucai Xiao's avatar
Shucai Xiao committed
319
320
321
        // concatenation for the last last_out is performed in the apply()
        // function to ensure the last instruction is concat, then we have
        // output inserted
Shucai Xiao's avatar
Shucai Xiao committed
322
        if(i < seq_len - 1)
Shucai Xiao's avatar
Shucai Xiao committed
323
        {
Shucai Xiao's avatar
Shucai Xiao committed
324
325
            if(is_forward)
            {
326
327
                hidden_out = (seq_index == 0)
                                 ? last_out
328
                                 : m.insert_instruction(
329
                                       ins, make_op("concat", {{"axis", 0}}), hidden_out, last_out);
Shucai Xiao's avatar
Shucai Xiao committed
330
331
332
            }
            else
            {
333
334
                hidden_out = (seq_index == seq_len - 1)
                                 ? last_out
335
                                 : m.insert_instruction(
336
                                       ins, make_op("concat", {{"axis", 0}}), last_out, hidden_out);
Shucai Xiao's avatar
Shucai Xiao committed
337
            }
Shucai Xiao's avatar
Shucai Xiao committed
338
339
340
        }
    }

341
    return {hidden_out, last_out};
Shucai Xiao's avatar
Shucai Xiao committed
342
343
}

Shucai Xiao's avatar
Shucai Xiao committed
344
std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins) const
345
346
{
    auto rnn_op = any_cast<op::rnn>(ins->get_operator());
Shucai Xiao's avatar
Shucai Xiao committed
347
348
349
350
    // could be 3 to 6 inputs, but the parse_gru function will
    // append undefined operators to make 6 arguments when parsing
    // an onnx file. Another case is user can have any num of arguments
    // when writing their program.
351
    if(rnn_op.direction == op::rnn_direction::bidirectional)
352
    {
Shucai Xiao's avatar
Shucai Xiao committed
353
        if(rnn_op.actv_funcs.empty())
354
355
        {
            // default is tanh
356
            return {make_op("tanh"), make_op("tanh")};
357
        }
Shucai Xiao's avatar
Shucai Xiao committed
358
        else if(rnn_op.actv_funcs.size() == 1)
359
360
361
362
363
364
365
366
367
368
        {
            return {rnn_op.actv_funcs.at(0), rnn_op.actv_funcs.at(0)};
        }
        else
        {
            return rnn_op.actv_funcs;
        }
    }
    else
    {
Shucai Xiao's avatar
Shucai Xiao committed
369
        if(rnn_op.actv_funcs.empty())
370
371
        {
            // default is tanh
372
            return {make_op("tanh")};
373
374
375
376
377
378
379
380
        }
        else
        {
            return rnn_op.actv_funcs;
        }
    }
}

381
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
382
void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
383
384
385
{
    assert(ins->name() == "gru");
    const auto actv_funcs = gru_actv_funcs(ins);
Shucai Xiao's avatar
Shucai Xiao committed
386
387
388
389
    // could be 3 to 6 inputs, but the parse_gru function will
    // append undefined operators to make 6 arguments when parsing
    // an onnx file. Another case is user can have num of arguments
    // when writing their program.
390
391
392
393
394
395
396
397
398
    auto args = ins->inputs();

    shape seq_shape         = args[0]->get_shape();
    std::size_t hidden_size = args[2]->get_shape().lens()[2];
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
    migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
    std::vector<float> data(ih_shape.elements(), 0.0);

Shucai Xiao's avatar
Shucai Xiao committed
399
    auto gru_op             = any_cast<op::gru>(ins->get_operator());
400
401
402
    op::rnn_direction dirct = gru_op.direction;

    // process sequence length
403
    instruction_ref seq_lens = m.end();
404
405
406
407
408
    if((args.size() >= 5) && args[4]->name() != "undefined")
    {
        seq_lens = args[4];
    }

409
    bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
410

411
    instruction_ref last_output{};
412
    if(dirct == op::rnn_direction::bidirectional)
413
414
    {
        // w weight matrix
415
        auto w_forward = m.insert_instruction(
416
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
417
        auto w_reverse = m.insert_instruction(
418
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
419
420

        // r weight matrix
421
        auto r_forward = m.insert_instruction(
422
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
423
        auto r_reverse = m.insert_instruction(
424
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
425
426

        // bias
427
428
        instruction_ref bias_forward = m.end();
        instruction_ref bias_reverse = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
429
        if(args.size() >= 4 && args[3]->name() != "undefined")
430
        {
431
            bias_forward = m.insert_instruction(
432
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
433
            bias_reverse = m.insert_instruction(
434
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
435
436
437
438
439
        }

        // intial hidden state
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
440
        if(args.size() == 6 && args[5]->name() != "undefined")
441
        {
442
            ih_forward = m.insert_instruction(
443
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
444
            ih_reverse = m.insert_instruction(
445
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
446
447
448
        }
        else
        {
449
450
            ih_forward = m.add_literal(migraphx::literal{ih_shape, data});
            ih_reverse = m.add_literal(migraphx::literal{ih_shape, data});
451
452
        }

453
454
        auto ret_forward =
            gru_cell(true,
455
                     m,
456
457
458
459
460
461
462
463
                     ins,
                     {args[0], w_forward, r_forward, bias_forward, seq_lens, ih_forward},
                     gru_op.linear_before_reset,
                     actv_funcs.at(0),
                     actv_funcs.at(1));

        if(variable_seq_len)
        {
464
465
            args[0] =
                m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
466
        }
Shucai Xiao's avatar
Shucai Xiao committed
467

468
469
        auto ret_reverse =
            gru_cell(false,
470
                     m,
471
472
473
474
475
                     ins,
                     {args[0], w_reverse, r_reverse, bias_reverse, seq_lens, ih_reverse},
                     gru_op.linear_before_reset,
                     actv_funcs.at(2),
                     actv_funcs.at(3));
476

477
        auto concat_output = m.insert_instruction(
478
            ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
479
        last_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_output);
480
481
482

        // The following logic is to ensure the last instruction rewritten
        // from gru operator is a concat
483
        if(ret_forward[0] == m.end())
484
        {
485
            m.replace_instruction(
486
                ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
487
488
489
        }
        else
        {
490
            ret_forward[0] = m.insert_instruction(
491
                ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
492
            ret_reverse[0] = m.insert_instruction(
493
                ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
494
            m.replace_instruction(
495
                ins, make_op("concat", {{"axis", 1}}), {ret_forward[0], ret_reverse[0]});
496
497
498
499
        }
    }
    else
    {
500
        bool is_forward = (dirct == op::rnn_direction::forward);
501
502
503
504
505
        // weight matrix
        auto w = args[1];
        auto r = args[2];

        // bias
506
        instruction_ref bias = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
507
        if(args.size() >= 4 && args[3]->name() != "undefined")
508
509
510
511
512
513
        {
            bias = args[3];
        }

        // intial hidden state
        instruction_ref ih{};
Shucai Xiao's avatar
Shucai Xiao committed
514
        if(args.size() == 6 && args[5]->name() != "undefined")
515
516
517
518
519
        {
            ih = args[5];
        }
        else
        {
520
            ih = m.add_literal(migraphx::literal{ih_shape, data});
521
522
        }

523
524
        if(!is_forward and variable_seq_len)
        {
525
526
            args[0] =
                m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
527
528
        }

529
        auto ret = gru_cell(is_forward,
530
                            m,
531
                            ins,
532
                            {args[0], w, r, bias, seq_lens, ih},
533
534
535
536
                            gru_op.linear_before_reset,
                            actv_funcs.at(0),
                            actv_funcs.at(1));

537
        last_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
538

539
        if(ret[0] == m.end())
540
        {
541
            m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
542
543
544
545
546
        }
        else
        {
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
547
            m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
548
549
550
        }
    }

551
552
    // in case of all sequences are of the same lengths and shorter than the
    // max sequence length, need to pad 0's at the end for output hidden states
553
554
    ins = pad_hidden_states(m, args[0], seq_lens, ins);
    replace_last_hs_output(m, ins, seq_lens, last_output, dirct);
555
556
}

557
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
558
std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
559
                                                   module& m,
Shucai Xiao's avatar
Shucai Xiao committed
560
561
562
563
564
                                                   instruction_ref ins,
                                                   std::vector<instruction_ref> inputs,
                                                   int linear_before_reset,
                                                   const operation& actv_func1,
                                                   const operation& actv_func2) const
565
{
566
567
568
569
570
571
572
    assert(inputs.size() == 6);
    auto seq      = inputs.at(0);
    auto w        = inputs.at(1);
    auto r        = inputs.at(2);
    auto bias     = inputs.at(3);
    auto seq_lens = inputs.at(4);
    auto ih       = inputs.at(5);
573

574
    instruction_ref hidden_states = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
575
    instruction_ref last_output{};
Shucai Xiao's avatar
Shucai Xiao committed
576
    migraphx::shape seq_shape = seq->get_shape();
Shucai Xiao's avatar
Shucai Xiao committed
577
    migraphx::shape r_shape   = r->get_shape();
Paul Fultz II's avatar
Paul Fultz II committed
578
    long hs                   = r_shape.lens()[2];
579

580
581
    migraphx::shape ss(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]});
    std::vector<float> data(ss.elements(), 1.0f);
582
    auto l1 = m.add_literal(migraphx::literal{ss, data});
583

584
    // w matrix squeeze to 2-dim and do a transpose
585
    std::vector<int64_t> perm{1, 0};
586
587
    auto sw = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
    auto tw = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
588

589
    // r slide to two part, zr and h
590
591
    auto sr  = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
    auto rzr = m.insert_instruction(
592
        ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2 * hs}}}), sr);
593
    auto trzr = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rzr);
594

595
    auto rh = m.insert_instruction(
596
        ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), sr);
597
    auto trh = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rh);
598
599

    // initial states
600
    auto sih  = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
Shucai Xiao's avatar
Shucai Xiao committed
601
    size_t bs = ih->get_shape().lens()[1];
602
603

    // bias
604
605
606
    instruction_ref bwb{};
    instruction_ref brb_zr{};
    instruction_ref brb_h{};
607
    if(bias != m.end())
608
    {
609
610
        auto sbias = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
        auto wb    = m.insert_instruction(
611
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {3 * hs}}}), sbias);
612
        bwb = m.insert_instruction(
613
            ins,
614
            make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(3 * hs)}}}),
615
616
            wb);

617
        auto rb_zr = m.insert_instruction(
618
619
620
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {3 * hs}}, {"ends", {5 * hs}}}),
            sbias);
621
        auto rb_h = m.insert_instruction(
622
623
624
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {5 * hs}}, {"ends", {6 * hs}}}),
            sbias);
625
        brb_zr = m.insert_instruction(
626
            ins,
627
            make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(2 * hs)}}}),
628
            rb_zr);
629
        brb_h = m.insert_instruction(
630
            ins,
631
            make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(hs)}}}),
632
            rb_h);
633
634
    }

635
    long seq_len = get_seq_len(m, seq, seq_lens);
636
637
638
    for(long i = 0; i < seq_len; i++)
    {
        long seq_index = is_forward ? i : (seq_len - 1 - i);
639
        auto xt        = m.insert_instruction(
640
641
642
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
            seq);
643
644
        auto cont_xt = m.insert_instruction(ins, make_op("contiguous"), xt);
        xt           = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
645

646
647
648
        auto xt_w    = m.insert_instruction(ins, make_op("dot"), xt, tw);
        auto ih1_rzr = m.insert_instruction(ins, make_op("dot"), sih, trzr);
        if(bias != m.end())
649
        {
650
651
            xt_w    = m.insert_instruction(ins, make_op("add"), xt_w, bwb);
            ih1_rzr = m.insert_instruction(ins, make_op("add"), ih1_rzr, brb_zr);
652
653
        }

654
        auto xw_z = m.insert_instruction(
655
            ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), xt_w);
656
        auto xw_r = m.insert_instruction(
657
            ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), xt_w);
658
        auto xw_h = m.insert_instruction(
659
            ins, make_op("slice", {{"axes", {1}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), xt_w);
660

661
        auto hr_z = m.insert_instruction(
662
            ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), ih1_rzr);
663
        auto hr_r = m.insert_instruction(
664
            ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), ih1_rzr);
665

666
667
        auto xw_hr_z = m.insert_instruction(ins, make_op("add"), xw_z, hr_z);
        auto zt      = m.insert_instruction(ins, actv_func1, xw_hr_z);
668

669
670
        auto xw_hr_r = m.insert_instruction(ins, make_op("add"), xw_r, hr_r);
        auto rt      = m.insert_instruction(ins, actv_func1, xw_hr_r);
671
672
673
674
675

        instruction_ref hr_h{};
        if(linear_before_reset == 0)
        {
            // equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
676
677
678
            auto rt_ht1 = m.insert_instruction(ins, make_op("mul"), rt, sih);
            hr_h        = m.insert_instruction(ins, make_op("dot"), rt_ht1, trh);
            if(bias != m.end())
679
            {
680
                hr_h = m.insert_instruction(ins, make_op("add"), hr_h, brb_h);
681
            }
682
683
684
        }
        else
        {
685
            // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
686
687
            auto ht1_rh = m.insert_instruction(ins, make_op("dot"), sih, trh);
            if(bias != m.end())
688
            {
689
                ht1_rh = m.insert_instruction(ins, make_op("add"), ht1_rh, brb_h);
690
            }
691
            hr_h = m.insert_instruction(ins, make_op("mul"), rt, ht1_rh);
692
693
        }

694
695
        auto xw_hr_h = m.insert_instruction(ins, make_op("add"), xw_h, hr_h);
        auto ht      = m.insert_instruction(ins, actv_func2, xw_hr_h);
696

697
        // equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
698
699
700
701
702
        auto one_minus_zt    = m.insert_instruction(ins, make_op("sub"), l1, zt);
        auto one_minus_zt_ht = m.insert_instruction(ins, make_op("mul"), one_minus_zt, ht);
        auto zt_ht1          = m.insert_instruction(ins, make_op("mul"), zt, sih);
        sih                  = m.insert_instruction(ins, make_op("add"), one_minus_zt_ht, zt_ht1);
        last_output = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), sih);
703
704
705
706
707
708
709
710

        if(i < seq_len - 1)
        {
            if(is_forward)
            {
                hidden_states =
                    (seq_index == 0)
                        ? last_output
711
                        : m.insert_instruction(
712
                              ins, make_op("concat", {{"axis", 0}}), hidden_states, last_output);
713
714
715
716
717
718
            }
            else
            {
                hidden_states =
                    (seq_index == seq_len - 1)
                        ? last_output
719
                        : m.insert_instruction(
720
                              ins, make_op("concat", {{"axis", 0}}), last_output, hidden_states);
721
722
723
724
725
726
727
728
729
730
731
732
733
734
            }
        }
    }

    return {hidden_states, last_output};
}

std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
{
    auto gru_op = any_cast<op::gru>(ins->get_operator());
    // before rewrite the gru operator, need to ensure
    // we have 4 actv funcs, even though a user does not
    // specifiy any actv func. If less than 4, use the
    // algorithm in parse_gru to make 4 actv functions
735
    if(gru_op.direction == op::rnn_direction::bidirectional)
736
737
    {
        if(gru_op.actv_funcs.empty())
738
            return {make_op("sigmoid"), make_op("tanh"), make_op("sigmoid"), make_op("tanh")};
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
        else if(gru_op.actv_funcs.size() == 1)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0)};
        else if(gru_op.actv_funcs.size() == 2)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1)};
        else if(gru_op.actv_funcs.size() == 3)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1),
                    gru_op.actv_funcs.at(2),
                    gru_op.actv_funcs.at(0)};
        else
            return gru_op.actv_funcs;
    }
    else
    {
        if(gru_op.actv_funcs.empty())
760
            return {make_op("sigmoid"), make_op("tanh")};
761
762
763
764
765
766
767
        else if(gru_op.actv_funcs.size() == 1)
            return {gru_op.actv_funcs.at(0), gru_op.actv_funcs.at(0)};
        else
            return gru_op.actv_funcs;
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
768
// for lstm operators
769
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
770
void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
Shucai Xiao's avatar
Shucai Xiao committed
771
772
773
{
    assert(ins->name() == "lstm");
    auto args = ins->inputs();
Shucai Xiao's avatar
Shucai Xiao committed
774

Shucai Xiao's avatar
Shucai Xiao committed
775
    shape seq_shape         = args[0]->get_shape();
776
    std::size_t hidden_size = args[2]->get_shape().lens()[2];
Shucai Xiao's avatar
Shucai Xiao committed
777
778
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
Shucai Xiao's avatar
Shucai Xiao committed
779
    migraphx::shape ihc_shape{type, {1, batch_size, hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
780
    std::vector<float> ihc_data(ihc_shape.elements(), 0.0);
Shucai Xiao's avatar
Shucai Xiao committed
781
782
783

    migraphx::shape pph_shape{type, {1, 3 * hidden_size}};

Shucai Xiao's avatar
Shucai Xiao committed
784
785
    auto actv_funcs         = lstm_actv_funcs(ins);
    auto lstm_op            = any_cast<op::lstm>(ins->get_operator());
Shucai Xiao's avatar
Shucai Xiao committed
786
    op::rnn_direction dirct = lstm_op.direction;
Shucai Xiao's avatar
Shucai Xiao committed
787

Shucai Xiao's avatar
Shucai Xiao committed
788
    // process sequence length
789
    instruction_ref seq_lens = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
790
791
792
793
794
    if((args.size() >= 5) && args[4]->name() != "undefined")
    {
        seq_lens = args[4];
    }

795
    bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
796
797

    instruction_ref last_hs_output{};
Shucai Xiao's avatar
Shucai Xiao committed
798
    instruction_ref last_cell_output{};
Shucai Xiao's avatar
Shucai Xiao committed
799
800
    instruction_ref hidden_state{};
    instruction_ref cell_outputs{};
Shucai Xiao's avatar
Shucai Xiao committed
801
    if(dirct == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
802
803
804
    {
        // input weight matrix
        // input weight matrix
805
        auto w_forward = m.insert_instruction(
806
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
807
        auto w_reverse = m.insert_instruction(
808
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
Shucai Xiao's avatar
Shucai Xiao committed
809
810

        // hidden state weight matrix
811
        auto r_forward = m.insert_instruction(
812
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
813
        auto r_reverse = m.insert_instruction(
814
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
Shucai Xiao's avatar
Shucai Xiao committed
815
816

        // process bias
817
818
        instruction_ref bias_forward = m.end();
        instruction_ref bias_reverse = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
819
        if(args.size() >= 4 && args[3]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
820
        {
821
            bias_forward = m.insert_instruction(
822
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
823
            bias_reverse = m.insert_instruction(
824
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
Shucai Xiao's avatar
Shucai Xiao committed
825
826
827
828
829
830
831
        }

        // process intial hidden state, it is the 6th argument
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
        if(args.size() >= 6 && args[5]->name() != "undefined")
        {
832
            ih_forward = m.insert_instruction(
833
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
834
            ih_reverse = m.insert_instruction(
835
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
Shucai Xiao's avatar
Shucai Xiao committed
836
837
838
        }
        else
        {
839
840
            ih_forward = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
            ih_reverse = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
Shucai Xiao's avatar
Shucai Xiao committed
841
842
843
844
845
        }

        // process initial cell value
        instruction_ref ic_forward{};
        instruction_ref ic_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
846
        if(args.size() >= 7 && args[6]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
847
        {
848
            ic_forward = m.insert_instruction(
849
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[6]);
850
            ic_reverse = m.insert_instruction(
851
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[6]);
Shucai Xiao's avatar
Shucai Xiao committed
852
853
854
        }
        else
        {
855
856
            ic_forward = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
            ic_reverse = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
Shucai Xiao's avatar
Shucai Xiao committed
857
858
859
        }

        // process weight of the peephole
860
861
        instruction_ref pph_forward = m.end();
        instruction_ref pph_reverse = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
862
        if(args.size() == 8 && args[7]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
863
        {
864
            pph_forward = m.insert_instruction(
865
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[7]);
866
            pph_reverse = m.insert_instruction(
867
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[7]);
Shucai Xiao's avatar
Shucai Xiao committed
868
        }
Shucai Xiao's avatar
Shucai Xiao committed
869

Shucai Xiao's avatar
Shucai Xiao committed
870
        auto ret_forward = lstm_cell(true,
871
                                     m,
Shucai Xiao's avatar
Shucai Xiao committed
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
                                     ins,
                                     {args[0],
                                      w_forward,
                                      r_forward,
                                      bias_forward,
                                      seq_lens,
                                      ih_forward,
                                      ic_forward,
                                      pph_forward},
                                     actv_funcs.at(0),
                                     actv_funcs.at(1),
                                     actv_funcs.at(2));

        if(variable_seq_len)
        {
887
888
            args[0] =
                m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
889
890
        }
        auto ret_reverse = lstm_cell(false,
891
                                     m,
Shucai Xiao's avatar
Shucai Xiao committed
892
893
894
895
896
897
898
899
900
901
902
903
904
                                     ins,
                                     {args[0],
                                      w_reverse,
                                      r_reverse,
                                      bias_reverse,
                                      seq_lens,
                                      ih_reverse,
                                      ic_reverse,
                                      pph_reverse},
                                     actv_funcs.at(3),
                                     actv_funcs.at(4),
                                     actv_funcs.at(5));

905
        auto concat_hs_output = m.insert_instruction(
906
            ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
907
        auto concat_cell_output = m.insert_instruction(
908
909
            ins, make_op("concat", {{"axis", 1}}), ret_forward[3], ret_reverse[3]);
        last_hs_output =
910
            m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_hs_output);
911
        last_cell_output =
912
            m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_cell_output);
Shucai Xiao's avatar
Shucai Xiao committed
913
914

        // the following logic is to ensure the last instruction is a concat
915
        if(ret_forward[0] == m.end())
Shucai Xiao's avatar
Shucai Xiao committed
916
        {
Shucai Xiao's avatar
Shucai Xiao committed
917
            cell_outputs = concat_cell_output;
Shucai Xiao's avatar
Shucai Xiao committed
918
919
920
        }
        else
        {
921
            ret_forward[1] = m.insert_instruction(
922
                ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
923
            ret_reverse[1] = m.insert_instruction(
924
                ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
Shucai Xiao's avatar
Shucai Xiao committed
925

926
            ret_forward[3] = m.insert_instruction(
927
                ins, make_op("concat", {{"axis", 0}}), ret_forward[2], ret_forward[3]);
928
            ret_reverse[3] = m.insert_instruction(
929
                ins, make_op("concat", {{"axis", 0}}), ret_reverse[3], ret_reverse[2]);
930
            cell_outputs = m.insert_instruction(
931
                ins, make_op("concat", {{"axis", 1}}), ret_forward[3], ret_reverse[3]);
Shucai Xiao's avatar
Shucai Xiao committed
932
        }
Shucai Xiao's avatar
Shucai Xiao committed
933

934
        hidden_state = m.replace_instruction(
935
            ins, make_op("concat", {{"axis", 1}}), {ret_forward[1], ret_reverse[1]});
Shucai Xiao's avatar
Shucai Xiao committed
936
937
938
    }
    else
    {
Shucai Xiao's avatar
Shucai Xiao committed
939
        bool is_forward = (dirct == op::rnn_direction::forward);
Shucai Xiao's avatar
Shucai Xiao committed
940
941
942
943
944
        // weight matrices
        auto w = args[1];
        auto r = args[2];

        // bias
945
        instruction_ref bias = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
946
        if(args.size() >= 4 && args[3]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
947
948
949
950
951
952
        {
            bias = args[3];
        }

        // initial hidden state
        instruction_ref ih{};
Shucai Xiao's avatar
Shucai Xiao committed
953
        if(args.size() >= 6 && args[5]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
954
955
956
957
958
        {
            ih = args[5];
        }
        else
        {
959
            ih = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
Shucai Xiao's avatar
Shucai Xiao committed
960
961
962
963
        }

        // initial cell value
        instruction_ref ic{};
Shucai Xiao's avatar
Shucai Xiao committed
964
        if(args.size() >= 7 && args[6]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
965
966
967
968
969
        {
            ic = args[6];
        }
        else
        {
970
            ic = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
Shucai Xiao's avatar
Shucai Xiao committed
971
972
973
        }

        // process weight of the peephole
974
        instruction_ref pph = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
975
976
977
978
        if(args.size() == 8 && args[7]->name() != "undefined")
        {
            pph = args[7];
        }
Shucai Xiao's avatar
Shucai Xiao committed
979

Shucai Xiao's avatar
Shucai Xiao committed
980
981
        if(!is_forward and variable_seq_len)
        {
982
983
            args[0] =
                m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
984
        }
Shucai Xiao's avatar
Shucai Xiao committed
985
        auto ret = lstm_cell(is_forward,
986
                             m,
Shucai Xiao's avatar
Shucai Xiao committed
987
                             ins,
Shucai Xiao's avatar
Shucai Xiao committed
988
                             {args[0], w, r, bias, seq_lens, ih, ic, pph},
Shucai Xiao's avatar
Shucai Xiao committed
989
990
991
992
                             actv_funcs.at(0),
                             actv_funcs.at(1),
                             actv_funcs.at(2));

993
994
        last_hs_output   = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
        last_cell_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[3]);
Shucai Xiao's avatar
Shucai Xiao committed
995

996
        if(ret[0] == m.end())
Shucai Xiao's avatar
Shucai Xiao committed
997
        {
Shucai Xiao's avatar
Shucai Xiao committed
998
            cell_outputs = ret[3];
999
            hidden_state = m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
Shucai Xiao's avatar
Shucai Xiao committed
1000
1001
1002
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
1003
1004
            auto concat_cell_arg0 = is_forward ? ret[2] : ret[3];
            auto concat_cell_arg1 = is_forward ? ret[3] : ret[2];
1005
            cell_outputs          = m.insert_instruction(
1006
                ins, make_op("concat", {{"axis", 0}}), concat_cell_arg0, concat_cell_arg1);
Shucai Xiao's avatar
Shucai Xiao committed
1007

Shucai Xiao's avatar
Shucai Xiao committed
1008
1009
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
1010
            hidden_state     = m.replace_instruction(
1011
                ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
Shucai Xiao's avatar
Shucai Xiao committed
1012
1013
1014
        }
    }

1015
1016
    // in case of all sequences are of the same lengths and shorter than the
    // max sequence length, need to pad 0's at the end for output hidden states
1017
    hidden_state = pad_hidden_states(m, args[0], seq_lens, hidden_state);
1018
1019

    // replace last hidden states with corresponding instructions
1020
    ins = replace_last_hs_output(m, hidden_state, seq_lens, last_hs_output, dirct);
1021
1022

    // replace last cell outputs with corresponding instructions
1023
    replace_last_cell_output(m, ins, seq_lens, cell_outputs, last_cell_output, dirct);
Shucai Xiao's avatar
Shucai Xiao committed
1024
1025
}

1026
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
Shucai Xiao's avatar
Shucai Xiao committed
1027
std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
1028
                                                    module& m,
Shucai Xiao's avatar
Shucai Xiao committed
1029
1030
1031
1032
1033
                                                    instruction_ref ins,
                                                    std::vector<instruction_ref> inputs,
                                                    const operation& actv_func1,
                                                    const operation& actv_func2,
                                                    const operation& actv_func3) const
Shucai Xiao's avatar
Shucai Xiao committed
1034
{
Shucai Xiao's avatar
Shucai Xiao committed
1035
    // must have 7 args in the input vector
Shucai Xiao's avatar
Shucai Xiao committed
1036
1037
1038
1039
1040
1041
1042
1043
1044
    assert(inputs.size() == 8);
    auto seq      = inputs.at(0);
    auto w        = inputs.at(1);
    auto r        = inputs.at(2);
    auto bias     = inputs.at(3);
    auto seq_lens = inputs.at(4);
    auto ih       = inputs.at(5);
    auto ic       = inputs.at(6);
    auto pph      = inputs.at(7);
Shucai Xiao's avatar
Shucai Xiao committed
1045

1046
1047
    instruction_ref hidden_states = m.end();
    instruction_ref cell_outputs  = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
1048
1049

    instruction_ref last_hs_output{};
Shucai Xiao's avatar
Shucai Xiao committed
1050
1051
    instruction_ref last_cell_output{};

1052
    migraphx::shape r_shape = r->get_shape();
Paul Fultz II's avatar
Paul Fultz II committed
1053
    long hs                 = r_shape.lens()[2];
1054
    auto bs                 = ih->get_shape().lens()[1];
Shucai Xiao's avatar
Shucai Xiao committed
1055
1056

    std::vector<int64_t> perm{1, 0};
1057
    // w matrix, squeeze and transpose
1058
1059
    auto sw  = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
    auto tsw = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
Shucai Xiao's avatar
Shucai Xiao committed
1060

1061
    // r matrix, squeeze and transpose
1062
1063
    auto sr  = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
    auto tsr = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr);
Shucai Xiao's avatar
Shucai Xiao committed
1064

Shucai Xiao's avatar
Shucai Xiao committed
1065
    // initial hidden state
1066
    auto sih = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
Shucai Xiao's avatar
Shucai Xiao committed
1067
1068

    // initial cell state
1069
    auto sic     = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ic);
1070
    auto ic_lens = sic->get_shape().lens();
Shucai Xiao's avatar
Shucai Xiao committed
1071

1072
    // bias
1073
    instruction_ref wrb{};
1074
    if(bias != m.end())
1075
    {
1076

1077
1078
        auto sbias = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
        auto ub_wb = m.insert_instruction(
1079
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {4 * hs}}}), sbias);
1080
        auto ub_rb = m.insert_instruction(
1081
1082
1083
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {4 * hs}}, {"ends", {8 * hs}}}),
            sbias);
1084
        auto ub_wrb = m.insert_instruction(ins, make_op("add"), ub_wb, ub_rb);
1085

1086
        wrb = m.insert_instruction(
1087
            ins,
1088
            make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, 4 * static_cast<size_t>(hs)}}}),
1089
            ub_wrb);
1090
1091
    }

Shucai Xiao's avatar
Shucai Xiao committed
1092
    // peep hole
Shucai Xiao's avatar
Shucai Xiao committed
1093
1094
1095
    instruction_ref pphi_brcst{};
    instruction_ref ppho_brcst{};
    instruction_ref pphf_brcst{};
1096
    if(pph != m.end())
Shucai Xiao's avatar
Shucai Xiao committed
1097
    {
1098
1099
        auto spph = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), pph);
        auto pphi = m.insert_instruction(
1100
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), spph);
1101
        pphi_brcst = m.insert_instruction(
1102
            ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), pphi);
1103

1104
        auto ppho = m.insert_instruction(
1105
            ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), spph);
1106
        ppho_brcst = m.insert_instruction(
1107
            ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), ppho);
1108

1109
        auto pphf = m.insert_instruction(
1110
            ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), spph);
1111
        pphf_brcst = m.insert_instruction(
1112
            ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), pphf);
Shucai Xiao's avatar
Shucai Xiao committed
1113
    }
Shucai Xiao's avatar
Shucai Xiao committed
1114

1115
    long seq_len = get_seq_len(m, seq, seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1116
    for(long i = 0; i < seq_len; ++i)
Shucai Xiao's avatar
Shucai Xiao committed
1117
1118
    {
        long seq_index = is_forward ? i : (seq_len - 1 - i);
1119
        auto xt        = m.insert_instruction(
1120
1121
1122
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
            seq);
1123
1124
        auto cont_xt = m.insert_instruction(ins, make_op("contiguous"), xt);
        xt           = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
1125

1126
1127
1128
1129
        auto xt_tsw  = m.insert_instruction(ins, make_op("dot"), xt, tsw);
        auto sih_tsr = m.insert_instruction(ins, make_op("dot"), sih, tsr);
        auto xt_sih  = m.insert_instruction(ins, make_op("add"), xt_tsw, sih_tsr);
        if(bias != m.end())
1130
        {
1131
            xt_sih = m.insert_instruction(ins, make_op("add"), xt_sih, wrb);
1132
        }
Shucai Xiao's avatar
Shucai Xiao committed
1133

1134
        auto it_before_actv = m.insert_instruction(
1135
            ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), xt_sih);
1136
        auto ot_before_actv = m.insert_instruction(
1137
            ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), xt_sih);
1138
        auto ft_before_actv = m.insert_instruction(
1139
1140
1141
            ins,
            make_op("slice", {{"axes", {1}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}),
            xt_sih);
1142
        auto ct_before_actv = m.insert_instruction(
1143
1144
1145
            ins,
            make_op("slice", {{"axes", {1}}, {"starts", {3 * hs}}, {"ends", {4 * hs}}}),
            xt_sih);
1146

1147
        if(pph != m.end())
Shucai Xiao's avatar
Shucai Xiao committed
1148
        {
1149
1150
            auto pphi_ct   = m.insert_instruction(ins, make_op("mul"), pphi_brcst, sic);
            it_before_actv = m.insert_instruction(ins, make_op("add"), it_before_actv, pphi_ct);
Shucai Xiao's avatar
Shucai Xiao committed
1151

1152
1153
            auto pphf_ct   = m.insert_instruction(ins, make_op("mul"), pphf_brcst, sic);
            ft_before_actv = m.insert_instruction(ins, make_op("add"), ft_before_actv, pphf_ct);
Shucai Xiao's avatar
Shucai Xiao committed
1154
        }
1155
1156
1157
        auto it = m.insert_instruction(ins, actv_func1, it_before_actv);
        auto ft = m.insert_instruction(ins, actv_func1, ft_before_actv);
        auto ct = m.insert_instruction(ins, actv_func2, ct_before_actv);
1158
1159

        // equation Ct = ft (.) Ct-1 + it (.) ct
1160
1161
1162
        auto ft_cell = m.insert_instruction(ins, make_op("mul"), ft, sic);
        auto it_ct   = m.insert_instruction(ins, make_op("mul"), it, ct);
        auto cellt   = m.insert_instruction(ins, make_op("add"), ft_cell, it_ct);
1163

1164
        if(pph != m.end())
Shucai Xiao's avatar
Shucai Xiao committed
1165
        {
1166
1167
            auto ppho_cellt = m.insert_instruction(ins, make_op("mul"), ppho_brcst, cellt);
            ot_before_actv  = m.insert_instruction(ins, make_op("add"), ot_before_actv, ppho_cellt);
Shucai Xiao's avatar
Shucai Xiao committed
1168
        }
1169
        auto ot = m.insert_instruction(ins, actv_func1, ot_before_actv);
1170
1171

        // Ht = ot (.) h(Ct)
1172
1173
        auto h_cellt = m.insert_instruction(ins, actv_func3, cellt);
        auto ht      = m.insert_instruction(ins, make_op("mul"), ot, h_cellt);
1174
1175
1176
1177

        sic = cellt;
        sih = ht;

1178
        last_hs_output = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), ht);
1179
        last_cell_output =
1180
            m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), cellt);
1181

Shucai Xiao's avatar
Shucai Xiao committed
1182
        if(i < seq_len - 1)
1183
        {
Shucai Xiao's avatar
Shucai Xiao committed
1184
            if(i == 0)
1185
            {
Shucai Xiao's avatar
Shucai Xiao committed
1186
1187
                hidden_states = last_hs_output;
                cell_outputs  = last_cell_output;
1188
1189
1190
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
1191
1192
                auto concat_hs_arg0 = is_forward ? hidden_states : last_hs_output;
                auto concat_hs_arg1 = is_forward ? last_hs_output : hidden_states;
1193
                hidden_states       = m.insert_instruction(
1194
                    ins, make_op("concat", {{"axis", 0}}), concat_hs_arg0, concat_hs_arg1);
Shucai Xiao's avatar
Shucai Xiao committed
1195
1196
1197

                auto concat_cell_arg0 = is_forward ? cell_outputs : last_cell_output;
                auto concat_cell_arg1 = is_forward ? last_cell_output : cell_outputs;
1198
                cell_outputs          = m.insert_instruction(
1199
                    ins, make_op("concat", {{"axis", 0}}), concat_cell_arg0, concat_cell_arg1);
1200
1201
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
1202
    }
1203

Shucai Xiao's avatar
Shucai Xiao committed
1204
    return {hidden_states, last_hs_output, cell_outputs, last_cell_output};
Shucai Xiao's avatar
Shucai Xiao committed
1205
1206
1207
1208
1209
1210
1211
1212
1213
}

std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
{
    auto lstm_op = any_cast<op::lstm>(ins->get_operator());
    // before rewrite the lstm operator, need to ensure
    // we have 6 actv funcs, even though a user does not
    // specifiy any actv func. If less than 46, use the
    // algorithm in parse_lstm to make 6 actv functions
Shucai Xiao's avatar
Shucai Xiao committed
1214
    const auto& actv_funcs     = lstm_op.actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1215
    std::size_t num_actv_funcs = actv_funcs.size();
Shucai Xiao's avatar
Shucai Xiao committed
1216
    if(lstm_op.direction == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1217
1218
1219
1220
    {
        switch(num_actv_funcs)
        {
        case 0:
1221
1222
1223
1224
1225
1226
            return {make_op("sigmoid"),
                    make_op("tanh"),
                    make_op("tanh"),
                    make_op("sigmoid"),
                    make_op("tanh"),
                    make_op("tanh")};
Shucai Xiao's avatar
Shucai Xiao committed
1227
1228

        case 1:
Shucai Xiao's avatar
Shucai Xiao committed
1229
1230
1231
1232
1233
1234
            return {actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0)};
Shucai Xiao's avatar
Shucai Xiao committed
1235
1236

        case 2:
Shucai Xiao's avatar
Shucai Xiao committed
1237
1238
1239
1240
1241
1242
1243
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(1),
                    actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(1)};

Shucai Xiao's avatar
Shucai Xiao committed
1244
        case 3:
Shucai Xiao's avatar
Shucai Xiao committed
1245
1246
1247
1248
1249
1250
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2)};
Shucai Xiao's avatar
Shucai Xiao committed
1251
1252

        case 4:
Shucai Xiao's avatar
Shucai Xiao committed
1253
1254
1255
1256
1257
1258
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(3),
                    actv_funcs.at(3),
                    actv_funcs.at(3)};
Shucai Xiao's avatar
Shucai Xiao committed
1259
1260

        case 5:
Shucai Xiao's avatar
Shucai Xiao committed
1261
1262
1263
1264
1265
1266
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(3),
                    actv_funcs.at(4),
                    actv_funcs.at(4)};
Shucai Xiao's avatar
Shucai Xiao committed
1267

Shucai Xiao's avatar
Shucai Xiao committed
1268
        default: return actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1269
1270
1271
1272
1273
1274
        }
    }
    else
    {
        switch(num_actv_funcs)
        {
1275
        case 0: return {make_op("sigmoid"), make_op("tanh"), make_op("tanh")};
Shucai Xiao's avatar
Shucai Xiao committed
1276

Shucai Xiao's avatar
Shucai Xiao committed
1277
        case 1: return {actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0)};
Shucai Xiao's avatar
Shucai Xiao committed
1278

Shucai Xiao's avatar
Shucai Xiao committed
1279
1280
1281
        case 2: return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1)};

        default: return actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1282
1283
1284
1285
        }
    }
}

1286
bool rewrite_rnn::is_variable_seq_lens(const module& m, instruction_ref seq_lens) const
Shucai Xiao's avatar
Shucai Xiao committed
1287
1288
{
    bool is_var_lens = false;
1289
    if(seq_lens != m.end())
Shucai Xiao's avatar
Shucai Xiao committed
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
    {
        if(seq_lens->can_eval())
        {
            auto arg_lens = seq_lens->eval();
            std::vector<int64_t> vec_lens;
            arg_lens.visit([&](auto l) { vec_lens.assign(l.begin(), l.end()); });
            int64_t l = 0;
            if(!vec_lens.empty())
            {
                l = vec_lens[0];
            }
            if(!std::all_of(vec_lens.begin(), vec_lens.end(), [&](auto v) { return v == l; }))
            {
                is_var_lens = true;
            }
        }
        else
        {
            is_var_lens = true;
        }
    }

    return is_var_lens;
}

std::size_t
1316
rewrite_rnn::get_seq_len(const module& m, instruction_ref input, instruction_ref seq_lens) const
Shucai Xiao's avatar
Shucai Xiao committed
1317
{
1318
    bool is_var_lens = is_variable_seq_lens(m, seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1319
1320
    auto input_shape = input->get_shape();
    auto length      = input_shape.lens()[0];
1321
    if(!is_var_lens and seq_lens != m.end())
Shucai Xiao's avatar
Shucai Xiao committed
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
    {
        auto arg_len = seq_lens->eval();
        std::vector<std::size_t> vec_lens;
        arg_len.visit([&](auto l) { vec_lens.assign(l.begin(), l.end()); });
        length = vec_lens.empty() ? length : vec_lens[0];
    }

    return length;
}

1332
instruction_ref rewrite_rnn::replace_last_hs_output(module& m,
Shucai Xiao's avatar
Shucai Xiao committed
1333
1334
1335
1336
1337
                                                    instruction_ref ins,
                                                    instruction_ref seq_lens,
                                                    instruction_ref last_hs_output,
                                                    op::rnn_direction dirct) const
{
1338
    bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1339
1340
1341
    instruction_ref result_ins{};
    if(variable_seq_len)
    {
1342
1343
1344
1345
1346
1347
1348
        result_ins =
            m.insert_instruction(std::next(ins),
                                 make_op("rnn_var_sl_shift_output",
                                         {{"output_name", "hidden_states"}, {"direction", dirct}}),
                                 ins,
                                 seq_lens);
        m.replace_instruction(ins, result_ins);
1349
1350
        auto hs_outputs = find_all(result_ins->outputs(),
                                   [&](auto i) { return i->name() == "rnn_last_hs_output"; });
Shucai Xiao's avatar
Shucai Xiao committed
1351

1352
        for(auto& hs_out : hs_outputs)
Shucai Xiao's avatar
Shucai Xiao committed
1353
        {
1354
            auto inputs = hs_out->inputs();
1355
1356
1357
1358
            m.replace_instruction(hs_out,
                                  make_op("rnn_var_sl_last_output", {{"direction", dirct}}),
                                  inputs.front(),
                                  seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1359
1360
1361
1362
        }
    }
    else
    {
1363
1364
        auto hs_outputs =
            find_all(ins->outputs(), [&](auto i) { return i->name() == "rnn_last_hs_output"; });
Shucai Xiao's avatar
Shucai Xiao committed
1365

1366
1367
        for(auto& hs_out : hs_outputs)
        {
1368
            m.replace_instruction(hs_out, last_hs_output);
Shucai Xiao's avatar
Shucai Xiao committed
1369
        }
1370

Shucai Xiao's avatar
Shucai Xiao committed
1371
1372
1373
1374
1375
1376
        result_ins = ins;
    }

    return result_ins;
}

1377
void rewrite_rnn::replace_last_cell_output(module& m,
Shucai Xiao's avatar
Shucai Xiao committed
1378
1379
1380
1381
1382
1383
                                           instruction_ref ins,
                                           instruction_ref seq_lens,
                                           instruction_ref cell_outputs,
                                           instruction_ref last_cell_output,
                                           op::rnn_direction dirct) const
{
1384
    bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
1385
1386
    auto ins_outputs =
        find_all(ins->outputs(), [&](auto i) { return i->name() == "rnn_last_cell_output"; });
Shucai Xiao's avatar
Shucai Xiao committed
1387
1388
1389

    if(variable_seq_len)
    {
1390
        if(!ins_outputs.empty())
Shucai Xiao's avatar
Shucai Xiao committed
1391
        {
1392
            cell_outputs = m.insert_instruction(
1393
1394
1395
1396
1397
                std::next(ins),
                make_op("rnn_var_sl_shift_output",
                        {{"output_name", "cell_outputs"}, {"direction", dirct}}),
                cell_outputs,
                seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1398
1399
        }

1400
        for(auto co : ins_outputs)
Shucai Xiao's avatar
Shucai Xiao committed
1401
        {
1402
1403
1404
1405
            m.replace_instruction(co,
                                  make_op("rnn_var_sl_last_output", {{"direction", dirct}}),
                                  cell_outputs,
                                  seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1406
1407
1408
1409
1410
1411
        }
    }
    // replace the rnn_last_cell_output with the last_cell_output. The while
    // loop is to handle the case of multiple rnn_last_cell_output operators
    else
    {
1412
        for(auto co : ins_outputs)
Shucai Xiao's avatar
Shucai Xiao committed
1413
        {
1414
            m.replace_instruction(co, last_cell_output);
Shucai Xiao's avatar
Shucai Xiao committed
1415
1416
1417
1418
        }
    }
}

1419
instruction_ref rewrite_rnn::pad_hidden_states(module& m,
1420
1421
1422
1423
1424
                                               instruction_ref seq,
                                               instruction_ref seq_lens,
                                               instruction_ref hs) const
{
    auto max_seq_len = seq->get_shape().lens()[0];
1425
    auto seq_len     = get_seq_len(m, seq, seq_lens);
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436

    // condition of all sequence are of the same length and
    // less than max_seq_len, we need to append the hs outputs
    auto hs_padded = hs;
    if(seq_len < max_seq_len)
    {
        auto s        = hs->get_shape();
        auto pad_lens = s.lens();
        pad_lens[0]   = static_cast<std::size_t>(max_seq_len - seq_len);
        shape pad_s{s.type(), pad_lens};
        std::vector<float> pad_data(pad_s.elements(), 0.0f);
1437
1438
1439
        auto pl   = m.add_literal(pad_s, pad_data.begin(), pad_data.end());
        hs_padded = m.insert_instruction(std::next(hs), make_op("concat", {{"axis", 0}}), hs, pl);
        m.replace_instruction(hs, hs_padded);
1440
1441
1442
1443
1444
    }

    return hs_padded;
}

Shucai Xiao's avatar
Shucai Xiao committed
1445
1446
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx