"vscode:/vscode.git/clone" did not exist on "46b50620f8899137625d8127b92afb8acbf9aca5"
rewrite_rnn.cpp 55.7 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
Paul's avatar
Paul committed
4
5
6
7
8
9
10
11
12
13
14
15
16
#include <migraphx/op/add.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/gru.hpp>
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/rnn.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/squeeze.hpp>
#include <migraphx/op/sub.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/unsqueeze.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
17
#include <migraphx/op/contiguous.hpp>
18
19
20
#include <migraphx/op/common.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
21
22
#include <migraphx/make_op.hpp>

Shucai Xiao's avatar
Shucai Xiao committed
23
24
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
25
#include <migraphx/ranges.hpp>
26
#include <migraphx/op/common.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
27
28
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
29
30
31
32

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

33
void rewrite_rnn::apply(module& prog) const
Shucai Xiao's avatar
Shucai Xiao committed
34
35
36
{
    for(auto ins : iterator_for(prog))
    {
Shucai Xiao's avatar
Shucai Xiao committed
37
        if(ins->name() == "rnn")
Shucai Xiao's avatar
Shucai Xiao committed
38
        {
Shucai Xiao's avatar
Shucai Xiao committed
39
            apply_vanilla_rnn(prog, ins);
40
        }
41
        else if(ins->name() == "gru")
42
43
        {
            apply_gru(prog, ins);
Shucai Xiao's avatar
Shucai Xiao committed
44
        }
45
46
47
48
        else if(ins->name() == "lstm")
        {
            apply_lstm(prog, ins);
        }
Shucai Xiao's avatar
Shucai Xiao committed
49
    }
50
51
}

52
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
53
void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
54
55
56
57
{
    assert(ins->name() == "rnn");
    // could be 3 to 6 inputs, but the parse_rnn function will
    // append undefined operators to make 6 arguments when parsing
Shucai Xiao's avatar
Shucai Xiao committed
58
    // an onnx file. Another case is user can have num of arguments
59
    // when writing their module.
60
61
62
63
64
65
66
67
68
    auto args = ins->inputs();

    shape seq_shape         = args[0]->get_shape();
    std::size_t hidden_size = args[1]->get_shape().lens()[1];
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
    migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
    std::vector<float> data(ih_shape.elements(), 0);

Shucai Xiao's avatar
Shucai Xiao committed
69
70
    auto actv_funcs         = vanilla_rnn_actv_funcs(ins);
    auto rnn_op             = any_cast<op::rnn>(ins->get_operator());
71
72
73
74
75
76
77
78
79
80
81
    op::rnn_direction dirct = rnn_op.direction;

    // process sequence length
    instruction_ref seq_lens = prog.end();
    if((args.size() >= 5) && args[4]->name() != "undefined")
    {
        seq_lens = args[4];
    }

    bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);

82
    instruction_ref last_output{};
83
    if(dirct == op::rnn_direction::bidirectional)
84
85
    {
        // input weight matrix
86
87
88
89
        auto w_forward = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
        auto w_reverse = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
90
91

        // hidden state weight matrix
92
93
94
95
        auto r_forward = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
        auto r_reverse = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
96
97
98
99

        // process bias
        instruction_ref bias_forward = prog.end();
        instruction_ref bias_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
100
        if(args.size() >= 4 && args[3]->name() != "undefined")
101
        {
102
103
104
105
            bias_forward = prog.insert_instruction(
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
            bias_reverse = prog.insert_instruction(
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
106
107
108
109
110
111
        }

        // process intial hidden state, it could be the 6th argument
        // or the 5th one (if the sequence len argument is ignored)
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
112
        if(args.size() == 6 && args[5]->name() != "undefined")
113
        {
114
115
116
117
            ih_forward = prog.insert_instruction(
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
            ih_reverse = prog.insert_instruction(
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
118
119
120
121
122
123
124
        }
        else
        {
            ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
            ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
        }

125
126
127
128
129
130
131
132
133
        auto ret_forward =
            vanilla_rnn_cell(true,
                             prog,
                             ins,
                             {args[0], w_forward, r_forward, bias_forward, seq_lens, ih_forward},
                             actv_funcs.at(0));

        if(variable_seq_len)
        {
134
135
            args[0] = prog.insert_instruction(
                ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
136
137
138
139
140
141
142
143
        }

        auto ret_reverse =
            vanilla_rnn_cell(false,
                             prog,
                             ins,
                             {args[0], w_reverse, r_reverse, bias_reverse, seq_lens, ih_reverse},
                             actv_funcs.at(1));
144

145
146
147
148
        auto concat_output = prog.insert_instruction(
            ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
        last_output =
            prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_output);
149
150
151
152
153
154

        // The following logic is to ensure the last instruction rewritten from
        // rnn operator is a concat instruction
        // sequence len is 1
        if(ret_forward[0] == prog.end())
        {
155
156
            prog.replace_instruction(
                ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
157
158
159
        }
        else
        {
160
161
162
163
164
165
            ret_forward[0] = prog.insert_instruction(
                ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
            ret_reverse[0] = prog.insert_instruction(
                ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
            prog.replace_instruction(
                ins, make_op("concat", {{"axis", 1}}), {ret_forward[0], ret_reverse[0]});
166
167
168
169
        }
    }
    else
    {
170
        bool is_forward = (dirct == op::rnn_direction::forward);
171
172
173
174
175
176
177
178
        // input weight matrix
        auto w = args[1];

        // hidden state weight matrix
        auto r = args[2];

        // process bias and initial hidden state
        instruction_ref bias = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
179
        if(args.size() >= 4 && args[3]->name() != "undefined")
180
181
182
183
184
185
        {
            bias = args[3];
        }

        // process intial hidden state
        instruction_ref ih;
Shucai Xiao's avatar
Shucai Xiao committed
186
        if(args.size() == 6 && args[5]->name() != "undefined")
187
188
189
190
191
192
193
194
        {
            ih = args[5];
        }
        else
        {
            ih = prog.add_literal(migraphx::literal{ih_shape, data});
        }

195
196
        if(!is_forward and variable_seq_len)
        {
197
198
            args[0] = prog.insert_instruction(
                ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
199
200
201
202
        }

        auto ret = vanilla_rnn_cell(
            is_forward, prog, ins, {args[0], w, r, bias, seq_lens, ih}, actv_funcs.at(0));
203
        last_output = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
204
205
206
207
208
209

        // following logic is to ensure the last instruction is a
        // concat instruction
        // sequence len is 1
        if(ret[0] == prog.end())
        {
210
            prog.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
211
212
213
214
215
        }
        else
        {
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
216
217
            prog.replace_instruction(
                ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
218
219
220
        }
    }

221
222
223
224
    // in case of all sequences are of the same lengths and shorter than the
    // max sequence length, need to pad 0's at the end for output hidden states
    ins = pad_hidden_states(prog, args[0], seq_lens, ins);
    replace_last_hs_output(prog, ins, seq_lens, last_output, dirct);
Shucai Xiao's avatar
Shucai Xiao committed
225
226
}

Shucai Xiao's avatar
Shucai Xiao committed
227
std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
228
                                                           module& prog,
Shucai Xiao's avatar
Shucai Xiao committed
229
                                                           instruction_ref ins,
230
                                                           std::vector<instruction_ref> inputs,
Shucai Xiao's avatar
Shucai Xiao committed
231
                                                           operation& actv_func) const
Shucai Xiao's avatar
Shucai Xiao committed
232
{
233
234
235
236
237
238
239
240
    assert(inputs.size() == 6);
    auto seq      = inputs.at(0);
    auto w        = inputs.at(1);
    auto r        = inputs.at(2);
    auto bias     = inputs.at(3);
    auto seq_lens = inputs.at(4);
    auto ih       = inputs.at(5);

Shucai Xiao's avatar
Shucai Xiao committed
241
242
    // squeeze and transpose w
    std::vector<int64_t> perm{1, 0};
243
    auto sw      = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
244
    auto tran_sw = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
Shucai Xiao's avatar
Shucai Xiao committed
245
246

    // squeeze and transpose r
247
    auto sr      = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
248
    auto tran_sr = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr);
Shucai Xiao's avatar
Shucai Xiao committed
249
250

    // initial hidden state
251
    auto sih      = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
252
    auto sih_lens = sih->get_shape().lens();
Shucai Xiao's avatar
Shucai Xiao committed
253
254

    // bias
255
    instruction_ref bb{};
Shucai Xiao's avatar
Shucai Xiao committed
256
    if(bias != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
257
    {
258
        long hs    = static_cast<long>(r->get_shape().lens()[2]);
259
260
261
262
263
264
265
        auto sbias = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
        auto wb    = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), sbias);
        auto rb = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), sbias);
        auto wrb = prog.insert_instruction(ins, make_op("add"), wb, rb);
        bb       = prog.insert_instruction(
266
            ins, make_op("broadcast", {{"axis", 1}, {"out_lens", sih_lens}}), wrb);
Shucai Xiao's avatar
Shucai Xiao committed
267
268
    }

Shucai Xiao's avatar
Shucai Xiao committed
269
270
    instruction_ref hidden_out = prog.end();
    instruction_ref last_out{};
271
    last_out     = prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), sih);
Paul Fultz II's avatar
Paul Fultz II committed
272
    long seq_len = get_seq_len(prog, seq, seq_lens);
273
    for(long i = 0; i < seq_len; i++)
Shucai Xiao's avatar
Shucai Xiao committed
274
    {
Shucai Xiao's avatar
Shucai Xiao committed
275
        long seq_index = is_forward ? i : (seq_len - 1 - i);
276
277
278
279
280
281
282
283
        auto xt        = prog.insert_instruction(
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
            seq);
        auto cont_xt = prog.insert_instruction(ins, make_op("contiguous"), xt);
        xt           = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);
        auto xt_wi   = prog.insert_instruction(ins, make_op("dot"), xt, tran_sw);
        auto ht_ri   = prog.insert_instruction(ins, make_op("dot"), sih, tran_sr);
Shucai Xiao's avatar
Shucai Xiao committed
284
        if(bias != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
285
        {
286
            xt_wi = prog.insert_instruction(ins, make_op("add"), xt_wi, bb);
Shucai Xiao's avatar
Shucai Xiao committed
287
        }
288
        auto xt_ht = prog.insert_instruction(ins, make_op("add"), xt_wi, ht_ri);
Shucai Xiao's avatar
Shucai Xiao committed
289
290

        // apply activation function
Shucai Xiao's avatar
Shucai Xiao committed
291
292
        auto ht = prog.insert_instruction(ins, actv_func, xt_ht);
        sih     = ht;
Shucai Xiao's avatar
Shucai Xiao committed
293

Shucai Xiao's avatar
Shucai Xiao committed
294
295
        // add the dimensions of sequence length (axis 0 for sequence length,
        // axis 1 for num_directions
296
        last_out = prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), ht);
Shucai Xiao's avatar
Shucai Xiao committed
297

Shucai Xiao's avatar
Shucai Xiao committed
298
299
300
        // concatenation for the last last_out is performed in the apply()
        // function to ensure the last instruction is concat, then we have
        // output inserted
Shucai Xiao's avatar
Shucai Xiao committed
301
        if(i < seq_len - 1)
Shucai Xiao's avatar
Shucai Xiao committed
302
        {
Shucai Xiao's avatar
Shucai Xiao committed
303
304
            if(is_forward)
            {
305
306
307
308
                hidden_out = (seq_index == 0)
                                 ? last_out
                                 : prog.insert_instruction(
                                       ins, make_op("concat", {{"axis", 0}}), hidden_out, last_out);
Shucai Xiao's avatar
Shucai Xiao committed
309
310
311
            }
            else
            {
312
313
314
315
                hidden_out = (seq_index == seq_len - 1)
                                 ? last_out
                                 : prog.insert_instruction(
                                       ins, make_op("concat", {{"axis", 0}}), last_out, hidden_out);
Shucai Xiao's avatar
Shucai Xiao committed
316
            }
Shucai Xiao's avatar
Shucai Xiao committed
317
318
319
        }
    }

320
    return {hidden_out, last_out};
Shucai Xiao's avatar
Shucai Xiao committed
321
322
}

Shucai Xiao's avatar
Shucai Xiao committed
323
std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins) const
324
325
{
    auto rnn_op = any_cast<op::rnn>(ins->get_operator());
Shucai Xiao's avatar
Shucai Xiao committed
326
327
328
329
    // could be 3 to 6 inputs, but the parse_gru function will
    // append undefined operators to make 6 arguments when parsing
    // an onnx file. Another case is user can have any num of arguments
    // when writing their program.
330
    if(rnn_op.direction == op::rnn_direction::bidirectional)
331
    {
Shucai Xiao's avatar
Shucai Xiao committed
332
        if(rnn_op.actv_funcs.empty())
333
334
        {
            // default is tanh
335
            return {make_op("tanh"), make_op("tanh")};
336
        }
Shucai Xiao's avatar
Shucai Xiao committed
337
        else if(rnn_op.actv_funcs.size() == 1)
338
339
340
341
342
343
344
345
346
347
        {
            return {rnn_op.actv_funcs.at(0), rnn_op.actv_funcs.at(0)};
        }
        else
        {
            return rnn_op.actv_funcs;
        }
    }
    else
    {
Shucai Xiao's avatar
Shucai Xiao committed
348
        if(rnn_op.actv_funcs.empty())
349
350
        {
            // default is tanh
351
            return {make_op("tanh")};
352
353
354
355
356
357
358
359
        }
        else
        {
            return rnn_op.actv_funcs;
        }
    }
}

360
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
361
void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
362
363
364
{
    assert(ins->name() == "gru");
    const auto actv_funcs = gru_actv_funcs(ins);
Shucai Xiao's avatar
Shucai Xiao committed
365
366
367
368
    // could be 3 to 6 inputs, but the parse_gru function will
    // append undefined operators to make 6 arguments when parsing
    // an onnx file. Another case is user can have num of arguments
    // when writing their program.
369
370
371
372
373
374
375
376
377
    auto args = ins->inputs();

    shape seq_shape         = args[0]->get_shape();
    std::size_t hidden_size = args[2]->get_shape().lens()[2];
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
    migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
    std::vector<float> data(ih_shape.elements(), 0.0);

Shucai Xiao's avatar
Shucai Xiao committed
378
    auto gru_op             = any_cast<op::gru>(ins->get_operator());
379
380
381
382
383
384
385
386
387
388
389
    op::rnn_direction dirct = gru_op.direction;

    // process sequence length
    instruction_ref seq_lens = prog.end();
    if((args.size() >= 5) && args[4]->name() != "undefined")
    {
        seq_lens = args[4];
    }

    bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);

390
    instruction_ref last_output{};
391
    if(dirct == op::rnn_direction::bidirectional)
392
393
    {
        // w weight matrix
394
395
396
397
        auto w_forward = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
        auto w_reverse = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
398
399

        // r weight matrix
400
401
402
403
        auto r_forward = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
        auto r_reverse = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
404
405
406
407

        // bias
        instruction_ref bias_forward = prog.end();
        instruction_ref bias_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
408
        if(args.size() >= 4 && args[3]->name() != "undefined")
409
        {
410
411
412
413
            bias_forward = prog.insert_instruction(
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
            bias_reverse = prog.insert_instruction(
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
414
415
416
417
418
        }

        // intial hidden state
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
419
        if(args.size() == 6 && args[5]->name() != "undefined")
420
        {
421
422
423
424
            ih_forward = prog.insert_instruction(
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
            ih_reverse = prog.insert_instruction(
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
425
426
427
428
429
430
431
        }
        else
        {
            ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
            ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
        }

432
433
434
435
436
437
438
439
440
441
442
        auto ret_forward =
            gru_cell(true,
                     prog,
                     ins,
                     {args[0], w_forward, r_forward, bias_forward, seq_lens, ih_forward},
                     gru_op.linear_before_reset,
                     actv_funcs.at(0),
                     actv_funcs.at(1));

        if(variable_seq_len)
        {
443
444
            args[0] = prog.insert_instruction(
                ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
445
        }
Shucai Xiao's avatar
Shucai Xiao committed
446

447
448
449
450
451
452
453
454
        auto ret_reverse =
            gru_cell(false,
                     prog,
                     ins,
                     {args[0], w_reverse, r_reverse, bias_reverse, seq_lens, ih_reverse},
                     gru_op.linear_before_reset,
                     actv_funcs.at(2),
                     actv_funcs.at(3));
455

456
457
458
459
        auto concat_output = prog.insert_instruction(
            ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
        last_output =
            prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_output);
460
461
462
463
464

        // The following logic is to ensure the last instruction rewritten
        // from gru operator is a concat
        if(ret_forward[0] == prog.end())
        {
465
466
            prog.replace_instruction(
                ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
467
468
469
        }
        else
        {
470
471
472
473
474
475
            ret_forward[0] = prog.insert_instruction(
                ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
            ret_reverse[0] = prog.insert_instruction(
                ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
            prog.replace_instruction(
                ins, make_op("concat", {{"axis", 1}}), {ret_forward[0], ret_reverse[0]});
476
477
478
479
        }
    }
    else
    {
480
        bool is_forward = (dirct == op::rnn_direction::forward);
481
482
483
484
485
486
        // weight matrix
        auto w = args[1];
        auto r = args[2];

        // bias
        instruction_ref bias = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
487
        if(args.size() >= 4 && args[3]->name() != "undefined")
488
489
490
491
492
493
        {
            bias = args[3];
        }

        // intial hidden state
        instruction_ref ih{};
Shucai Xiao's avatar
Shucai Xiao committed
494
        if(args.size() == 6 && args[5]->name() != "undefined")
495
496
497
498
499
500
501
502
        {
            ih = args[5];
        }
        else
        {
            ih = prog.add_literal(migraphx::literal{ih_shape, data});
        }

503
504
        if(!is_forward and variable_seq_len)
        {
505
506
            args[0] = prog.insert_instruction(
                ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
507
508
        }

509
510
511
        auto ret = gru_cell(is_forward,
                            prog,
                            ins,
512
                            {args[0], w, r, bias, seq_lens, ih},
513
514
515
516
                            gru_op.linear_before_reset,
                            actv_funcs.at(0),
                            actv_funcs.at(1));

517
        last_output = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
518
519
520

        if(ret[0] == prog.end())
        {
521
            prog.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
522
523
524
525
526
        }
        else
        {
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
527
528
            prog.replace_instruction(
                ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
529
530
531
        }
    }

532
533
534
535
    // in case of all sequences are of the same lengths and shorter than the
    // max sequence length, need to pad 0's at the end for output hidden states
    ins = pad_hidden_states(prog, args[0], seq_lens, ins);
    replace_last_hs_output(prog, ins, seq_lens, last_output, dirct);
536
537
}

538
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
539
std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
540
                                                   module& prog,
Shucai Xiao's avatar
Shucai Xiao committed
541
542
543
544
545
                                                   instruction_ref ins,
                                                   std::vector<instruction_ref> inputs,
                                                   int linear_before_reset,
                                                   const operation& actv_func1,
                                                   const operation& actv_func2) const
546
{
547
548
549
550
551
552
553
    assert(inputs.size() == 6);
    auto seq      = inputs.at(0);
    auto w        = inputs.at(1);
    auto r        = inputs.at(2);
    auto bias     = inputs.at(3);
    auto seq_lens = inputs.at(4);
    auto ih       = inputs.at(5);
554

Shucai Xiao's avatar
Shucai Xiao committed
555
556
    instruction_ref hidden_states = prog.end();
    instruction_ref last_output{};
Shucai Xiao's avatar
Shucai Xiao committed
557
    migraphx::shape seq_shape = seq->get_shape();
Shucai Xiao's avatar
Shucai Xiao committed
558
    migraphx::shape r_shape   = r->get_shape();
Paul Fultz II's avatar
Paul Fultz II committed
559
    long hs                   = r_shape.lens()[2];
560

561
562
563
    migraphx::shape ss(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]});
    std::vector<float> data(ss.elements(), 1.0f);
    auto l1 = prog.add_literal(migraphx::literal{ss, data});
564

565
    // w matrix squeeze to 2-dim and do a transpose
566
    std::vector<int64_t> perm{1, 0};
567
    auto sw = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
568
    auto tw = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
569

570
    // r slide to two part, zr and h
571
572
573
    auto sr  = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
    auto rzr = prog.insert_instruction(
        ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2 * hs}}}), sr);
574
    auto trzr = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rzr);
575

576
577
    auto rh = prog.insert_instruction(
        ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), sr);
578
    auto trh = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), rh);
579
580

    // initial states
581
    auto sih  = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
Shucai Xiao's avatar
Shucai Xiao committed
582
    size_t bs = ih->get_shape().lens()[1];
583
584

    // bias
585
586
587
    instruction_ref bwb{};
    instruction_ref brb_zr{};
    instruction_ref brb_h{};
588
589
    if(bias != prog.end())
    {
590
591
592
593
594
        auto sbias = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
        auto wb    = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {3 * hs}}}), sbias);
        bwb = prog.insert_instruction(
            ins,
595
            make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(3 * hs)}}}),
596
597
598
599
600
601
602
603
604
605
606
607
            wb);

        auto rb_zr = prog.insert_instruction(
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {3 * hs}}, {"ends", {5 * hs}}}),
            sbias);
        auto rb_h = prog.insert_instruction(
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {5 * hs}}, {"ends", {6 * hs}}}),
            sbias);
        brb_zr = prog.insert_instruction(
            ins,
608
            make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(2 * hs)}}}),
609
610
611
            rb_zr);
        brb_h = prog.insert_instruction(
            ins,
612
            make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, static_cast<size_t>(hs)}}}),
613
            rb_h);
614
615
    }

Paul Fultz II's avatar
Paul Fultz II committed
616
    long seq_len = get_seq_len(prog, seq, seq_lens);
617
618
619
    for(long i = 0; i < seq_len; i++)
    {
        long seq_index = is_forward ? i : (seq_len - 1 - i);
620
621
622
623
624
625
626
627
628
        auto xt        = prog.insert_instruction(
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
            seq);
        auto cont_xt = prog.insert_instruction(ins, make_op("contiguous"), xt);
        xt           = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);

        auto xt_w    = prog.insert_instruction(ins, make_op("dot"), xt, tw);
        auto ih1_rzr = prog.insert_instruction(ins, make_op("dot"), sih, trzr);
Shucai Xiao's avatar
Shucai Xiao committed
629
        if(bias != prog.end())
630
        {
631
632
            xt_w    = prog.insert_instruction(ins, make_op("add"), xt_w, bwb);
            ih1_rzr = prog.insert_instruction(ins, make_op("add"), ih1_rzr, brb_zr);
633
634
        }

635
636
637
638
639
640
        auto xw_z = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), xt_w);
        auto xw_r = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), xt_w);
        auto xw_h = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {1}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), xt_w);
641

642
643
644
645
        auto hr_z = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), ih1_rzr);
        auto hr_r = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), ih1_rzr);
646

647
        auto xw_hr_z = prog.insert_instruction(ins, make_op("add"), xw_z, hr_z);
Shucai Xiao's avatar
Shucai Xiao committed
648
        auto zt      = prog.insert_instruction(ins, actv_func1, xw_hr_z);
649

650
        auto xw_hr_r = prog.insert_instruction(ins, make_op("add"), xw_r, hr_r);
Shucai Xiao's avatar
Shucai Xiao committed
651
        auto rt      = prog.insert_instruction(ins, actv_func1, xw_hr_r);
652
653
654
655
656

        instruction_ref hr_h{};
        if(linear_before_reset == 0)
        {
            // equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
657
658
            auto rt_ht1 = prog.insert_instruction(ins, make_op("mul"), rt, sih);
            hr_h        = prog.insert_instruction(ins, make_op("dot"), rt_ht1, trh);
Shucai Xiao's avatar
Shucai Xiao committed
659
            if(bias != prog.end())
660
            {
661
                hr_h = prog.insert_instruction(ins, make_op("add"), hr_h, brb_h);
662
            }
663
664
665
        }
        else
        {
666
            // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
667
            auto ht1_rh = prog.insert_instruction(ins, make_op("dot"), sih, trh);
Shucai Xiao's avatar
Shucai Xiao committed
668
            if(bias != prog.end())
669
            {
670
                ht1_rh = prog.insert_instruction(ins, make_op("add"), ht1_rh, brb_h);
671
            }
672
            hr_h = prog.insert_instruction(ins, make_op("mul"), rt, ht1_rh);
673
674
        }

675
        auto xw_hr_h = prog.insert_instruction(ins, make_op("add"), xw_h, hr_h);
Shucai Xiao's avatar
Shucai Xiao committed
676
        auto ht      = prog.insert_instruction(ins, actv_func2, xw_hr_h);
677

678
        // equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
679
680
681
682
683
        auto one_minus_zt    = prog.insert_instruction(ins, make_op("sub"), l1, zt);
        auto one_minus_zt_ht = prog.insert_instruction(ins, make_op("mul"), one_minus_zt, ht);
        auto zt_ht1          = prog.insert_instruction(ins, make_op("mul"), zt, sih);
        sih         = prog.insert_instruction(ins, make_op("add"), one_minus_zt_ht, zt_ht1);
        last_output = prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), sih);
684
685
686
687
688
689
690
691

        if(i < seq_len - 1)
        {
            if(is_forward)
            {
                hidden_states =
                    (seq_index == 0)
                        ? last_output
692
693
                        : prog.insert_instruction(
                              ins, make_op("concat", {{"axis", 0}}), hidden_states, last_output);
694
695
696
697
698
699
            }
            else
            {
                hidden_states =
                    (seq_index == seq_len - 1)
                        ? last_output
700
701
                        : prog.insert_instruction(
                              ins, make_op("concat", {{"axis", 0}}), last_output, hidden_states);
702
703
704
705
706
707
708
709
710
711
712
713
714
715
            }
        }
    }

    return {hidden_states, last_output};
}

std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
{
    auto gru_op = any_cast<op::gru>(ins->get_operator());
    // before rewrite the gru operator, need to ensure
    // we have 4 actv funcs, even though a user does not
    // specifiy any actv func. If less than 4, use the
    // algorithm in parse_gru to make 4 actv functions
716
    if(gru_op.direction == op::rnn_direction::bidirectional)
717
718
    {
        if(gru_op.actv_funcs.empty())
719
            return {make_op("sigmoid"), make_op("tanh"), make_op("sigmoid"), make_op("tanh")};
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
        else if(gru_op.actv_funcs.size() == 1)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0)};
        else if(gru_op.actv_funcs.size() == 2)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1)};
        else if(gru_op.actv_funcs.size() == 3)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1),
                    gru_op.actv_funcs.at(2),
                    gru_op.actv_funcs.at(0)};
        else
            return gru_op.actv_funcs;
    }
    else
    {
        if(gru_op.actv_funcs.empty())
741
            return {make_op("sigmoid"), make_op("tanh")};
742
743
744
745
746
747
748
        else if(gru_op.actv_funcs.size() == 1)
            return {gru_op.actv_funcs.at(0), gru_op.actv_funcs.at(0)};
        else
            return gru_op.actv_funcs;
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
749
// for lstm operators
750
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
751
void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
Shucai Xiao's avatar
Shucai Xiao committed
752
753
754
{
    assert(ins->name() == "lstm");
    auto args = ins->inputs();
Shucai Xiao's avatar
Shucai Xiao committed
755

Shucai Xiao's avatar
Shucai Xiao committed
756
    shape seq_shape         = args[0]->get_shape();
757
    std::size_t hidden_size = args[2]->get_shape().lens()[2];
Shucai Xiao's avatar
Shucai Xiao committed
758
759
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
Shucai Xiao's avatar
Shucai Xiao committed
760
    migraphx::shape ihc_shape{type, {1, batch_size, hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
761
    std::vector<float> ihc_data(ihc_shape.elements(), 0.0);
Shucai Xiao's avatar
Shucai Xiao committed
762
763
764

    migraphx::shape pph_shape{type, {1, 3 * hidden_size}};

Shucai Xiao's avatar
Shucai Xiao committed
765
766
    auto actv_funcs         = lstm_actv_funcs(ins);
    auto lstm_op            = any_cast<op::lstm>(ins->get_operator());
Shucai Xiao's avatar
Shucai Xiao committed
767
    op::rnn_direction dirct = lstm_op.direction;
Shucai Xiao's avatar
Shucai Xiao committed
768

Shucai Xiao's avatar
Shucai Xiao committed
769
770
771
772
773
774
775
776
777
778
    // process sequence length
    instruction_ref seq_lens = prog.end();
    if((args.size() >= 5) && args[4]->name() != "undefined")
    {
        seq_lens = args[4];
    }

    bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);

    instruction_ref last_hs_output{};
Shucai Xiao's avatar
Shucai Xiao committed
779
    instruction_ref last_cell_output{};
Shucai Xiao's avatar
Shucai Xiao committed
780
781
    instruction_ref hidden_state{};
    instruction_ref cell_outputs{};
Shucai Xiao's avatar
Shucai Xiao committed
782
    if(dirct == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
783
784
785
    {
        // input weight matrix
        // input weight matrix
786
787
788
789
        auto w_forward = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[1]);
        auto w_reverse = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[1]);
Shucai Xiao's avatar
Shucai Xiao committed
790
791

        // hidden state weight matrix
792
793
794
795
        auto r_forward = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[2]);
        auto r_reverse = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[2]);
Shucai Xiao's avatar
Shucai Xiao committed
796
797
798
799

        // process bias
        instruction_ref bias_forward = prog.end();
        instruction_ref bias_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
800
        if(args.size() >= 4 && args[3]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
801
        {
802
803
804
805
            bias_forward = prog.insert_instruction(
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[3]);
            bias_reverse = prog.insert_instruction(
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[3]);
Shucai Xiao's avatar
Shucai Xiao committed
806
807
808
809
810
811
812
        }

        // process intial hidden state, it is the 6th argument
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
        if(args.size() >= 6 && args[5]->name() != "undefined")
        {
813
814
815
816
            ih_forward = prog.insert_instruction(
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[5]);
            ih_reverse = prog.insert_instruction(
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[5]);
Shucai Xiao's avatar
Shucai Xiao committed
817
818
819
820
821
822
823
824
825
826
        }
        else
        {
            ih_forward = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
            ih_reverse = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
        }

        // process initial cell value
        instruction_ref ic_forward{};
        instruction_ref ic_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
827
        if(args.size() >= 7 && args[6]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
828
        {
829
830
831
832
            ic_forward = prog.insert_instruction(
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[6]);
            ic_reverse = prog.insert_instruction(
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[6]);
Shucai Xiao's avatar
Shucai Xiao committed
833
834
835
836
837
838
839
840
        }
        else
        {
            ic_forward = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
            ic_reverse = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
        }

        // process weight of the peephole
Shucai Xiao's avatar
Shucai Xiao committed
841
842
        instruction_ref pph_forward = prog.end();
        instruction_ref pph_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
843
        if(args.size() == 8 && args[7]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
844
        {
845
846
847
848
            pph_forward = prog.insert_instruction(
                ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), args[7]);
            pph_reverse = prog.insert_instruction(
                ins, make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), args[7]);
Shucai Xiao's avatar
Shucai Xiao committed
849
        }
Shucai Xiao's avatar
Shucai Xiao committed
850

Shucai Xiao's avatar
Shucai Xiao committed
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
        auto ret_forward = lstm_cell(true,
                                     prog,
                                     ins,
                                     {args[0],
                                      w_forward,
                                      r_forward,
                                      bias_forward,
                                      seq_lens,
                                      ih_forward,
                                      ic_forward,
                                      pph_forward},
                                     actv_funcs.at(0),
                                     actv_funcs.at(1),
                                     actv_funcs.at(2));

        if(variable_seq_len)
        {
868
869
            args[0] = prog.insert_instruction(
                ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
        }
        auto ret_reverse = lstm_cell(false,
                                     prog,
                                     ins,
                                     {args[0],
                                      w_reverse,
                                      r_reverse,
                                      bias_reverse,
                                      seq_lens,
                                      ih_reverse,
                                      ic_reverse,
                                      pph_reverse},
                                     actv_funcs.at(3),
                                     actv_funcs.at(4),
                                     actv_funcs.at(5));

886
887
888
889
890
891
892
893
        auto concat_hs_output = prog.insert_instruction(
            ins, make_op("concat", {{"axis", 1}}), ret_forward[1], ret_reverse[1]);
        auto concat_cell_output = prog.insert_instruction(
            ins, make_op("concat", {{"axis", 1}}), ret_forward[3], ret_reverse[3]);
        last_hs_output =
            prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_hs_output);
        last_cell_output =
            prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), concat_cell_output);
Shucai Xiao's avatar
Shucai Xiao committed
894
895

        // the following logic is to ensure the last instruction is a concat
Shucai Xiao's avatar
Shucai Xiao committed
896
        if(ret_forward[0] == prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
897
        {
Shucai Xiao's avatar
Shucai Xiao committed
898
            cell_outputs = concat_cell_output;
Shucai Xiao's avatar
Shucai Xiao committed
899
900
901
        }
        else
        {
902
903
904
905
            ret_forward[1] = prog.insert_instruction(
                ins, make_op("concat", {{"axis", 0}}), ret_forward[0], ret_forward[1]);
            ret_reverse[1] = prog.insert_instruction(
                ins, make_op("concat", {{"axis", 0}}), ret_reverse[1], ret_reverse[0]);
Shucai Xiao's avatar
Shucai Xiao committed
906

907
908
909
910
911
912
            ret_forward[3] = prog.insert_instruction(
                ins, make_op("concat", {{"axis", 0}}), ret_forward[2], ret_forward[3]);
            ret_reverse[3] = prog.insert_instruction(
                ins, make_op("concat", {{"axis", 0}}), ret_reverse[3], ret_reverse[2]);
            cell_outputs = prog.insert_instruction(
                ins, make_op("concat", {{"axis", 1}}), ret_forward[3], ret_reverse[3]);
Shucai Xiao's avatar
Shucai Xiao committed
913
        }
Shucai Xiao's avatar
Shucai Xiao committed
914

915
916
        hidden_state = prog.replace_instruction(
            ins, make_op("concat", {{"axis", 1}}), {ret_forward[1], ret_reverse[1]});
Shucai Xiao's avatar
Shucai Xiao committed
917
918
919
    }
    else
    {
Shucai Xiao's avatar
Shucai Xiao committed
920
        bool is_forward = (dirct == op::rnn_direction::forward);
Shucai Xiao's avatar
Shucai Xiao committed
921
922
923
924
925
926
        // weight matrices
        auto w = args[1];
        auto r = args[2];

        // bias
        instruction_ref bias = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
927
        if(args.size() >= 4 && args[3]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
928
929
930
931
932
933
        {
            bias = args[3];
        }

        // initial hidden state
        instruction_ref ih{};
Shucai Xiao's avatar
Shucai Xiao committed
934
        if(args.size() >= 6 && args[5]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
935
936
937
938
939
940
941
942
943
944
        {
            ih = args[5];
        }
        else
        {
            ih = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
        }

        // initial cell value
        instruction_ref ic{};
Shucai Xiao's avatar
Shucai Xiao committed
945
        if(args.size() >= 7 && args[6]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
946
947
948
949
950
951
952
953
954
        {
            ic = args[6];
        }
        else
        {
            ic = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
        }

        // process weight of the peephole
Shucai Xiao's avatar
Shucai Xiao committed
955
        instruction_ref pph = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
956
957
958
959
        if(args.size() == 8 && args[7]->name() != "undefined")
        {
            pph = args[7];
        }
Shucai Xiao's avatar
Shucai Xiao committed
960

Shucai Xiao's avatar
Shucai Xiao committed
961
962
        if(!is_forward and variable_seq_len)
        {
963
964
            args[0] = prog.insert_instruction(
                ins, make_op("rnn_var_sl_shift_sequence"), args[0], seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
965
        }
Shucai Xiao's avatar
Shucai Xiao committed
966
        auto ret = lstm_cell(is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
967
968
                             prog,
                             ins,
Shucai Xiao's avatar
Shucai Xiao committed
969
                             {args[0], w, r, bias, seq_lens, ih, ic, pph},
Shucai Xiao's avatar
Shucai Xiao committed
970
971
972
973
                             actv_funcs.at(0),
                             actv_funcs.at(1),
                             actv_funcs.at(2));

974
975
976
        last_hs_output = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[1]);
        last_cell_output =
            prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ret[3]);
Shucai Xiao's avatar
Shucai Xiao committed
977

Shucai Xiao's avatar
Shucai Xiao committed
978
        if(ret[0] == prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
979
        {
Shucai Xiao's avatar
Shucai Xiao committed
980
            cell_outputs = ret[3];
981
            hidden_state = prog.replace_instruction(ins, make_op("concat", {{"axis", 0}}), ret[1]);
Shucai Xiao's avatar
Shucai Xiao committed
982
983
984
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
985
986
            auto concat_cell_arg0 = is_forward ? ret[2] : ret[3];
            auto concat_cell_arg1 = is_forward ? ret[3] : ret[2];
987
988
            cell_outputs          = prog.insert_instruction(
                ins, make_op("concat", {{"axis", 0}}), concat_cell_arg0, concat_cell_arg1);
Shucai Xiao's avatar
Shucai Xiao committed
989

Shucai Xiao's avatar
Shucai Xiao committed
990
991
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
992
993
            hidden_state     = prog.replace_instruction(
                ins, make_op("concat", {{"axis", 0}}), concat_arg0, concat_arg1);
Shucai Xiao's avatar
Shucai Xiao committed
994
995
996
        }
    }

997
998
999
1000
1001
    // in case of all sequences are of the same lengths and shorter than the
    // max sequence length, need to pad 0's at the end for output hidden states
    hidden_state = pad_hidden_states(prog, args[0], seq_lens, hidden_state);

    // replace last hidden states with corresponding instructions
Shucai Xiao's avatar
Shucai Xiao committed
1002
    ins = replace_last_hs_output(prog, hidden_state, seq_lens, last_hs_output, dirct);
1003
1004

    // replace last cell outputs with corresponding instructions
Shucai Xiao's avatar
Shucai Xiao committed
1005
    replace_last_cell_output(prog, ins, seq_lens, cell_outputs, last_cell_output, dirct);
Shucai Xiao's avatar
Shucai Xiao committed
1006
1007
}

1008
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
Shucai Xiao's avatar
Shucai Xiao committed
1009
std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
1010
                                                    module& prog,
Shucai Xiao's avatar
Shucai Xiao committed
1011
1012
1013
1014
1015
                                                    instruction_ref ins,
                                                    std::vector<instruction_ref> inputs,
                                                    const operation& actv_func1,
                                                    const operation& actv_func2,
                                                    const operation& actv_func3) const
Shucai Xiao's avatar
Shucai Xiao committed
1016
{
Shucai Xiao's avatar
Shucai Xiao committed
1017
    // must have 7 args in the input vector
Shucai Xiao's avatar
Shucai Xiao committed
1018
1019
1020
1021
1022
1023
1024
1025
1026
    assert(inputs.size() == 8);
    auto seq      = inputs.at(0);
    auto w        = inputs.at(1);
    auto r        = inputs.at(2);
    auto bias     = inputs.at(3);
    auto seq_lens = inputs.at(4);
    auto ih       = inputs.at(5);
    auto ic       = inputs.at(6);
    auto pph      = inputs.at(7);
Shucai Xiao's avatar
Shucai Xiao committed
1027

Shucai Xiao's avatar
Shucai Xiao committed
1028
    instruction_ref hidden_states = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
1029
1030
1031
    instruction_ref cell_outputs  = prog.end();

    instruction_ref last_hs_output{};
Shucai Xiao's avatar
Shucai Xiao committed
1032
1033
    instruction_ref last_cell_output{};

1034
    migraphx::shape r_shape = r->get_shape();
Paul Fultz II's avatar
Paul Fultz II committed
1035
    long hs                 = r_shape.lens()[2];
1036
    auto bs                 = ih->get_shape().lens()[1];
Shucai Xiao's avatar
Shucai Xiao committed
1037
1038

    std::vector<int64_t> perm{1, 0};
1039
    // w matrix, squeeze and transpose
1040
    auto sw  = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), w);
1041
    auto tsw = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sw);
Shucai Xiao's avatar
Shucai Xiao committed
1042

1043
    // r matrix, squeeze and transpose
1044
    auto sr  = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), r);
1045
    auto tsr = prog.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), sr);
Shucai Xiao's avatar
Shucai Xiao committed
1046

Shucai Xiao's avatar
Shucai Xiao committed
1047
    // initial hidden state
1048
    auto sih = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ih);
Shucai Xiao's avatar
Shucai Xiao committed
1049
1050

    // initial cell state
1051
    auto sic     = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), ic);
1052
    auto ic_lens = sic->get_shape().lens();
Shucai Xiao's avatar
Shucai Xiao committed
1053

1054
    // bias
1055
    instruction_ref wrb{};
Shucai Xiao's avatar
Shucai Xiao committed
1056
    if(bias != prog.end())
1057
    {
1058

1059
1060
1061
1062
1063
1064
1065
1066
        auto sbias = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), bias);
        auto ub_wb = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {4 * hs}}}), sbias);
        auto ub_rb = prog.insert_instruction(
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {4 * hs}}, {"ends", {8 * hs}}}),
            sbias);
        auto ub_wrb = prog.insert_instruction(ins, make_op("add"), ub_wb, ub_rb);
1067

1068
        wrb = prog.insert_instruction(
1069
            ins,
1070
            make_op("broadcast", {{"axis", 1}, {"out_lens", {bs, 4 * static_cast<size_t>(hs)}}}),
1071
            ub_wrb);
1072
1073
    }

Shucai Xiao's avatar
Shucai Xiao committed
1074
    // peep hole
Shucai Xiao's avatar
Shucai Xiao committed
1075
1076
1077
    instruction_ref pphi_brcst{};
    instruction_ref ppho_brcst{};
    instruction_ref pphf_brcst{};
Shucai Xiao's avatar
Shucai Xiao committed
1078
    if(pph != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
1079
    {
1080
1081
1082
1083
        auto spph = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), pph);
        auto pphi = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {hs}}}), spph);
        pphi_brcst = prog.insert_instruction(
1084
            ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), pphi);
1085
1086
1087
1088

        auto ppho = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {hs}}, {"ends", {2 * hs}}}), spph);
        ppho_brcst = prog.insert_instruction(
1089
            ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), ppho);
1090
1091
1092
1093

        auto pphf = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {0}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}), spph);
        pphf_brcst = prog.insert_instruction(
1094
            ins, make_op("broadcast", {{"axis", 1}, {"out_lens", ic_lens}}), pphf);
Shucai Xiao's avatar
Shucai Xiao committed
1095
    }
Shucai Xiao's avatar
Shucai Xiao committed
1096

Paul Fultz II's avatar
Paul Fultz II committed
1097
    long seq_len = get_seq_len(prog, seq, seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1098
    for(long i = 0; i < seq_len; ++i)
Shucai Xiao's avatar
Shucai Xiao committed
1099
1100
    {
        long seq_index = is_forward ? i : (seq_len - 1 - i);
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
        auto xt        = prog.insert_instruction(
            ins,
            make_op("slice", {{"axes", {0}}, {"starts", {seq_index}}, {"ends", {seq_index + 1}}}),
            seq);
        auto cont_xt = prog.insert_instruction(ins, make_op("contiguous"), xt);
        xt           = prog.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), cont_xt);

        auto xt_tsw  = prog.insert_instruction(ins, make_op("dot"), xt, tsw);
        auto sih_tsr = prog.insert_instruction(ins, make_op("dot"), sih, tsr);
        auto xt_sih  = prog.insert_instruction(ins, make_op("add"), xt_tsw, sih_tsr);
Shucai Xiao's avatar
Shucai Xiao committed
1111
        if(bias != prog.end())
1112
        {
1113
            xt_sih = prog.insert_instruction(ins, make_op("add"), xt_sih, wrb);
1114
        }
Shucai Xiao's avatar
Shucai Xiao committed
1115

1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
        auto it_before_actv = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {hs}}}), xt_sih);
        auto ot_before_actv = prog.insert_instruction(
            ins, make_op("slice", {{"axes", {1}}, {"starts", {hs}}, {"ends", {2 * hs}}}), xt_sih);
        auto ft_before_actv = prog.insert_instruction(
            ins,
            make_op("slice", {{"axes", {1}}, {"starts", {2 * hs}}, {"ends", {3 * hs}}}),
            xt_sih);
        auto ct_before_actv = prog.insert_instruction(
            ins,
            make_op("slice", {{"axes", {1}}, {"starts", {3 * hs}}, {"ends", {4 * hs}}}),
            xt_sih);
1128

Shucai Xiao's avatar
Shucai Xiao committed
1129
        if(pph != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
1130
        {
1131
1132
            auto pphi_ct   = prog.insert_instruction(ins, make_op("mul"), pphi_brcst, sic);
            it_before_actv = prog.insert_instruction(ins, make_op("add"), it_before_actv, pphi_ct);
Shucai Xiao's avatar
Shucai Xiao committed
1133

1134
1135
            auto pphf_ct   = prog.insert_instruction(ins, make_op("mul"), pphf_brcst, sic);
            ft_before_actv = prog.insert_instruction(ins, make_op("add"), ft_before_actv, pphf_ct);
Shucai Xiao's avatar
Shucai Xiao committed
1136
        }
Shucai Xiao's avatar
Shucai Xiao committed
1137
        auto it = prog.insert_instruction(ins, actv_func1, it_before_actv);
1138
        auto ft = prog.insert_instruction(ins, actv_func1, ft_before_actv);
1139
        auto ct = prog.insert_instruction(ins, actv_func2, ct_before_actv);
1140
1141

        // equation Ct = ft (.) Ct-1 + it (.) ct
1142
1143
1144
        auto ft_cell = prog.insert_instruction(ins, make_op("mul"), ft, sic);
        auto it_ct   = prog.insert_instruction(ins, make_op("mul"), it, ct);
        auto cellt   = prog.insert_instruction(ins, make_op("add"), ft_cell, it_ct);
1145

Shucai Xiao's avatar
Shucai Xiao committed
1146
        if(pph != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
1147
        {
1148
1149
1150
            auto ppho_cellt = prog.insert_instruction(ins, make_op("mul"), ppho_brcst, cellt);
            ot_before_actv =
                prog.insert_instruction(ins, make_op("add"), ot_before_actv, ppho_cellt);
Shucai Xiao's avatar
Shucai Xiao committed
1151
        }
1152
1153
1154
1155
        auto ot = prog.insert_instruction(ins, actv_func1, ot_before_actv);

        // Ht = ot (.) h(Ct)
        auto h_cellt = prog.insert_instruction(ins, actv_func3, cellt);
1156
        auto ht      = prog.insert_instruction(ins, make_op("mul"), ot, h_cellt);
1157
1158
1159
1160

        sic = cellt;
        sih = ht;

1161
1162
1163
        last_hs_output = prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), ht);
        last_cell_output =
            prog.insert_instruction(ins, make_op("unsqueeze", {{"axes", {0, 1}}}), cellt);
1164

Shucai Xiao's avatar
Shucai Xiao committed
1165
        if(i < seq_len - 1)
1166
        {
Shucai Xiao's avatar
Shucai Xiao committed
1167
            if(i == 0)
1168
            {
Shucai Xiao's avatar
Shucai Xiao committed
1169
1170
                hidden_states = last_hs_output;
                cell_outputs  = last_cell_output;
1171
1172
1173
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
1174
1175
                auto concat_hs_arg0 = is_forward ? hidden_states : last_hs_output;
                auto concat_hs_arg1 = is_forward ? last_hs_output : hidden_states;
1176
1177
                hidden_states       = prog.insert_instruction(
                    ins, make_op("concat", {{"axis", 0}}), concat_hs_arg0, concat_hs_arg1);
Shucai Xiao's avatar
Shucai Xiao committed
1178
1179
1180

                auto concat_cell_arg0 = is_forward ? cell_outputs : last_cell_output;
                auto concat_cell_arg1 = is_forward ? last_cell_output : cell_outputs;
1181
1182
                cell_outputs          = prog.insert_instruction(
                    ins, make_op("concat", {{"axis", 0}}), concat_cell_arg0, concat_cell_arg1);
1183
1184
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
1185
    }
1186

Shucai Xiao's avatar
Shucai Xiao committed
1187
    return {hidden_states, last_hs_output, cell_outputs, last_cell_output};
Shucai Xiao's avatar
Shucai Xiao committed
1188
1189
1190
1191
1192
1193
1194
1195
1196
}

std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
{
    auto lstm_op = any_cast<op::lstm>(ins->get_operator());
    // before rewrite the lstm operator, need to ensure
    // we have 6 actv funcs, even though a user does not
    // specifiy any actv func. If less than 46, use the
    // algorithm in parse_lstm to make 6 actv functions
Shucai Xiao's avatar
Shucai Xiao committed
1197
    const auto& actv_funcs     = lstm_op.actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1198
    std::size_t num_actv_funcs = actv_funcs.size();
Shucai Xiao's avatar
Shucai Xiao committed
1199
    if(lstm_op.direction == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1200
1201
1202
1203
    {
        switch(num_actv_funcs)
        {
        case 0:
1204
1205
1206
1207
1208
1209
            return {make_op("sigmoid"),
                    make_op("tanh"),
                    make_op("tanh"),
                    make_op("sigmoid"),
                    make_op("tanh"),
                    make_op("tanh")};
Shucai Xiao's avatar
Shucai Xiao committed
1210
1211

        case 1:
Shucai Xiao's avatar
Shucai Xiao committed
1212
1213
1214
1215
1216
1217
            return {actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0)};
Shucai Xiao's avatar
Shucai Xiao committed
1218
1219

        case 2:
Shucai Xiao's avatar
Shucai Xiao committed
1220
1221
1222
1223
1224
1225
1226
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(1),
                    actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(1)};

Shucai Xiao's avatar
Shucai Xiao committed
1227
        case 3:
Shucai Xiao's avatar
Shucai Xiao committed
1228
1229
1230
1231
1232
1233
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2)};
Shucai Xiao's avatar
Shucai Xiao committed
1234
1235

        case 4:
Shucai Xiao's avatar
Shucai Xiao committed
1236
1237
1238
1239
1240
1241
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(3),
                    actv_funcs.at(3),
                    actv_funcs.at(3)};
Shucai Xiao's avatar
Shucai Xiao committed
1242
1243

        case 5:
Shucai Xiao's avatar
Shucai Xiao committed
1244
1245
1246
1247
1248
1249
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(3),
                    actv_funcs.at(4),
                    actv_funcs.at(4)};
Shucai Xiao's avatar
Shucai Xiao committed
1250

Shucai Xiao's avatar
Shucai Xiao committed
1251
        default: return actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1252
1253
1254
1255
1256
1257
        }
    }
    else
    {
        switch(num_actv_funcs)
        {
1258
        case 0: return {make_op("sigmoid"), make_op("tanh"), make_op("tanh")};
Shucai Xiao's avatar
Shucai Xiao committed
1259

Shucai Xiao's avatar
Shucai Xiao committed
1260
        case 1: return {actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0)};
Shucai Xiao's avatar
Shucai Xiao committed
1261

Shucai Xiao's avatar
Shucai Xiao committed
1262
1263
1264
        case 2: return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1)};

        default: return actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1265
1266
1267
1268
        }
    }
}

1269
bool rewrite_rnn::is_variable_seq_lens(const module& prog, instruction_ref seq_lens) const
Shucai Xiao's avatar
Shucai Xiao committed
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
{
    bool is_var_lens = false;
    if(seq_lens != prog.end())
    {
        if(seq_lens->can_eval())
        {
            auto arg_lens = seq_lens->eval();
            std::vector<int64_t> vec_lens;
            arg_lens.visit([&](auto l) { vec_lens.assign(l.begin(), l.end()); });
            int64_t l = 0;
            if(!vec_lens.empty())
            {
                l = vec_lens[0];
            }
            if(!std::all_of(vec_lens.begin(), vec_lens.end(), [&](auto v) { return v == l; }))
            {
                is_var_lens = true;
            }
        }
        else
        {
            is_var_lens = true;
        }
    }

    return is_var_lens;
}

std::size_t
1299
rewrite_rnn::get_seq_len(const module& prog, instruction_ref input, instruction_ref seq_lens) const
Shucai Xiao's avatar
Shucai Xiao committed
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
{
    bool is_var_lens = is_variable_seq_lens(prog, seq_lens);
    auto input_shape = input->get_shape();
    auto length      = input_shape.lens()[0];
    if(!is_var_lens and seq_lens != prog.end())
    {
        auto arg_len = seq_lens->eval();
        std::vector<std::size_t> vec_lens;
        arg_len.visit([&](auto l) { vec_lens.assign(l.begin(), l.end()); });
        length = vec_lens.empty() ? length : vec_lens[0];
    }

    return length;
}

1315
instruction_ref rewrite_rnn::replace_last_hs_output(module& prog,
Shucai Xiao's avatar
Shucai Xiao committed
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
                                                    instruction_ref ins,
                                                    instruction_ref seq_lens,
                                                    instruction_ref last_hs_output,
                                                    op::rnn_direction dirct) const
{
    bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);
    instruction_ref result_ins{};
    if(variable_seq_len)
    {
        result_ins = prog.insert_instruction(
1326
1327
1328
1329
1330
            std::next(ins),
            make_op("rnn_var_sl_shift_output",
                    {{"output_name", "hidden_states"}, {"direction", dirct}}),
            ins,
            seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1331
        prog.replace_instruction(ins, result_ins);
1332
1333
        auto hs_outputs = find_all(result_ins->outputs(),
                                   [&](auto i) { return i->name() == "rnn_last_hs_output"; });
Shucai Xiao's avatar
Shucai Xiao committed
1334

1335
        for(auto& hs_out : hs_outputs)
Shucai Xiao's avatar
Shucai Xiao committed
1336
        {
1337
            auto inputs = hs_out->inputs();
1338
1339
1340
1341
            prog.replace_instruction(hs_out,
                                     make_op("rnn_var_sl_last_output", {{"direction", dirct}}),
                                     inputs.front(),
                                     seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1342
1343
1344
1345
        }
    }
    else
    {
1346
1347
        auto hs_outputs =
            find_all(ins->outputs(), [&](auto i) { return i->name() == "rnn_last_hs_output"; });
Shucai Xiao's avatar
Shucai Xiao committed
1348

1349
1350
1351
        for(auto& hs_out : hs_outputs)
        {
            prog.replace_instruction(hs_out, last_hs_output);
Shucai Xiao's avatar
Shucai Xiao committed
1352
        }
1353

Shucai Xiao's avatar
Shucai Xiao committed
1354
1355
1356
1357
1358
1359
        result_ins = ins;
    }

    return result_ins;
}

1360
void rewrite_rnn::replace_last_cell_output(module& prog,
Shucai Xiao's avatar
Shucai Xiao committed
1361
1362
1363
1364
1365
1366
1367
                                           instruction_ref ins,
                                           instruction_ref seq_lens,
                                           instruction_ref cell_outputs,
                                           instruction_ref last_cell_output,
                                           op::rnn_direction dirct) const
{
    bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);
1368
1369
    auto ins_outputs =
        find_all(ins->outputs(), [&](auto i) { return i->name() == "rnn_last_cell_output"; });
Shucai Xiao's avatar
Shucai Xiao committed
1370
1371
1372

    if(variable_seq_len)
    {
1373
        if(!ins_outputs.empty())
Shucai Xiao's avatar
Shucai Xiao committed
1374
        {
1375
1376
1377
1378
1379
1380
            cell_outputs = prog.insert_instruction(
                std::next(ins),
                make_op("rnn_var_sl_shift_output",
                        {{"output_name", "cell_outputs"}, {"direction", dirct}}),
                cell_outputs,
                seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1381
1382
        }

1383
        for(auto co : ins_outputs)
Shucai Xiao's avatar
Shucai Xiao committed
1384
        {
1385
1386
1387
1388
            prog.replace_instruction(co,
                                     make_op("rnn_var_sl_last_output", {{"direction", dirct}}),
                                     cell_outputs,
                                     seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1389
1390
1391
1392
1393
1394
        }
    }
    // replace the rnn_last_cell_output with the last_cell_output. The while
    // loop is to handle the case of multiple rnn_last_cell_output operators
    else
    {
1395
        for(auto co : ins_outputs)
Shucai Xiao's avatar
Shucai Xiao committed
1396
        {
1397
            prog.replace_instruction(co, last_cell_output);
Shucai Xiao's avatar
Shucai Xiao committed
1398
1399
1400
1401
        }
    }
}

1402
instruction_ref rewrite_rnn::pad_hidden_states(module& prog,
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
                                               instruction_ref seq,
                                               instruction_ref seq_lens,
                                               instruction_ref hs) const
{
    auto max_seq_len = seq->get_shape().lens()[0];
    auto seq_len     = get_seq_len(prog, seq, seq_lens);

    // condition of all sequence are of the same length and
    // less than max_seq_len, we need to append the hs outputs
    auto hs_padded = hs;
    if(seq_len < max_seq_len)
    {
        auto s        = hs->get_shape();
        auto pad_lens = s.lens();
        pad_lens[0]   = static_cast<std::size_t>(max_seq_len - seq_len);
        shape pad_s{s.type(), pad_lens};
        std::vector<float> pad_data(pad_s.elements(), 0.0f);
1420
1421
1422
        auto pl = prog.add_literal(pad_s, pad_data.begin(), pad_data.end());
        hs_padded =
            prog.insert_instruction(std::next(hs), make_op("concat", {{"axis", 0}}), hs, pl);
1423
1424
1425
1426
1427
1428
        prog.replace_instruction(hs, hs_padded);
    }

    return hs_padded;
}

1429
namespace op {
Shucai Xiao's avatar
Shucai Xiao committed
1430
std::ostream& operator<<(std::ostream& os, rnn_direction v)
1431
{
Shucai Xiao's avatar
Shucai Xiao committed
1432
1433
    std::vector<std::string> rnn_direction_str = {"forward", "reverse", "bidirectional"};
    os << rnn_direction_str[static_cast<std::underlying_type<rnn_direction>::type>(v)];
1434
1435
    return os;
}
Shucai Xiao's avatar
Shucai Xiao committed
1436
} // namespace op
1437

Shucai Xiao's avatar
Shucai Xiao committed
1438
1439
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx