rewrite_rnn.cpp 54.6 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
Paul's avatar
Paul committed
4
5
6
7
8
9
10
11
12
13
14
15
16
#include <migraphx/op/add.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/gru.hpp>
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/rnn.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/squeeze.hpp>
#include <migraphx/op/sub.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/unsqueeze.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
17
#include <migraphx/op/contiguous.hpp>
18
19
20
#include <migraphx/op/common.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
21
22
#include <migraphx/make_op.hpp>

Shucai Xiao's avatar
Shucai Xiao committed
23
24
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
25
#include <migraphx/ranges.hpp>
26
#include <migraphx/op/common.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
27
28
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
29
30
31
32

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

33
void rewrite_rnn::apply(module& m) const
Shucai Xiao's avatar
Shucai Xiao committed
34
{
35
    for(auto ins : iterator_for(m))
Shucai Xiao's avatar
Shucai Xiao committed
36
    {
Shucai Xiao's avatar
Shucai Xiao committed
37
        if(ins->name() == "rnn")
Shucai Xiao's avatar
Shucai Xiao committed
38
        {
39
            apply_vanilla_rnn(m, ins);
40
        }
41
        else if(ins->name() == "gru")
42
        {
43
            apply_gru(m, ins);
Shucai Xiao's avatar
Shucai Xiao committed
44
        }
45
46
        else if(ins->name() == "lstm")
        {
47
            apply_lstm(m, ins);
48
        }
Shucai Xiao's avatar
Shucai Xiao committed
49
    }
50
51
}

52
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
53
void rewrite_rnn::apply_vanilla_rnn(module& m, instruction_ref ins) const
54
55
56
57
{
    assert(ins->name() == "rnn");
    // could be 3 to 6 inputs, but the parse_rnn function will
    // append undefined operators to make 6 arguments when parsing
Shucai Xiao's avatar
Shucai Xiao committed
58
    // an onnx file. Another case is user can have num of arguments
59
    // when writing their module.
60
61
62
63
64
65
66
67
68
    auto args = ins->inputs();

    shape seq_shape         = args[0]->get_shape();
    std::size_t hidden_size = args[1]->get_shape().lens()[1];
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
    migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
    std::vector<float> data(ih_shape.elements(), 0);

Shucai Xiao's avatar
Shucai Xiao committed
69
70
    auto actv_funcs         = vanilla_rnn_actv_funcs(ins);
    auto rnn_op             = any_cast<op::rnn>(ins->get_operator());
71
72
73
    op::rnn_direction dirct = rnn_op.direction;

    // process sequence length
74
    instruction_ref seq_lens = m.end();
75
76
77
78
79
    if((args.size() >= 5) && args[4]->name() != "undefined")
    {
        seq_lens = args[4];
    }

80
    bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
81

82
    instruction_ref last_output{};
83
    if(dirct == op::rnn_direction::bidirectional)
84
85
    {
        // input weight matrix
86
        auto w_forward = m.insert_instruction(
87
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
88
        auto w_reverse = m.insert_instruction(
89
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
90
91

        // hidden state weight matrix
92
        auto r_forward = m.insert_instruction(
93
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
94
        auto r_reverse = m.insert_instruction(
95
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
96
97

        // process bias
98
99
        instruction_ref bias_forward = m.end();
        instruction_ref bias_reverse = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
100
        if(args.size() >= 4 && args[3]->name() != "undefined")
101
        {
102
            bias_forward = m.insert_instruction(
103
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
104
            bias_reverse = m.insert_instruction(
105
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
106
107
108
109
110
111
        }

        // process intial hidden state, it could be the 6th argument
        // or the 5th one (if the sequence len argument is ignored)
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
112
        if(args.size() == 6 && args[5]->name() != "undefined")
113
        {
114
            ih_forward = m.insert_instruction(
115
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
116
            ih_reverse = m.insert_instruction(
117
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
118
119
120
        }
        else
        {
121
122
            ih_forward = m.add_literal(migraphx::literal{ih_shape, data});
            ih_reverse = m.add_literal(migraphx::literal{ih_shape, data});
123
124
        }

125
126
        auto ret_forward =
            vanilla_rnn_cell(true,
127
                             m,
128
129
130
131
132
133
                             ins,
                             {args[0], w_forward, r_forward, bias_forward, seq_lens, ih_forward},
                             actv_funcs.at(0));

        if(variable_seq_len)
        {
134
135
            args[0] =
                m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
136
137
138
139
        }

        auto ret_reverse =
            vanilla_rnn_cell(false,
140
                             m,
141
142
143
                             ins,
                             {args[0], w_reverse, r_reverse, bias_reverse, seq_lens, ih_reverse},
                             actv_funcs.at(1));
144

145
        auto concat_output = m.insert_instruction(
146
            ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
147
        last_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_output);
148
149
150
151

        // The following logic is to ensure the last instruction rewritten from
        // rnn operator is a concat instruction
        // sequence len is 1
152
        if(ret_forward[0] == m.end())
153
        {
154
            m.replace_instruction(
155
                ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
156
157
158
        }
        else
        {
159
            ret_forward[0] = m.insert_instruction(
160
                ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
161
            ret_reverse[0] = m.insert_instruction(
162
                ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
163
            m.replace_instruction(
164
                ins, make_op("concat", {{"axis", 1}}), {ret_forward[0], ret_reverse[0]});
165
166
167
168
        }
    }
    else
    {
169
        bool is_forward = (dirct == op::rnn_direction::forward);
170
171
172
173
174
175
176
        // input weight matrix
        auto w = args[1];

        // hidden state weight matrix
        auto r = args[2];

        // process bias and initial hidden state
177
        instruction_ref bias = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
178
        if(args.size() >= 4 && args[3]->name() != "undefined")
179
180
181
182
183
184
        {
            bias = args[3];
        }

        // process intial hidden state
        instruction_ref ih;
Shucai Xiao's avatar
Shucai Xiao committed
185
        if(args.size() == 6 && args[5]->name() != "undefined")
186
187
188
189
190
        {
            ih = args[5];
        }
        else
        {
191
            ih = m.add_literal(migraphx::literal{ih_shape, data});
192
193
        }

194
195
        if(!is_forward and variable_seq_len)
        {
196
197
            args[0] =
                m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
198
199
200
        }

        auto ret = vanilla_rnn_cell(
201
202
            is_forward, m, ins, {args[0], w, r, bias, seq_lens, ih}, actv_funcs.at(0));
        last_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
203
204
205
206

        // following logic is to ensure the last instruction is a
        // concat instruction
        // sequence len is 1
207
        if(ret[0] == m.end())
208
        {
209
            m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
210
211
212
213
214
        }
        else
        {
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
215
            m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
216
217
218
        }
    }

219
220
    // in case of all sequences are of the same lengths and shorter than the
    // max sequence length, need to pad 0's at the end for output hidden states
221
222
    ins = pad_hidden_states(m, args[0], seq_lens, ins);
    replace_last_hs_output(m, ins, seq_lens, last_output, dirct);
Shucai Xiao's avatar
Shucai Xiao committed
223
224
}

Shucai Xiao's avatar
Shucai Xiao committed
225
std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
226
                                                           module& m,
Shucai Xiao's avatar
Shucai Xiao committed
227
                                                           instruction_ref ins,
228
                                                           std::vector<instruction_ref> inputs,
Shucai Xiao's avatar
Shucai Xiao committed
229
                                                           operation& actv_func) const
Shucai Xiao's avatar
Shucai Xiao committed
230
{
231
232
233
234
235
236
237
238
    assert(inputs.size() == 6);
    auto seq      = inputs.at(0);
    auto w        = inputs.at(1);
    auto r        = inputs.at(2);
    auto bias     = inputs.at(3);
    auto seq_lens = inputs.at(4);
    auto ih       = inputs.at(5);

Shucai Xiao's avatar
Shucai Xiao committed
239
240
    // squeeze and transpose w
    std::vector<int64_t> perm{1, 0};
241
242
    auto sw      = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
    auto tran_sw = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
Shucai Xiao's avatar
Shucai Xiao committed
243
244

    // squeeze and transpose r
245
246
    auto sr      = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
    auto tran_sr = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr);
Shucai Xiao's avatar
Shucai Xiao committed
247
248

    // initial hidden state
249
    auto sih      = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
250
    auto sih_lens = sih->get_shape().lens();
Shucai Xiao's avatar
Shucai Xiao committed
251
252

    // bias
253
    instruction_ref bb{};
254
    if(bias != m.end())
Shucai Xiao's avatar
Shucai Xiao committed
255
    {
256
        long hs    = static_cast<long>(r->get_shape().lens()[2]);
257
258
        auto sbias = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
        auto wb    = m.insert_instruction(
259
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), sbias);
260
        auto rb = m.insert_instruction(
261
            ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), sbias);
262
263
        auto wrb = m.insert_instruction(ins, make_op("add"), wb, rb);
        bb       = m.insert_instruction(
264
            ins, make_op("broadcast", {{"axis", 1}, {"out_lens", sih_lens}}), wrb);
Shucai Xiao's avatar
Shucai Xiao committed
265
266
    }

267
    instruction_ref hidden_out = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
268
    instruction_ref last_out{};
269
270
    last_out     = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), sih);
    long seq_len = get_seq_len(m, seq, seq_lens);
271
    for(long i = 0; i < seq_len; i++)
Shucai Xiao's avatar
Shucai Xiao committed
272
    {
Shucai Xiao's avatar
Shucai Xiao committed
273
        long seq_index = is_forward ? i : (seq_len - 1 - i);
274
        auto xt        = m.insert_instruction(
275
276
277
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
            seq);
278
279
280
281
282
        auto cont_xt = m.insert_instruction(ins, make_op("contiguous"), xt);
        xt           = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
        auto xt_wi   = m.insert_instruction(ins, make_op("dot"), xt, tran_sw);
        auto ht_ri   = m.insert_instruction(ins, make_op("dot"), sih, tran_sr);
        if(bias != m.end())
Shucai Xiao's avatar
Shucai Xiao committed
283
        {
284
            xt_wi = m.insert_instruction(ins, make_op("add"), xt_wi, bb);
Shucai Xiao's avatar
Shucai Xiao committed
285
        }
286
        auto xt_ht = m.insert_instruction(ins, make_op("add"), xt_wi, ht_ri);
Shucai Xiao's avatar
Shucai Xiao committed
287
288

        // apply activation function
289
        auto ht = m.insert_instruction(ins, actv_func, xt_ht);
Shucai Xiao's avatar
Shucai Xiao committed
290
        sih     = ht;
Shucai Xiao's avatar
Shucai Xiao committed
291

Shucai Xiao's avatar
Shucai Xiao committed
292
293
        // add the dimensions of sequence length (axis 0 for sequence length,
        // axis 1 for num_directions
294
        last_out = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), ht);
Shucai Xiao's avatar
Shucai Xiao committed
295

Shucai Xiao's avatar
Shucai Xiao committed
296
297
298
        // concatenation for the last last_out is performed in the apply()
        // function to ensure the last instruction is concat, then we have
        // output inserted
Shucai Xiao's avatar
Shucai Xiao committed
299
        if(i < seq_len - 1)
Shucai Xiao's avatar
Shucai Xiao committed
300
        {
Shucai Xiao's avatar
Shucai Xiao committed
301
302
            if(is_forward)
            {
303
304
                hidden_out = (seq_index == 0)
                                 ? last_out
305
                                 : m.insert_instruction(
306
                                       ins, make_op("concat", {{"axis", 0}}), hidden_out, last_out);
Shucai Xiao's avatar
Shucai Xiao committed
307
308
309
            }
            else
            {
310
311
                hidden_out = (seq_index == seq_len - 1)
                                 ? last_out
312
                                 : m.insert_instruction(
313
                                       ins, make_op("concat", {{"axis", 0}}), last_out, hidden_out);
Shucai Xiao's avatar
Shucai Xiao committed
314
            }
Shucai Xiao's avatar
Shucai Xiao committed
315
316
317
        }
    }

318
    return {hidden_out, last_out};
Shucai Xiao's avatar
Shucai Xiao committed
319
320
}

Shucai Xiao's avatar
Shucai Xiao committed
321
std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins) const
322
323
{
    auto rnn_op = any_cast<op::rnn>(ins->get_operator());
Shucai Xiao's avatar
Shucai Xiao committed
324
325
326
327
    // could be 3 to 6 inputs, but the parse_gru function will
    // append undefined operators to make 6 arguments when parsing
    // an onnx file. Another case is user can have any num of arguments
    // when writing their program.
328
    if(rnn_op.direction == op::rnn_direction::bidirectional)
329
    {
Shucai Xiao's avatar
Shucai Xiao committed
330
        if(rnn_op.actv_funcs.empty())
331
332
        {
            // default is tanh
333
            return {make_op("tanh"), make_op("tanh")};
334
        }
Shucai Xiao's avatar
Shucai Xiao committed
335
        else if(rnn_op.actv_funcs.size() == 1)
336
337
338
339
340
341
342
343
344
345
        {
            return {rnn_op.actv_funcs.at(0), rnn_op.actv_funcs.at(0)};
        }
        else
        {
            return rnn_op.actv_funcs;
        }
    }
    else
    {
Shucai Xiao's avatar
Shucai Xiao committed
346
        if(rnn_op.actv_funcs.empty())
347
348
        {
            // default is tanh
349
            return {make_op("tanh")};
350
351
352
353
354
355
356
357
        }
        else
        {
            return rnn_op.actv_funcs;
        }
    }
}

358
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
359
void rewrite_rnn::apply_gru(module& m, instruction_ref ins) const
360
361
362
{
    assert(ins->name() == "gru");
    const auto actv_funcs = gru_actv_funcs(ins);
Shucai Xiao's avatar
Shucai Xiao committed
363
364
365
366
    // could be 3 to 6 inputs, but the parse_gru function will
    // append undefined operators to make 6 arguments when parsing
    // an onnx file. Another case is user can have num of arguments
    // when writing their program.
367
368
369
370
371
372
373
374
375
    auto args = ins->inputs();

    shape seq_shape         = args[0]->get_shape();
    std::size_t hidden_size = args[2]->get_shape().lens()[2];
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
    migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
    std::vector<float> data(ih_shape.elements(), 0.0);

Shucai Xiao's avatar
Shucai Xiao committed
376
    auto gru_op             = any_cast<op::gru>(ins->get_operator());
377
378
379
    op::rnn_direction dirct = gru_op.direction;

    // process sequence length
380
    instruction_ref seq_lens = m.end();
381
382
383
384
385
    if((args.size() >= 5) && args[4]->name() != "undefined")
    {
        seq_lens = args[4];
    }

386
    bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
387

388
    instruction_ref last_output{};
389
    if(dirct == op::rnn_direction::bidirectional)
390
391
    {
        // w weight matrix
392
        auto w_forward = m.insert_instruction(
393
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
394
        auto w_reverse = m.insert_instruction(
395
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
396
397

        // r weight matrix
398
        auto r_forward = m.insert_instruction(
399
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
400
        auto r_reverse = m.insert_instruction(
401
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
402
403

        // bias
404
405
        instruction_ref bias_forward = m.end();
        instruction_ref bias_reverse = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
406
        if(args.size() >= 4 && args[3]->name() != "undefined")
407
        {
408
            bias_forward = m.insert_instruction(
409
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
410
            bias_reverse = m.insert_instruction(
411
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
412
413
414
415
416
        }

        // intial hidden state
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
417
        if(args.size() == 6 && args[5]->name() != "undefined")
418
        {
419
            ih_forward = m.insert_instruction(
420
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
421
            ih_reverse = m.insert_instruction(
422
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
423
424
425
        }
        else
        {
426
427
            ih_forward = m.add_literal(migraphx::literal{ih_shape, data});
            ih_reverse = m.add_literal(migraphx::literal{ih_shape, data});
428
429
        }

430
431
        auto ret_forward =
            gru_cell(true,
432
                     m,
433
434
435
436
437
438
439
440
                     ins,
                     {args[0], w_forward, r_forward, bias_forward, seq_lens, ih_forward},
                     gru_op.linear_before_reset,
                     actv_funcs.at(0),
                     actv_funcs.at(1));

        if(variable_seq_len)
        {
441
442
            args[0] =
                m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
443
        }
Shucai Xiao's avatar
Shucai Xiao committed
444

445
446
        auto ret_reverse =
            gru_cell(false,
447
                     m,
448
449
450
451
452
                     ins,
                     {args[0], w_reverse, r_reverse, bias_reverse, seq_lens, ih_reverse},
                     gru_op.linear_before_reset,
                     actv_funcs.at(2),
                     actv_funcs.at(3));
453

454
        auto concat_output = m.insert_instruction(
455
            ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
456
        last_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_output);
457
458
459

        // The following logic is to ensure the last instruction rewritten
        // from gru operator is a concat
460
        if(ret_forward[0] == m.end())
461
        {
462
            m.replace_instruction(
463
                ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
464
465
466
        }
        else
        {
467
            ret_forward[0] = m.insert_instruction(
468
                ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
469
            ret_reverse[0] = m.insert_instruction(
470
                ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
471
            m.replace_instruction(
472
                ins, make_op("concat", {{"axis", 1}}), {ret_forward[0], ret_reverse[0]});
473
474
475
476
        }
    }
    else
    {
477
        bool is_forward = (dirct == op::rnn_direction::forward);
478
479
480
481
482
        // weight matrix
        auto w = args[1];
        auto r = args[2];

        // bias
483
        instruction_ref bias = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
484
        if(args.size() >= 4 && args[3]->name() != "undefined")
485
486
487
488
489
490
        {
            bias = args[3];
        }

        // intial hidden state
        instruction_ref ih{};
Shucai Xiao's avatar
Shucai Xiao committed
491
        if(args.size() == 6 && args[5]->name() != "undefined")
492
493
494
495
496
        {
            ih = args[5];
        }
        else
        {
497
            ih = m.add_literal(migraphx::literal{ih_shape, data});
498
499
        }

500
501
        if(!is_forward and variable_seq_len)
        {
502
503
            args[0] =
                m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
504
505
        }

506
        auto ret = gru_cell(is_forward,
507
                            m,
508
                            ins,
509
                            {args[0], w, r, bias, seq_lens, ih},
510
511
512
513
                            gru_op.linear_before_reset,
                            actv_funcs.at(0),
                            actv_funcs.at(1));

514
        last_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
515

516
        if(ret[0] == m.end())
517
        {
518
            m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
519
520
521
522
523
        }
        else
        {
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
524
            m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
525
526
527
        }
    }

528
529
    // in case of all sequences are of the same lengths and shorter than the
    // max sequence length, need to pad 0's at the end for output hidden states
530
531
    ins = pad_hidden_states(m, args[0], seq_lens, ins);
    replace_last_hs_output(m, ins, seq_lens, last_output, dirct);
532
533
}

534
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
535
std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
536
                                                   module& m,
Shucai Xiao's avatar
Shucai Xiao committed
537
538
539
540
541
                                                   instruction_ref ins,
                                                   std::vector<instruction_ref> inputs,
                                                   int linear_before_reset,
                                                   const operation& actv_func1,
                                                   const operation& actv_func2) const
542
{
543
544
545
546
547
548
549
    assert(inputs.size() == 6);
    auto seq      = inputs.at(0);
    auto w        = inputs.at(1);
    auto r        = inputs.at(2);
    auto bias     = inputs.at(3);
    auto seq_lens = inputs.at(4);
    auto ih       = inputs.at(5);
550

551
    instruction_ref hidden_states = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
552
    instruction_ref last_output{};
Shucai Xiao's avatar
Shucai Xiao committed
553
    migraphx::shape seq_shape = seq->get_shape();
Shucai Xiao's avatar
Shucai Xiao committed
554
    migraphx::shape r_shape   = r->get_shape();
Paul Fultz II's avatar
Paul Fultz II committed
555
    long hs                   = r_shape.lens()[2];
556

557
558
    migraphx::shape ss(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]});
    std::vector<float> data(ss.elements(), 1.0f);
559
    auto l1 = m.add_literal(migraphx::literal{ss, data});
560

561
    // w matrix squeeze to 2-dim and do a transpose
562
    std::vector<int64_t> perm{1, 0};
563
564
    auto sw = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
    auto tw = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
565

566
    // r slide to two part, zr and h
567
568
    auto sr  = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
    auto rzr = m.insert_instruction(
569
        ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2 * hs}}}), sr);
570
    auto trzr = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rzr);
571

572
    auto rh = m.insert_instruction(
573
        ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), sr);
574
    auto trh = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rh);
575
576

    // initial states
577
    auto sih  = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
Shucai Xiao's avatar
Shucai Xiao committed
578
    size_t bs = ih->get_shape().lens()[1];
579
580

    // bias
581
582
583
    instruction_ref bwb{};
    instruction_ref brb_zr{};
    instruction_ref brb_h{};
584
    if(bias != m.end())
585
    {
586
587
        auto sbias = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
        auto wb    = m.insert_instruction(
588
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {3 * hs}}}), sbias);
589
        bwb = m.insert_instruction(
590
            ins,
591
            make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(3 * hs)}}}),
592
593
            wb);

594
        auto rb_zr = m.insert_instruction(
595
596
597
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {3 * hs}}, {"ends", {5 * hs}}}),
            sbias);
598
        auto rb_h = m.insert_instruction(
599
600
601
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {5 * hs}}, {"ends", {6 * hs}}}),
            sbias);
602
        brb_zr = m.insert_instruction(
603
            ins,
604
            make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(2 * hs)}}}),
605
            rb_zr);
606
        brb_h = m.insert_instruction(
607
            ins,
608
            make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(hs)}}}),
609
            rb_h);
610
611
    }

612
    long seq_len = get_seq_len(m, seq, seq_lens);
613
614
615
    for(long i = 0; i < seq_len; i++)
    {
        long seq_index = is_forward ? i : (seq_len - 1 - i);
616
        auto xt        = m.insert_instruction(
617
618
619
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
            seq);
620
621
        auto cont_xt = m.insert_instruction(ins, make_op("contiguous"), xt);
        xt           = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
622

623
624
625
        auto xt_w    = m.insert_instruction(ins, make_op("dot"), xt, tw);
        auto ih1_rzr = m.insert_instruction(ins, make_op("dot"), sih, trzr);
        if(bias != m.end())
626
        {
627
628
            xt_w    = m.insert_instruction(ins, make_op("add"), xt_w, bwb);
            ih1_rzr = m.insert_instruction(ins, make_op("add"), ih1_rzr, brb_zr);
629
630
        }

631
        auto xw_z = m.insert_instruction(
632
            ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), xt_w);
633
        auto xw_r = m.insert_instruction(
634
            ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), xt_w);
635
        auto xw_h = m.insert_instruction(
636
            ins, make_op("slice", {{"axes", {1}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), xt_w);
637

638
        auto hr_z = m.insert_instruction(
639
            ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), ih1_rzr);
640
        auto hr_r = m.insert_instruction(
641
            ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), ih1_rzr);
642

643
644
        auto xw_hr_z = m.insert_instruction(ins, make_op("add"), xw_z, hr_z);
        auto zt      = m.insert_instruction(ins, actv_func1, xw_hr_z);
645

646
647
        auto xw_hr_r = m.insert_instruction(ins, make_op("add"), xw_r, hr_r);
        auto rt      = m.insert_instruction(ins, actv_func1, xw_hr_r);
648
649
650
651
652

        instruction_ref hr_h{};
        if(linear_before_reset == 0)
        {
            // equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
653
654
655
            auto rt_ht1 = m.insert_instruction(ins, make_op("mul"), rt, sih);
            hr_h        = m.insert_instruction(ins, make_op("dot"), rt_ht1, trh);
            if(bias != m.end())
656
            {
657
                hr_h = m.insert_instruction(ins, make_op("add"), hr_h, brb_h);
658
            }
659
660
661
        }
        else
        {
662
            // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
663
664
            auto ht1_rh = m.insert_instruction(ins, make_op("dot"), sih, trh);
            if(bias != m.end())
665
            {
666
                ht1_rh = m.insert_instruction(ins, make_op("add"), ht1_rh, brb_h);
667
            }
668
            hr_h = m.insert_instruction(ins, make_op("mul"), rt, ht1_rh);
669
670
        }

671
672
        auto xw_hr_h = m.insert_instruction(ins, make_op("add"), xw_h, hr_h);
        auto ht      = m.insert_instruction(ins, actv_func2, xw_hr_h);
673

674
        // equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
675
676
677
678
679
        auto one_minus_zt    = m.insert_instruction(ins, make_op("sub"), l1, zt);
        auto one_minus_zt_ht = m.insert_instruction(ins, make_op("mul"), one_minus_zt, ht);
        auto zt_ht1          = m.insert_instruction(ins, make_op("mul"), zt, sih);
        sih                  = m.insert_instruction(ins, make_op("add"), one_minus_zt_ht, zt_ht1);
        last_output = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), sih);
680
681
682
683
684
685
686
687

        if(i < seq_len - 1)
        {
            if(is_forward)
            {
                hidden_states =
                    (seq_index == 0)
                        ? last_output
688
                        : m.insert_instruction(
689
                              ins, make_op("concat", {{"axis", 0}}), hidden_states, last_output);
690
691
692
693
694
695
            }
            else
            {
                hidden_states =
                    (seq_index == seq_len - 1)
                        ? last_output
696
                        : m.insert_instruction(
697
                              ins, make_op("concat", {{"axis", 0}}), last_output, hidden_states);
698
699
700
701
702
703
704
705
706
707
708
709
710
711
            }
        }
    }

    return {hidden_states, last_output};
}

std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
{
    auto gru_op = any_cast<op::gru>(ins->get_operator());
    // before rewrite the gru operator, need to ensure
    // we have 4 actv funcs, even though a user does not
    // specifiy any actv func. If less than 4, use the
    // algorithm in parse_gru to make 4 actv functions
712
    if(gru_op.direction == op::rnn_direction::bidirectional)
713
714
    {
        if(gru_op.actv_funcs.empty())
715
            return {make_op("sigmoid"), make_op("tanh"), make_op("sigmoid"), make_op("tanh")};
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
        else if(gru_op.actv_funcs.size() == 1)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0)};
        else if(gru_op.actv_funcs.size() == 2)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1)};
        else if(gru_op.actv_funcs.size() == 3)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1),
                    gru_op.actv_funcs.at(2),
                    gru_op.actv_funcs.at(0)};
        else
            return gru_op.actv_funcs;
    }
    else
    {
        if(gru_op.actv_funcs.empty())
737
            return {make_op("sigmoid"), make_op("tanh")};
738
739
740
741
742
743
744
        else if(gru_op.actv_funcs.size() == 1)
            return {gru_op.actv_funcs.at(0), gru_op.actv_funcs.at(0)};
        else
            return gru_op.actv_funcs;
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
745
// for lstm operators
746
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
747
void rewrite_rnn::apply_lstm(module& m, instruction_ref ins) const
Shucai Xiao's avatar
Shucai Xiao committed
748
749
750
{
    assert(ins->name() == "lstm");
    auto args = ins->inputs();
Shucai Xiao's avatar
Shucai Xiao committed
751

Shucai Xiao's avatar
Shucai Xiao committed
752
    shape seq_shape         = args[0]->get_shape();
753
    std::size_t hidden_size = args[2]->get_shape().lens()[2];
Shucai Xiao's avatar
Shucai Xiao committed
754
755
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
Shucai Xiao's avatar
Shucai Xiao committed
756
    migraphx::shape ihc_shape{type, {1, batch_size, hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
757
    std::vector<float> ihc_data(ihc_shape.elements(), 0.0);
Shucai Xiao's avatar
Shucai Xiao committed
758
759
760

    migraphx::shape pph_shape{type, {1, 3 * hidden_size}};

Shucai Xiao's avatar
Shucai Xiao committed
761
762
    auto actv_funcs         = lstm_actv_funcs(ins);
    auto lstm_op            = any_cast<op::lstm>(ins->get_operator());
Shucai Xiao's avatar
Shucai Xiao committed
763
    op::rnn_direction dirct = lstm_op.direction;
Shucai Xiao's avatar
Shucai Xiao committed
764

Shucai Xiao's avatar
Shucai Xiao committed
765
    // process sequence length
766
    instruction_ref seq_lens = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
767
768
769
770
771
    if((args.size() >= 5) && args[4]->name() != "undefined")
    {
        seq_lens = args[4];
    }

772
    bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
773
774

    instruction_ref last_hs_output{};
Shucai Xiao's avatar
Shucai Xiao committed
775
    instruction_ref last_cell_output{};
Shucai Xiao's avatar
Shucai Xiao committed
776
777
    instruction_ref hidden_state{};
    instruction_ref cell_outputs{};
Shucai Xiao's avatar
Shucai Xiao committed
778
    if(dirct == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
779
780
781
    {
        // input weight matrix
        // input weight matrix
782
        auto w_forward = m.insert_instruction(
783
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
784
        auto w_reverse = m.insert_instruction(
785
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
Shucai Xiao's avatar
Shucai Xiao committed
786
787

        // hidden state weight matrix
788
        auto r_forward = m.insert_instruction(
789
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
790
        auto r_reverse = m.insert_instruction(
791
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
Shucai Xiao's avatar
Shucai Xiao committed
792
793

        // process bias
794
795
        instruction_ref bias_forward = m.end();
        instruction_ref bias_reverse = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
796
        if(args.size() >= 4 && args[3]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
797
        {
798
            bias_forward = m.insert_instruction(
799
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
800
            bias_reverse = m.insert_instruction(
801
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
Shucai Xiao's avatar
Shucai Xiao committed
802
803
804
805
806
807
808
        }

        // process intial hidden state, it is the 6th argument
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
        if(args.size() >= 6 && args[5]->name() != "undefined")
        {
809
            ih_forward = m.insert_instruction(
810
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
811
            ih_reverse = m.insert_instruction(
812
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
Shucai Xiao's avatar
Shucai Xiao committed
813
814
815
        }
        else
        {
816
817
            ih_forward = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
            ih_reverse = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
Shucai Xiao's avatar
Shucai Xiao committed
818
819
820
821
822
        }

        // process initial cell value
        instruction_ref ic_forward{};
        instruction_ref ic_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
823
        if(args.size() >= 7 && args[6]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
824
        {
825
            ic_forward = m.insert_instruction(
826
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[6]);
827
            ic_reverse = m.insert_instruction(
828
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[6]);
Shucai Xiao's avatar
Shucai Xiao committed
829
830
831
        }
        else
        {
832
833
            ic_forward = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
            ic_reverse = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
Shucai Xiao's avatar
Shucai Xiao committed
834
835
836
        }

        // process weight of the peephole
837
838
        instruction_ref pph_forward = m.end();
        instruction_ref pph_reverse = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
839
        if(args.size() == 8 && args[7]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
840
        {
841
            pph_forward = m.insert_instruction(
842
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[7]);
843
            pph_reverse = m.insert_instruction(
844
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[7]);
Shucai Xiao's avatar
Shucai Xiao committed
845
        }
Shucai Xiao's avatar
Shucai Xiao committed
846

Shucai Xiao's avatar
Shucai Xiao committed
847
        auto ret_forward = lstm_cell(true,
848
                                     m,
Shucai Xiao's avatar
Shucai Xiao committed
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
                                     ins,
                                     {args[0],
                                      w_forward,
                                      r_forward,
                                      bias_forward,
                                      seq_lens,
                                      ih_forward,
                                      ic_forward,
                                      pph_forward},
                                     actv_funcs.at(0),
                                     actv_funcs.at(1),
                                     actv_funcs.at(2));

        if(variable_seq_len)
        {
864
865
            args[0] =
                m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
866
867
        }
        auto ret_reverse = lstm_cell(false,
868
                                     m,
Shucai Xiao's avatar
Shucai Xiao committed
869
870
871
872
873
874
875
876
877
878
879
880
881
                                     ins,
                                     {args[0],
                                      w_reverse,
                                      r_reverse,
                                      bias_reverse,
                                      seq_lens,
                                      ih_reverse,
                                      ic_reverse,
                                      pph_reverse},
                                     actv_funcs.at(3),
                                     actv_funcs.at(4),
                                     actv_funcs.at(5));

882
        auto concat_hs_output = m.insert_instruction(
883
            ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
884
        auto concat_cell_output = m.insert_instruction(
885
886
            ins, make_op("concat", {{"axis", 1}}), ret_forward[3], ret_reverse[3]);
        last_hs_output =
887
            m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_hs_output);
888
        last_cell_output =
889
            m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_cell_output);
Shucai Xiao's avatar
Shucai Xiao committed
890
891

        // the following logic is to ensure the last instruction is a concat
892
        if(ret_forward[0] == m.end())
Shucai Xiao's avatar
Shucai Xiao committed
893
        {
Shucai Xiao's avatar
Shucai Xiao committed
894
            cell_outputs = concat_cell_output;
Shucai Xiao's avatar
Shucai Xiao committed
895
896
897
        }
        else
        {
898
            ret_forward[1] = m.insert_instruction(
899
                ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
900
            ret_reverse[1] = m.insert_instruction(
901
                ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
Shucai Xiao's avatar
Shucai Xiao committed
902

903
            ret_forward[3] = m.insert_instruction(
904
                ins, make_op("concat", {{"axis", 0}}), ret_forward[2], ret_forward[3]);
905
            ret_reverse[3] = m.insert_instruction(
906
                ins, make_op("concat", {{"axis", 0}}), ret_reverse[3], ret_reverse[2]);
907
            cell_outputs = m.insert_instruction(
908
                ins, make_op("concat", {{"axis", 1}}), ret_forward[3], ret_reverse[3]);
Shucai Xiao's avatar
Shucai Xiao committed
909
        }
Shucai Xiao's avatar
Shucai Xiao committed
910

911
        hidden_state = m.replace_instruction(
912
            ins, make_op("concat", {{"axis", 1}}), {ret_forward[1], ret_reverse[1]});
Shucai Xiao's avatar
Shucai Xiao committed
913
914
915
    }
    else
    {
Shucai Xiao's avatar
Shucai Xiao committed
916
        bool is_forward = (dirct == op::rnn_direction::forward);
Shucai Xiao's avatar
Shucai Xiao committed
917
918
919
920
921
        // weight matrices
        auto w = args[1];
        auto r = args[2];

        // bias
922
        instruction_ref bias = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
923
        if(args.size() >= 4 && args[3]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
924
925
926
927
928
929
        {
            bias = args[3];
        }

        // initial hidden state
        instruction_ref ih{};
Shucai Xiao's avatar
Shucai Xiao committed
930
        if(args.size() >= 6 && args[5]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
931
932
933
934
935
        {
            ih = args[5];
        }
        else
        {
936
            ih = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
Shucai Xiao's avatar
Shucai Xiao committed
937
938
939
940
        }

        // initial cell value
        instruction_ref ic{};
Shucai Xiao's avatar
Shucai Xiao committed
941
        if(args.size() >= 7 && args[6]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
942
943
944
945
946
        {
            ic = args[6];
        }
        else
        {
947
            ic = m.add_literal(migraphx::literal{ihc_shape, ihc_data});
Shucai Xiao's avatar
Shucai Xiao committed
948
949
950
        }

        // process weight of the peephole
951
        instruction_ref pph = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
952
953
954
955
        if(args.size() == 8 && args[7]->name() != "undefined")
        {
            pph = args[7];
        }
Shucai Xiao's avatar
Shucai Xiao committed
956

Shucai Xiao's avatar
Shucai Xiao committed
957
958
        if(!is_forward and variable_seq_len)
        {
959
960
            args[0] =
                m.insert_instruction(ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
961
        }
Shucai Xiao's avatar
Shucai Xiao committed
962
        auto ret = lstm_cell(is_forward,
963
                             m,
Shucai Xiao's avatar
Shucai Xiao committed
964
                             ins,
Shucai Xiao's avatar
Shucai Xiao committed
965
                             {args[0], w, r, bias, seq_lens, ih, ic, pph},
Shucai Xiao's avatar
Shucai Xiao committed
966
967
968
969
                             actv_funcs.at(0),
                             actv_funcs.at(1),
                             actv_funcs.at(2));

970
971
        last_hs_output   = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
        last_cell_output = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[3]);
Shucai Xiao's avatar
Shucai Xiao committed
972

973
        if(ret[0] == m.end())
Shucai Xiao's avatar
Shucai Xiao committed
974
        {
Shucai Xiao's avatar
Shucai Xiao committed
975
            cell_outputs = ret[3];
976
            hidden_state = m.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
Shucai Xiao's avatar
Shucai Xiao committed
977
978
979
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
980
981
            auto concat_cell_arg0 = is_forward ? ret[2] : ret[3];
            auto concat_cell_arg1 = is_forward ? ret[3] : ret[2];
982
            cell_outputs          = m.insert_instruction(
983
                ins, make_op("concat", {{"axis", 0}}), concat_cell_arg0, concat_cell_arg1);
Shucai Xiao's avatar
Shucai Xiao committed
984

Shucai Xiao's avatar
Shucai Xiao committed
985
986
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
987
            hidden_state     = m.replace_instruction(
988
                ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
Shucai Xiao's avatar
Shucai Xiao committed
989
990
991
        }
    }

992
993
    // in case of all sequences are of the same lengths and shorter than the
    // max sequence length, need to pad 0's at the end for output hidden states
994
    hidden_state = pad_hidden_states(m, args[0], seq_lens, hidden_state);
995
996

    // replace last hidden states with corresponding instructions
997
    ins = replace_last_hs_output(m, hidden_state, seq_lens, last_hs_output, dirct);
998
999

    // replace last cell outputs with corresponding instructions
1000
    replace_last_cell_output(m, ins, seq_lens, cell_outputs, last_cell_output, dirct);
Shucai Xiao's avatar
Shucai Xiao committed
1001
1002
}

1003
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
Shucai Xiao's avatar
Shucai Xiao committed
1004
std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
1005
                                                    module& m,
Shucai Xiao's avatar
Shucai Xiao committed
1006
1007
1008
1009
1010
                                                    instruction_ref ins,
                                                    std::vector<instruction_ref> inputs,
                                                    const operation& actv_func1,
                                                    const operation& actv_func2,
                                                    const operation& actv_func3) const
Shucai Xiao's avatar
Shucai Xiao committed
1011
{
Shucai Xiao's avatar
Shucai Xiao committed
1012
    // must have 7 args in the input vector
Shucai Xiao's avatar
Shucai Xiao committed
1013
1014
1015
1016
1017
1018
1019
1020
1021
    assert(inputs.size() == 8);
    auto seq      = inputs.at(0);
    auto w        = inputs.at(1);
    auto r        = inputs.at(2);
    auto bias     = inputs.at(3);
    auto seq_lens = inputs.at(4);
    auto ih       = inputs.at(5);
    auto ic       = inputs.at(6);
    auto pph      = inputs.at(7);
Shucai Xiao's avatar
Shucai Xiao committed
1022

1023
1024
    instruction_ref hidden_states = m.end();
    instruction_ref cell_outputs  = m.end();
Shucai Xiao's avatar
Shucai Xiao committed
1025
1026

    instruction_ref last_hs_output{};
Shucai Xiao's avatar
Shucai Xiao committed
1027
1028
    instruction_ref last_cell_output{};

1029
    migraphx::shape r_shape = r->get_shape();
Paul Fultz II's avatar
Paul Fultz II committed
1030
    long hs                 = r_shape.lens()[2];
1031
    auto bs                 = ih->get_shape().lens()[1];
Shucai Xiao's avatar
Shucai Xiao committed
1032
1033

    std::vector<int64_t> perm{1, 0};
1034
    // w matrix, squeeze and transpose
1035
1036
    auto sw  = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
    auto tsw = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
Shucai Xiao's avatar
Shucai Xiao committed
1037

1038
    // r matrix, squeeze and transpose
1039
1040
    auto sr  = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
    auto tsr = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr);
Shucai Xiao's avatar
Shucai Xiao committed
1041

Shucai Xiao's avatar
Shucai Xiao committed
1042
    // initial hidden state
1043
    auto sih = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
Shucai Xiao's avatar
Shucai Xiao committed
1044
1045

    // initial cell state
1046
    auto sic     = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ic);
1047
    auto ic_lens = sic->get_shape().lens();
Shucai Xiao's avatar
Shucai Xiao committed
1048

1049
    // bias
1050
    instruction_ref wrb{};
1051
    if(bias != m.end())
1052
    {
1053

1054
1055
        auto sbias = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
        auto ub_wb = m.insert_instruction(
1056
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {4 * hs}}}), sbias);
1057
        auto ub_rb = m.insert_instruction(
1058
1059
1060
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {4 * hs}}, {"ends", {8 * hs}}}),
            sbias);
1061
        auto ub_wrb = m.insert_instruction(ins, make_op("add"), ub_wb, ub_rb);
1062

1063
        wrb = m.insert_instruction(
1064
            ins,
1065
            make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, 4 * static_cast<size_t>(hs)}}}),
1066
            ub_wrb);
1067
1068
    }

Shucai Xiao's avatar
Shucai Xiao committed
1069
    // peep hole
Shucai Xiao's avatar
Shucai Xiao committed
1070
1071
1072
    instruction_ref pphi_brcst{};
    instruction_ref ppho_brcst{};
    instruction_ref pphf_brcst{};
1073
    if(pph != m.end())
Shucai Xiao's avatar
Shucai Xiao committed
1074
    {
1075
1076
        auto spph = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), pph);
        auto pphi = m.insert_instruction(
1077
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), spph);
1078
        pphi_brcst = m.insert_instruction(
1079
            ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), pphi);
1080

1081
        auto ppho = m.insert_instruction(
1082
            ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), spph);
1083
        ppho_brcst = m.insert_instruction(
1084
            ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), ppho);
1085

1086
        auto pphf = m.insert_instruction(
1087
            ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), spph);
1088
        pphf_brcst = m.insert_instruction(
1089
            ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), pphf);
Shucai Xiao's avatar
Shucai Xiao committed
1090
    }
Shucai Xiao's avatar
Shucai Xiao committed
1091

1092
    long seq_len = get_seq_len(m, seq, seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1093
    for(long i = 0; i < seq_len; ++i)
Shucai Xiao's avatar
Shucai Xiao committed
1094
1095
    {
        long seq_index = is_forward ? i : (seq_len - 1 - i);
1096
        auto xt        = m.insert_instruction(
1097
1098
1099
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
            seq);
1100
1101
        auto cont_xt = m.insert_instruction(ins, make_op("contiguous"), xt);
        xt           = m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
1102

1103
1104
1105
1106
        auto xt_tsw  = m.insert_instruction(ins, make_op("dot"), xt, tsw);
        auto sih_tsr = m.insert_instruction(ins, make_op("dot"), sih, tsr);
        auto xt_sih  = m.insert_instruction(ins, make_op("add"), xt_tsw, sih_tsr);
        if(bias != m.end())
1107
        {
1108
            xt_sih = m.insert_instruction(ins, make_op("add"), xt_sih, wrb);
1109
        }
Shucai Xiao's avatar
Shucai Xiao committed
1110

1111
        auto it_before_actv = m.insert_instruction(
1112
            ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), xt_sih);
1113
        auto ot_before_actv = m.insert_instruction(
1114
            ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), xt_sih);
1115
        auto ft_before_actv = m.insert_instruction(
1116
1117
1118
            ins,
            make_op("slice", {{"axes", {1}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}),
            xt_sih);
1119
        auto ct_before_actv = m.insert_instruction(
1120
1121
1122
            ins,
            make_op("slice", {{"axes", {1}}, {"starts", {3 * hs}}, {"ends", {4 * hs}}}),
            xt_sih);
1123

1124
        if(pph != m.end())
Shucai Xiao's avatar
Shucai Xiao committed
1125
        {
1126
1127
            auto pphi_ct   = m.insert_instruction(ins, make_op("mul"), pphi_brcst, sic);
            it_before_actv = m.insert_instruction(ins, make_op("add"), it_before_actv, pphi_ct);
Shucai Xiao's avatar
Shucai Xiao committed
1128

1129
1130
            auto pphf_ct   = m.insert_instruction(ins, make_op("mul"), pphf_brcst, sic);
            ft_before_actv = m.insert_instruction(ins, make_op("add"), ft_before_actv, pphf_ct);
Shucai Xiao's avatar
Shucai Xiao committed
1131
        }
1132
1133
1134
        auto it = m.insert_instruction(ins, actv_func1, it_before_actv);
        auto ft = m.insert_instruction(ins, actv_func1, ft_before_actv);
        auto ct = m.insert_instruction(ins, actv_func2, ct_before_actv);
1135
1136

        // equation Ct = ft (.) Ct-1 + it (.) ct
1137
1138
1139
        auto ft_cell = m.insert_instruction(ins, make_op("mul"), ft, sic);
        auto it_ct   = m.insert_instruction(ins, make_op("mul"), it, ct);
        auto cellt   = m.insert_instruction(ins, make_op("add"), ft_cell, it_ct);
1140

1141
        if(pph != m.end())
Shucai Xiao's avatar
Shucai Xiao committed
1142
        {
1143
1144
            auto ppho_cellt = m.insert_instruction(ins, make_op("mul"), ppho_brcst, cellt);
            ot_before_actv  = m.insert_instruction(ins, make_op("add"), ot_before_actv, ppho_cellt);
Shucai Xiao's avatar
Shucai Xiao committed
1145
        }
1146
        auto ot = m.insert_instruction(ins, actv_func1, ot_before_actv);
1147
1148

        // Ht = ot (.) h(Ct)
1149
1150
        auto h_cellt = m.insert_instruction(ins, actv_func3, cellt);
        auto ht      = m.insert_instruction(ins, make_op("mul"), ot, h_cellt);
1151
1152
1153
1154

        sic = cellt;
        sih = ht;

1155
        last_hs_output = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), ht);
1156
        last_cell_output =
1157
            m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), cellt);
1158

Shucai Xiao's avatar
Shucai Xiao committed
1159
        if(i < seq_len - 1)
1160
        {
Shucai Xiao's avatar
Shucai Xiao committed
1161
            if(i == 0)
1162
            {
Shucai Xiao's avatar
Shucai Xiao committed
1163
1164
                hidden_states = last_hs_output;
                cell_outputs  = last_cell_output;
1165
1166
1167
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
1168
1169
                auto concat_hs_arg0 = is_forward ? hidden_states : last_hs_output;
                auto concat_hs_arg1 = is_forward ? last_hs_output : hidden_states;
1170
                hidden_states       = m.insert_instruction(
1171
                    ins, make_op("concat", {{"axis", 0}}), concat_hs_arg0, concat_hs_arg1);
Shucai Xiao's avatar
Shucai Xiao committed
1172
1173
1174

                auto concat_cell_arg0 = is_forward ? cell_outputs : last_cell_output;
                auto concat_cell_arg1 = is_forward ? last_cell_output : cell_outputs;
1175
                cell_outputs          = m.insert_instruction(
1176
                    ins, make_op("concat", {{"axis", 0}}), concat_cell_arg0, concat_cell_arg1);
1177
1178
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
1179
    }
1180

Shucai Xiao's avatar
Shucai Xiao committed
1181
    return {hidden_states, last_hs_output, cell_outputs, last_cell_output};
Shucai Xiao's avatar
Shucai Xiao committed
1182
1183
1184
1185
1186
1187
1188
1189
1190
}

std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
{
    auto lstm_op = any_cast<op::lstm>(ins->get_operator());
    // before rewrite the lstm operator, need to ensure
    // we have 6 actv funcs, even though a user does not
    // specifiy any actv func. If less than 46, use the
    // algorithm in parse_lstm to make 6 actv functions
Shucai Xiao's avatar
Shucai Xiao committed
1191
    const auto& actv_funcs     = lstm_op.actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1192
    std::size_t num_actv_funcs = actv_funcs.size();
Shucai Xiao's avatar
Shucai Xiao committed
1193
    if(lstm_op.direction == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1194
1195
1196
1197
    {
        switch(num_actv_funcs)
        {
        case 0:
1198
1199
1200
1201
1202
1203
            return {make_op("sigmoid"),
                    make_op("tanh"),
                    make_op("tanh"),
                    make_op("sigmoid"),
                    make_op("tanh"),
                    make_op("tanh")};
Shucai Xiao's avatar
Shucai Xiao committed
1204
1205

        case 1:
Shucai Xiao's avatar
Shucai Xiao committed
1206
1207
1208
1209
1210
1211
            return {actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0)};
Shucai Xiao's avatar
Shucai Xiao committed
1212
1213

        case 2:
Shucai Xiao's avatar
Shucai Xiao committed
1214
1215
1216
1217
1218
1219
1220
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(1),
                    actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(1)};

Shucai Xiao's avatar
Shucai Xiao committed
1221
        case 3:
Shucai Xiao's avatar
Shucai Xiao committed
1222
1223
1224
1225
1226
1227
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2)};
Shucai Xiao's avatar
Shucai Xiao committed
1228
1229

        case 4:
Shucai Xiao's avatar
Shucai Xiao committed
1230
1231
1232
1233
1234
1235
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(3),
                    actv_funcs.at(3),
                    actv_funcs.at(3)};
Shucai Xiao's avatar
Shucai Xiao committed
1236
1237

        case 5:
Shucai Xiao's avatar
Shucai Xiao committed
1238
1239
1240
1241
1242
1243
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(3),
                    actv_funcs.at(4),
                    actv_funcs.at(4)};
Shucai Xiao's avatar
Shucai Xiao committed
1244

Shucai Xiao's avatar
Shucai Xiao committed
1245
        default: return actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1246
1247
1248
1249
1250
1251
        }
    }
    else
    {
        switch(num_actv_funcs)
        {
1252
        case 0: return {make_op("sigmoid"), make_op("tanh"), make_op("tanh")};
Shucai Xiao's avatar
Shucai Xiao committed
1253

Shucai Xiao's avatar
Shucai Xiao committed
1254
        case 1: return {actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0)};
Shucai Xiao's avatar
Shucai Xiao committed
1255

Shucai Xiao's avatar
Shucai Xiao committed
1256
1257
1258
        case 2: return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1)};

        default: return actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1259
1260
1261
1262
        }
    }
}

1263
bool rewrite_rnn::is_variable_seq_lens(const module& m, instruction_ref seq_lens) const
Shucai Xiao's avatar
Shucai Xiao committed
1264
1265
{
    bool is_var_lens = false;
1266
    if(seq_lens != m.end())
Shucai Xiao's avatar
Shucai Xiao committed
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
    {
        if(seq_lens->can_eval())
        {
            auto arg_lens = seq_lens->eval();
            std::vector<int64_t> vec_lens;
            arg_lens.visit([&](auto l) { vec_lens.assign(l.begin(), l.end()); });
            int64_t l = 0;
            if(!vec_lens.empty())
            {
                l = vec_lens[0];
            }
            if(!std::all_of(vec_lens.begin(), vec_lens.end(), [&](auto v) { return v == l; }))
            {
                is_var_lens = true;
            }
        }
        else
        {
            is_var_lens = true;
        }
    }

    return is_var_lens;
}

std::size_t
1293
rewrite_rnn::get_seq_len(const module& m, instruction_ref input, instruction_ref seq_lens) const
Shucai Xiao's avatar
Shucai Xiao committed
1294
{
1295
    bool is_var_lens = is_variable_seq_lens(m, seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1296
1297
    auto input_shape = input->get_shape();
    auto length      = input_shape.lens()[0];
1298
    if(!is_var_lens and seq_lens != m.end())
Shucai Xiao's avatar
Shucai Xiao committed
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
    {
        auto arg_len = seq_lens->eval();
        std::vector<std::size_t> vec_lens;
        arg_len.visit([&](auto l) { vec_lens.assign(l.begin(), l.end()); });
        length = vec_lens.empty() ? length : vec_lens[0];
    }

    return length;
}

1309
instruction_ref rewrite_rnn::replace_last_hs_output(module& m,
Shucai Xiao's avatar
Shucai Xiao committed
1310
1311
1312
1313
1314
                                                    instruction_ref ins,
                                                    instruction_ref seq_lens,
                                                    instruction_ref last_hs_output,
                                                    op::rnn_direction dirct) const
{
1315
    bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1316
1317
1318
    instruction_ref result_ins{};
    if(variable_seq_len)
    {
1319
1320
1321
1322
1323
1324
1325
        result_ins =
            m.insert_instruction(std::next(ins),
                                 make_op("rnn_var_sl_shift_output",
                                         {{"output_name", "hidden_states"}, {"direction", dirct}}),
                                 ins,
                                 seq_lens);
        m.replace_instruction(ins, result_ins);
1326
1327
        auto hs_outputs = find_all(result_ins->outputs(),
                                   [&](auto i) { return i->name() == "rnn_last_hs_output"; });
Shucai Xiao's avatar
Shucai Xiao committed
1328

1329
        for(auto& hs_out : hs_outputs)
Shucai Xiao's avatar
Shucai Xiao committed
1330
        {
1331
            auto inputs = hs_out->inputs();
1332
1333
1334
1335
            m.replace_instruction(hs_out,
                                  make_op("rnn_var_sl_last_output", {{"direction", dirct}}),
                                  inputs.front(),
                                  seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1336
1337
1338
1339
        }
    }
    else
    {
1340
1341
        auto hs_outputs =
            find_all(ins->outputs(), [&](auto i) { return i->name() == "rnn_last_hs_output"; });
Shucai Xiao's avatar
Shucai Xiao committed
1342

1343
1344
        for(auto& hs_out : hs_outputs)
        {
1345
            m.replace_instruction(hs_out, last_hs_output);
Shucai Xiao's avatar
Shucai Xiao committed
1346
        }
1347

Shucai Xiao's avatar
Shucai Xiao committed
1348
1349
1350
1351
1352
1353
        result_ins = ins;
    }

    return result_ins;
}

1354
void rewrite_rnn::replace_last_cell_output(module& m,
Shucai Xiao's avatar
Shucai Xiao committed
1355
1356
1357
1358
1359
1360
                                           instruction_ref ins,
                                           instruction_ref seq_lens,
                                           instruction_ref cell_outputs,
                                           instruction_ref last_cell_output,
                                           op::rnn_direction dirct) const
{
1361
    bool variable_seq_len = is_variable_seq_lens(m, seq_lens);
1362
1363
    auto ins_outputs =
        find_all(ins->outputs(), [&](auto i) { return i->name() == "rnn_last_cell_output"; });
Shucai Xiao's avatar
Shucai Xiao committed
1364
1365
1366

    if(variable_seq_len)
    {
1367
        if(!ins_outputs.empty())
Shucai Xiao's avatar
Shucai Xiao committed
1368
        {
1369
            cell_outputs = m.insert_instruction(
1370
1371
1372
1373
1374
                std::next(ins),
                make_op("rnn_var_sl_shift_output",
                        {{"output_name", "cell_outputs"}, {"direction", dirct}}),
                cell_outputs,
                seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1375
1376
        }

1377
        for(auto co : ins_outputs)
Shucai Xiao's avatar
Shucai Xiao committed
1378
        {
1379
1380
1381
1382
            m.replace_instruction(co,
                                  make_op("rnn_var_sl_last_output", {{"direction", dirct}}),
                                  cell_outputs,
                                  seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1383
1384
1385
1386
1387
1388
        }
    }
    // replace the rnn_last_cell_output with the last_cell_output. The while
    // loop is to handle the case of multiple rnn_last_cell_output operators
    else
    {
1389
        for(auto co : ins_outputs)
Shucai Xiao's avatar
Shucai Xiao committed
1390
        {
1391
            m.replace_instruction(co, last_cell_output);
Shucai Xiao's avatar
Shucai Xiao committed
1392
1393
1394
1395
        }
    }
}

1396
instruction_ref rewrite_rnn::pad_hidden_states(module& m,
1397
1398
1399
1400
1401
                                               instruction_ref seq,
                                               instruction_ref seq_lens,
                                               instruction_ref hs) const
{
    auto max_seq_len = seq->get_shape().lens()[0];
1402
    auto seq_len     = get_seq_len(m, seq, seq_lens);
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413

    // condition of all sequence are of the same length and
    // less than max_seq_len, we need to append the hs outputs
    auto hs_padded = hs;
    if(seq_len < max_seq_len)
    {
        auto s        = hs->get_shape();
        auto pad_lens = s.lens();
        pad_lens[0]   = static_cast<std::size_t>(max_seq_len - seq_len);
        shape pad_s{s.type(), pad_lens};
        std::vector<float> pad_data(pad_s.elements(), 0.0f);
1414
1415
1416
        auto pl   = m.add_literal(pad_s, pad_data.begin(), pad_data.end());
        hs_padded = m.insert_instruction(std::next(hs), make_op("concat", {{"axis", 0}}), hs, pl);
        m.replace_instruction(hs, hs_padded);
1417
1418
1419
1420
1421
    }

    return hs_padded;
}

Shucai Xiao's avatar
Shucai Xiao committed
1422
1423
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx