rewrite_rnn.cpp 49.7 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
Paul's avatar
Paul committed
4
5
6
7
8
9
10
11
12
13
14
15
16
#include <migraphx/op/add.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/gru.hpp>
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/rnn.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/squeeze.hpp>
#include <migraphx/op/sub.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/unsqueeze.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
17
#include <migraphx/op/contiguous.hpp>
18
19
20
#include <migraphx/op/common.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
21
22
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
23
#include <migraphx/ranges.hpp>
24
#include <migraphx/op/common.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
25
26
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
27
28
29
30

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

Shucai Xiao's avatar
Shucai Xiao committed
31
void rewrite_rnn::apply(program& prog) const
Shucai Xiao's avatar
Shucai Xiao committed
32
33
34
{
    for(auto ins : iterator_for(prog))
    {
Shucai Xiao's avatar
Shucai Xiao committed
35
        if(ins->name() == "rnn")
Shucai Xiao's avatar
Shucai Xiao committed
36
        {
Shucai Xiao's avatar
Shucai Xiao committed
37
            apply_vanilla_rnn(prog, ins);
38
        }
39
        else if(ins->name() == "gru")
40
41
        {
            apply_gru(prog, ins);
Shucai Xiao's avatar
Shucai Xiao committed
42
        }
43
44
45
46
        else if(ins->name() == "lstm")
        {
            apply_lstm(prog, ins);
        }
Shucai Xiao's avatar
Shucai Xiao committed
47
    }
48
49
}

Shucai Xiao's avatar
Shucai Xiao committed
50
void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
51
52
53
54
{
    assert(ins->name() == "rnn");
    // could be 3 to 6 inputs, but the parse_rnn function will
    // append undefined operators to make 6 arguments when parsing
Shucai Xiao's avatar
Shucai Xiao committed
55
    // an onnx file. Another case is user can have num of arguments
56
57
58
59
60
61
62
63
64
65
    // when writing their program.
    auto args = ins->inputs();

    shape seq_shape         = args[0]->get_shape();
    std::size_t hidden_size = args[1]->get_shape().lens()[1];
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
    migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
    std::vector<float> data(ih_shape.elements(), 0);

Shucai Xiao's avatar
Shucai Xiao committed
66
67
    auto actv_funcs         = vanilla_rnn_actv_funcs(ins);
    auto rnn_op             = any_cast<op::rnn>(ins->get_operator());
68
69
70
71
72
73
74
75
76
77
78
    op::rnn_direction dirct = rnn_op.direction;

    // process sequence length
    instruction_ref seq_lens = prog.end();
    if((args.size() >= 5) && args[4]->name() != "undefined")
    {
        seq_lens = args[4];
    }

    bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);

79
    instruction_ref last_output{};
80
    if(dirct == op::rnn_direction::bidirectional)
81
82
83
84
85
86
87
88
89
90
91
92
    {
        // input weight matrix
        auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
        auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);

        // hidden state weight matrix
        auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
        auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);

        // process bias
        instruction_ref bias_forward = prog.end();
        instruction_ref bias_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
93
        if(args.size() >= 4 && args[3]->name() != "undefined")
94
95
96
97
98
99
100
101
102
        {
            bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
            bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
        }

        // process intial hidden state, it could be the 6th argument
        // or the 5th one (if the sequence len argument is ignored)
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
103
        if(args.size() == 6 && args[5]->name() != "undefined")
104
105
106
107
108
109
110
111
112
113
        {
            ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
            ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
        }
        else
        {
            ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
            ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
        }

114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
        auto ret_forward =
            vanilla_rnn_cell(true,
                             prog,
                             ins,
                             {args[0], w_forward, r_forward, bias_forward, seq_lens, ih_forward},
                             actv_funcs.at(0));

        if(variable_seq_len)
        {
            args[0] =
                prog.insert_instruction(ins, op::rnn_var_sl_shift_sequence{}, args[0], seq_lens);
        }

        auto ret_reverse =
            vanilla_rnn_cell(false,
                             prog,
                             ins,
                             {args[0], w_reverse, r_reverse, bias_reverse, seq_lens, ih_reverse},
                             actv_funcs.at(1));
133
134
135
136
137
138
139
140
141
142

        auto concat_output =
            prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
        last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);

        // The following logic is to ensure the last instruction rewritten from
        // rnn operator is a concat instruction
        // sequence len is 1
        if(ret_forward[0] == prog.end())
        {
143
            prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
144
145
146
147
148
149
150
        }
        else
        {
            ret_forward[0] =
                prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
            ret_reverse[0] =
                prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
151
            prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
152
153
154
155
        }
    }
    else
    {
156
        bool is_forward = (dirct == op::rnn_direction::forward);
157
158
159
160
161
162
163
164
        // input weight matrix
        auto w = args[1];

        // hidden state weight matrix
        auto r = args[2];

        // process bias and initial hidden state
        instruction_ref bias = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
165
        if(args.size() >= 4 && args[3]->name() != "undefined")
166
167
168
169
170
171
        {
            bias = args[3];
        }

        // process intial hidden state
        instruction_ref ih;
Shucai Xiao's avatar
Shucai Xiao committed
172
        if(args.size() == 6 && args[5]->name() != "undefined")
173
174
175
176
177
178
179
180
        {
            ih = args[5];
        }
        else
        {
            ih = prog.add_literal(migraphx::literal{ih_shape, data});
        }

181
182
183
184
185
186
187
188
        if(!is_forward and variable_seq_len)
        {
            args[0] =
                prog.insert_instruction(ins, op::rnn_var_sl_shift_sequence{}, args[0], seq_lens);
        }

        auto ret = vanilla_rnn_cell(
            is_forward, prog, ins, {args[0], w, r, bias, seq_lens, ih}, actv_funcs.at(0));
189
190
191
192
193
194
195
        last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);

        // following logic is to ensure the last instruction is a
        // concat instruction
        // sequence len is 1
        if(ret[0] == prog.end())
        {
196
            prog.replace_instruction(ins, op::concat{0}, ret[1]);
197
198
199
200
201
        }
        else
        {
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
202
            prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
203
204
205
        }
    }

206
207
208
209
    // in case of all sequences are of the same lengths and shorter than the
    // max sequence length, need to pad 0's at the end for output hidden states
    ins = pad_hidden_states(prog, args[0], seq_lens, ins);
    replace_last_hs_output(prog, ins, seq_lens, last_output, dirct);
Shucai Xiao's avatar
Shucai Xiao committed
210
211
}

Shucai Xiao's avatar
Shucai Xiao committed
212
std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
213
214
                                                           program& prog,
                                                           instruction_ref ins,
215
                                                           std::vector<instruction_ref> inputs,
Shucai Xiao's avatar
Shucai Xiao committed
216
                                                           operation& actv_func) const
Shucai Xiao's avatar
Shucai Xiao committed
217
{
218
219
220
221
222
223
224
225
    assert(inputs.size() == 6);
    auto seq      = inputs.at(0);
    auto w        = inputs.at(1);
    auto r        = inputs.at(2);
    auto bias     = inputs.at(3);
    auto seq_lens = inputs.at(4);
    auto ih       = inputs.at(5);

Shucai Xiao's avatar
Shucai Xiao committed
226
227
    // squeeze and transpose w
    std::vector<int64_t> perm{1, 0};
Shucai Xiao's avatar
Shucai Xiao committed
228
    auto sw      = prog.insert_instruction(ins, op::squeeze{{0}}, w);
Shucai Xiao's avatar
Shucai Xiao committed
229
    auto tran_sw = prog.insert_instruction(ins, op::transpose{perm}, sw);
Shucai Xiao's avatar
Shucai Xiao committed
230
231

    // squeeze and transpose r
Shucai Xiao's avatar
Shucai Xiao committed
232
    auto sr      = prog.insert_instruction(ins, op::squeeze{{0}}, r);
Shucai Xiao's avatar
Shucai Xiao committed
233
234
235
    auto tran_sr = prog.insert_instruction(ins, op::transpose{perm}, sr);

    // initial hidden state
Shucai Xiao's avatar
Shucai Xiao committed
236
    auto sih      = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
237
    auto sih_lens = sih->get_shape().lens();
Shucai Xiao's avatar
Shucai Xiao committed
238
239

    // bias
240
    instruction_ref bb{};
Shucai Xiao's avatar
Shucai Xiao committed
241
    if(bias != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
242
    {
243
        long hs    = static_cast<long>(r->get_shape().lens()[2]);
Shucai Xiao's avatar
Shucai Xiao committed
244
245
246
        auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
        auto wb    = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
        auto rb    = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
247
248
        auto wrb   = prog.insert_instruction(ins, op::add{}, wb, rb);
        bb         = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, wrb);
Shucai Xiao's avatar
Shucai Xiao committed
249
250
    }

Shucai Xiao's avatar
Shucai Xiao committed
251
252
    instruction_ref hidden_out = prog.end();
    instruction_ref last_out{};
253
254
255
    last_out     = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);
    long seq_len = static_cast<long>(get_seq_len(prog, seq, seq_lens));
    for(long i = 0; i < seq_len; i++)
Shucai Xiao's avatar
Shucai Xiao committed
256
    {
Shucai Xiao's avatar
Shucai Xiao committed
257
        long seq_index = is_forward ? i : (seq_len - 1 - i);
258
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
Shucai Xiao's avatar
Shucai Xiao committed
259
260
261
262
        auto cont_xt = prog.insert_instruction(ins, op::contiguous{}, xt);
        xt           = prog.insert_instruction(ins, op::squeeze{{0}}, cont_xt);
        auto xt_wi   = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
        auto ht_ri   = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
Shucai Xiao's avatar
Shucai Xiao committed
263
        if(bias != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
264
        {
265
            xt_wi = prog.insert_instruction(ins, op::add{}, xt_wi, bb);
Shucai Xiao's avatar
Shucai Xiao committed
266
        }
267
        auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
Shucai Xiao's avatar
Shucai Xiao committed
268
269

        // apply activation function
Shucai Xiao's avatar
Shucai Xiao committed
270
271
        auto ht = prog.insert_instruction(ins, actv_func, xt_ht);
        sih     = ht;
Shucai Xiao's avatar
Shucai Xiao committed
272

Shucai Xiao's avatar
Shucai Xiao committed
273
274
275
        // add the dimensions of sequence length (axis 0 for sequence length,
        // axis 1 for num_directions
        last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht);
Shucai Xiao's avatar
Shucai Xiao committed
276

Shucai Xiao's avatar
Shucai Xiao committed
277
278
279
        // concatenation for the last last_out is performed in the apply()
        // function to ensure the last instruction is concat, then we have
        // output inserted
Shucai Xiao's avatar
Shucai Xiao committed
280
        if(i < seq_len - 1)
Shucai Xiao's avatar
Shucai Xiao committed
281
        {
Shucai Xiao's avatar
Shucai Xiao committed
282
283
            if(is_forward)
            {
Shucai Xiao's avatar
Shucai Xiao committed
284
285
286
287
                hidden_out =
                    (seq_index == 0)
                        ? last_out
                        : prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out);
Shucai Xiao's avatar
Shucai Xiao committed
288
289
290
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
291
292
293
294
                hidden_out =
                    (seq_index == seq_len - 1)
                        ? last_out
                        : prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out);
Shucai Xiao's avatar
Shucai Xiao committed
295
            }
Shucai Xiao's avatar
Shucai Xiao committed
296
297
298
        }
    }

299
    return {hidden_out, last_out};
Shucai Xiao's avatar
Shucai Xiao committed
300
301
}

Shucai Xiao's avatar
Shucai Xiao committed
302
std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins) const
303
304
{
    auto rnn_op = any_cast<op::rnn>(ins->get_operator());
Shucai Xiao's avatar
Shucai Xiao committed
305
306
307
308
    // could be 3 to 6 inputs, but the parse_gru function will
    // append undefined operators to make 6 arguments when parsing
    // an onnx file. Another case is user can have any num of arguments
    // when writing their program.
309
    if(rnn_op.direction == op::rnn_direction::bidirectional)
310
    {
Shucai Xiao's avatar
Shucai Xiao committed
311
        if(rnn_op.actv_funcs.empty())
312
313
314
315
        {
            // default is tanh
            return {op::tanh{}, op::tanh{}};
        }
Shucai Xiao's avatar
Shucai Xiao committed
316
        else if(rnn_op.actv_funcs.size() == 1)
317
318
319
320
321
322
323
324
325
326
        {
            return {rnn_op.actv_funcs.at(0), rnn_op.actv_funcs.at(0)};
        }
        else
        {
            return rnn_op.actv_funcs;
        }
    }
    else
    {
Shucai Xiao's avatar
Shucai Xiao committed
327
        if(rnn_op.actv_funcs.empty())
328
329
330
331
332
333
334
335
336
337
338
        {
            // default is tanh
            return {op::tanh{}};
        }
        else
        {
            return rnn_op.actv_funcs;
        }
    }
}

339
340
341
342
void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
{
    assert(ins->name() == "gru");
    const auto actv_funcs = gru_actv_funcs(ins);
Shucai Xiao's avatar
Shucai Xiao committed
343
344
345
346
    // could be 3 to 6 inputs, but the parse_gru function will
    // append undefined operators to make 6 arguments when parsing
    // an onnx file. Another case is user can have num of arguments
    // when writing their program.
347
348
349
350
351
352
353
354
355
    auto args = ins->inputs();

    shape seq_shape         = args[0]->get_shape();
    std::size_t hidden_size = args[2]->get_shape().lens()[2];
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
    migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
    std::vector<float> data(ih_shape.elements(), 0.0);

Shucai Xiao's avatar
Shucai Xiao committed
356
    auto gru_op             = any_cast<op::gru>(ins->get_operator());
357
358
359
360
361
362
363
364
365
366
367
    op::rnn_direction dirct = gru_op.direction;

    // process sequence length
    instruction_ref seq_lens = prog.end();
    if((args.size() >= 5) && args[4]->name() != "undefined")
    {
        seq_lens = args[4];
    }

    bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);

368
    instruction_ref last_output{};
369
    if(dirct == op::rnn_direction::bidirectional)
370
371
372
373
374
375
376
377
378
379
380
381
    {
        // w weight matrix
        auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
        auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);

        // r weight matrix
        auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
        auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);

        // bias
        instruction_ref bias_forward = prog.end();
        instruction_ref bias_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
382
        if(args.size() >= 4 && args[3]->name() != "undefined")
383
384
385
386
387
388
389
390
        {
            bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
            bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
        }

        // intial hidden state
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
391
        if(args.size() == 6 && args[5]->name() != "undefined")
392
393
394
395
396
397
398
399
400
401
        {
            ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
            ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
        }
        else
        {
            ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
            ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
        }

402
403
404
405
406
407
408
409
410
411
412
413
414
415
        auto ret_forward =
            gru_cell(true,
                     prog,
                     ins,
                     {args[0], w_forward, r_forward, bias_forward, seq_lens, ih_forward},
                     gru_op.linear_before_reset,
                     actv_funcs.at(0),
                     actv_funcs.at(1));

        if(variable_seq_len)
        {
            args[0] =
                prog.insert_instruction(ins, op::rnn_var_sl_shift_sequence{}, args[0], seq_lens);
        }
Shucai Xiao's avatar
Shucai Xiao committed
416

417
418
419
420
421
422
423
424
        auto ret_reverse =
            gru_cell(false,
                     prog,
                     ins,
                     {args[0], w_reverse, r_reverse, bias_reverse, seq_lens, ih_reverse},
                     gru_op.linear_before_reset,
                     actv_funcs.at(2),
                     actv_funcs.at(3));
425
426
427
428
429
430
431
432
433

        auto concat_output =
            prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
        last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);

        // The following logic is to ensure the last instruction rewritten
        // from gru operator is a concat
        if(ret_forward[0] == prog.end())
        {
Shucai Xiao's avatar
Shucai Xiao committed
434
            prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
435
436
437
438
439
440
441
        }
        else
        {
            ret_forward[0] =
                prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
            ret_reverse[0] =
                prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
Shucai Xiao's avatar
Shucai Xiao committed
442
            prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
443
444
445
446
        }
    }
    else
    {
447
        bool is_forward = (dirct == op::rnn_direction::forward);
448
449
450
451
452
453
        // weight matrix
        auto w = args[1];
        auto r = args[2];

        // bias
        instruction_ref bias = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
454
        if(args.size() >= 4 && args[3]->name() != "undefined")
455
456
457
458
459
460
        {
            bias = args[3];
        }

        // intial hidden state
        instruction_ref ih{};
Shucai Xiao's avatar
Shucai Xiao committed
461
        if(args.size() == 6 && args[5]->name() != "undefined")
462
463
464
465
466
467
468
469
        {
            ih = args[5];
        }
        else
        {
            ih = prog.add_literal(migraphx::literal{ih_shape, data});
        }

470
471
472
473
474
475
        if(!is_forward and variable_seq_len)
        {
            args[0] =
                prog.insert_instruction(ins, op::rnn_var_sl_shift_sequence{}, args[0], seq_lens);
        }

476
477
478
        auto ret = gru_cell(is_forward,
                            prog,
                            ins,
479
                            {args[0], w, r, bias, seq_lens, ih},
480
481
482
483
484
485
486
487
                            gru_op.linear_before_reset,
                            actv_funcs.at(0),
                            actv_funcs.at(1));

        last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);

        if(ret[0] == prog.end())
        {
Shucai Xiao's avatar
Shucai Xiao committed
488
            prog.replace_instruction(ins, op::concat{0}, ret[1]);
489
490
491
492
493
        }
        else
        {
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
Shucai Xiao's avatar
Shucai Xiao committed
494
            prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
495
496
497
        }
    }

498
499
500
501
    // in case of all sequences are of the same lengths and shorter than the
    // max sequence length, need to pad 0's at the end for output hidden states
    ins = pad_hidden_states(prog, args[0], seq_lens, ins);
    replace_last_hs_output(prog, ins, seq_lens, last_output, dirct);
502
503
504
}

std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
505
506
507
508
509
510
                                                   program& prog,
                                                   instruction_ref ins,
                                                   std::vector<instruction_ref> inputs,
                                                   int linear_before_reset,
                                                   const operation& actv_func1,
                                                   const operation& actv_func2) const
511
{
512
513
514
515
516
517
518
    assert(inputs.size() == 6);
    auto seq      = inputs.at(0);
    auto w        = inputs.at(1);
    auto r        = inputs.at(2);
    auto bias     = inputs.at(3);
    auto seq_lens = inputs.at(4);
    auto ih       = inputs.at(5);
519

Shucai Xiao's avatar
Shucai Xiao committed
520
521
    instruction_ref hidden_states = prog.end();
    instruction_ref last_output{};
Shucai Xiao's avatar
Shucai Xiao committed
522
    migraphx::shape seq_shape = seq->get_shape();
Shucai Xiao's avatar
Shucai Xiao committed
523
524
    migraphx::shape r_shape   = r->get_shape();
    long hs                   = static_cast<long>(r_shape.lens()[2]);
525

526
527
528
    migraphx::shape ss(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]});
    std::vector<float> data(ss.elements(), 1.0f);
    auto l1 = prog.add_literal(migraphx::literal{ss, data});
529

530
    // w matrix squeeze to 2-dim and do a transpose
531
    std::vector<int64_t> perm{1, 0};
Shucai Xiao's avatar
Shucai Xiao committed
532
    auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
533
    auto tw = prog.insert_instruction(ins, op::transpose{perm}, sw);
534

535
    // r slide to two part, zr and h
Shucai Xiao's avatar
Shucai Xiao committed
536
537
538
    auto sr   = prog.insert_instruction(ins, op::squeeze{{0}}, r);
    auto rzr  = prog.insert_instruction(ins, op::slice{{0}, {0}, {2 * hs}}, sr);
    auto trzr = prog.insert_instruction(ins, op::transpose{perm}, rzr);
539

Shucai Xiao's avatar
Shucai Xiao committed
540
    auto rh  = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
541
    auto trh = prog.insert_instruction(ins, op::transpose{perm}, rh);
542
543

    // initial states
Shucai Xiao's avatar
Shucai Xiao committed
544
545
    auto sih  = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
    size_t bs = ih->get_shape().lens()[1];
546
547

    // bias
548
549
550
    instruction_ref bwb{};
    instruction_ref brb_zr{};
    instruction_ref brb_h{};
551
552
553
    if(bias != prog.end())
    {
        auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
Shucai Xiao's avatar
Shucai Xiao committed
554
555
        auto wb    = prog.insert_instruction(ins, op::slice{{0}, {0}, {3 * hs}}, sbias);
        bwb = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(3 * hs)}}, wb);
556
557

        auto rb_zr = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {5 * hs}}, sbias);
Shucai Xiao's avatar
Shucai Xiao committed
558
559
560
561
        auto rb_h  = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
        brb_zr     = prog.insert_instruction(
            ins, op::broadcast{1, {bs, static_cast<size_t>(2 * hs)}}, rb_zr);
        brb_h = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(hs)}}, rb_h);
562
563
    }

564
    long seq_len = static_cast<long>(get_seq_len(prog, seq, seq_lens));
565
566
567
568
    for(long i = 0; i < seq_len; i++)
    {
        long seq_index = is_forward ? i : (seq_len - 1 - i);
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
Shucai Xiao's avatar
Shucai Xiao committed
569
570
        auto cont_xt = prog.insert_instruction(ins, op::contiguous{}, xt);
        xt           = prog.insert_instruction(ins, op::squeeze{{0}}, cont_xt);
571

572
573
        auto xt_w    = prog.insert_instruction(ins, op::dot{}, xt, tw);
        auto ih1_rzr = prog.insert_instruction(ins, op::dot{}, sih, trzr);
Shucai Xiao's avatar
Shucai Xiao committed
574
        if(bias != prog.end())
575
        {
576
577
            xt_w    = prog.insert_instruction(ins, op::add{}, xt_w, bwb);
            ih1_rzr = prog.insert_instruction(ins, op::add{}, ih1_rzr, brb_zr);
578
579
580
581
582
583
584
585
586
587
        }

        auto xw_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_w);
        auto xw_r = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, xt_w);
        auto xw_h = prog.insert_instruction(ins, op::slice{{1}, {2 * hs}, {3 * hs}}, xt_w);

        auto hr_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, ih1_rzr);
        auto hr_r = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, ih1_rzr);

        auto xw_hr_z = prog.insert_instruction(ins, op::add{}, xw_z, hr_z);
Shucai Xiao's avatar
Shucai Xiao committed
588
        auto zt      = prog.insert_instruction(ins, actv_func1, xw_hr_z);
589
590

        auto xw_hr_r = prog.insert_instruction(ins, op::add{}, xw_r, hr_r);
Shucai Xiao's avatar
Shucai Xiao committed
591
        auto rt      = prog.insert_instruction(ins, actv_func1, xw_hr_r);
592
593
594
595
596
597

        instruction_ref hr_h{};
        if(linear_before_reset == 0)
        {
            // equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
            auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
598
            hr_h        = prog.insert_instruction(ins, op::dot{}, rt_ht1, trh);
Shucai Xiao's avatar
Shucai Xiao committed
599
            if(bias != prog.end())
600
            {
601
                hr_h = prog.insert_instruction(ins, op::add{}, hr_h, brb_h);
602
            }
603
604
605
        }
        else
        {
606
            // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
607
            auto ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, trh);
Shucai Xiao's avatar
Shucai Xiao committed
608
            if(bias != prog.end())
609
            {
610
                ht1_rh = prog.insert_instruction(ins, op::add{}, ht1_rh, brb_h);
611
            }
Shucai Xiao's avatar
Shucai Xiao committed
612
            hr_h = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
613
614
        }

615
        auto xw_hr_h = prog.insert_instruction(ins, op::add{}, xw_h, hr_h);
Shucai Xiao's avatar
Shucai Xiao committed
616
        auto ht      = prog.insert_instruction(ins, actv_func2, xw_hr_h);
617

618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
        // equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
        auto one_minus_zt    = prog.insert_instruction(ins, op::sub{}, l1, zt);
        auto one_minus_zt_ht = prog.insert_instruction(ins, op::mul{}, one_minus_zt, ht);
        auto zt_ht1          = prog.insert_instruction(ins, op::mul{}, zt, sih);
        sih                  = prog.insert_instruction(ins, op::add{}, one_minus_zt_ht, zt_ht1);
        last_output          = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);

        if(i < seq_len - 1)
        {
            if(is_forward)
            {
                hidden_states =
                    (seq_index == 0)
                        ? last_output
                        : prog.insert_instruction(ins, op::concat{0}, hidden_states, last_output);
            }
            else
            {
                hidden_states =
                    (seq_index == seq_len - 1)
                        ? last_output
                        : prog.insert_instruction(ins, op::concat{0}, last_output, hidden_states);
            }
        }
    }

    return {hidden_states, last_output};
}

std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
{
    auto gru_op = any_cast<op::gru>(ins->get_operator());
    // before rewrite the gru operator, need to ensure
    // we have 4 actv funcs, even though a user does not
    // specifiy any actv func. If less than 4, use the
    // algorithm in parse_gru to make 4 actv functions
654
    if(gru_op.direction == op::rnn_direction::bidirectional)
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
    {
        if(gru_op.actv_funcs.empty())
            return {op::sigmoid{}, op::tanh{}, op::sigmoid{}, op::tanh{}};
        else if(gru_op.actv_funcs.size() == 1)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0)};
        else if(gru_op.actv_funcs.size() == 2)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1)};
        else if(gru_op.actv_funcs.size() == 3)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1),
                    gru_op.actv_funcs.at(2),
                    gru_op.actv_funcs.at(0)};
        else
            return gru_op.actv_funcs;
    }
    else
    {
        if(gru_op.actv_funcs.empty())
            return {op::sigmoid{}, op::tanh{}};
        else if(gru_op.actv_funcs.size() == 1)
            return {gru_op.actv_funcs.at(0), gru_op.actv_funcs.at(0)};
        else
            return gru_op.actv_funcs;
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
687
688
689
690
691
// for lstm operators
void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
{
    assert(ins->name() == "lstm");
    auto args = ins->inputs();
Shucai Xiao's avatar
Shucai Xiao committed
692

Shucai Xiao's avatar
Shucai Xiao committed
693
    shape seq_shape         = args[0]->get_shape();
694
    std::size_t hidden_size = args[2]->get_shape().lens()[2];
Shucai Xiao's avatar
Shucai Xiao committed
695
696
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
Shucai Xiao's avatar
Shucai Xiao committed
697
    migraphx::shape ihc_shape{type, {1, batch_size, hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
698
    std::vector<float> ihc_data(ihc_shape.elements(), 0.0);
Shucai Xiao's avatar
Shucai Xiao committed
699
700
701

    migraphx::shape pph_shape{type, {1, 3 * hidden_size}};

Shucai Xiao's avatar
Shucai Xiao committed
702
703
    auto actv_funcs         = lstm_actv_funcs(ins);
    auto lstm_op            = any_cast<op::lstm>(ins->get_operator());
Shucai Xiao's avatar
Shucai Xiao committed
704
    op::rnn_direction dirct = lstm_op.direction;
Shucai Xiao's avatar
Shucai Xiao committed
705

Shucai Xiao's avatar
Shucai Xiao committed
706
707
708
709
710
711
712
713
714
715
    // process sequence length
    instruction_ref seq_lens = prog.end();
    if((args.size() >= 5) && args[4]->name() != "undefined")
    {
        seq_lens = args[4];
    }

    bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);

    instruction_ref last_hs_output{};
Shucai Xiao's avatar
Shucai Xiao committed
716
    instruction_ref last_cell_output{};
Shucai Xiao's avatar
Shucai Xiao committed
717
718
    instruction_ref hidden_state{};
    instruction_ref cell_outputs{};
Shucai Xiao's avatar
Shucai Xiao committed
719
    if(dirct == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
720
721
722
723
724
725
726
727
728
729
730
731
732
    {
        // input weight matrix
        // input weight matrix
        auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
        auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);

        // hidden state weight matrix
        auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
        auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);

        // process bias
        instruction_ref bias_forward = prog.end();
        instruction_ref bias_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
733
        if(args.size() >= 4 && args[3]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
        {
            bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
            bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
        }

        // process intial hidden state, it is the 6th argument
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
        if(args.size() >= 6 && args[5]->name() != "undefined")
        {
            ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
            ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
        }
        else
        {
            ih_forward = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
            ih_reverse = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
        }

        // process initial cell value
        instruction_ref ic_forward{};
        instruction_ref ic_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
756
        if(args.size() >= 7 && args[6]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
757
758
759
760
761
762
763
764
765
766
767
        {
            ic_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[6]);
            ic_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[6]);
        }
        else
        {
            ic_forward = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
            ic_reverse = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
        }

        // process weight of the peephole
Shucai Xiao's avatar
Shucai Xiao committed
768
769
        instruction_ref pph_forward = prog.end();
        instruction_ref pph_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
770
        if(args.size() == 8 && args[7]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
771
772
773
774
        {
            pph_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[7]);
            pph_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[7]);
        }
Shucai Xiao's avatar
Shucai Xiao committed
775

Shucai Xiao's avatar
Shucai Xiao committed
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
        auto ret_forward = lstm_cell(true,
                                     prog,
                                     ins,
                                     {args[0],
                                      w_forward,
                                      r_forward,
                                      bias_forward,
                                      seq_lens,
                                      ih_forward,
                                      ic_forward,
                                      pph_forward},
                                     actv_funcs.at(0),
                                     actv_funcs.at(1),
                                     actv_funcs.at(2));

        if(variable_seq_len)
        {
            args[0] =
                prog.insert_instruction(ins, op::rnn_var_sl_shift_sequence{}, args[0], seq_lens);
        }
        auto ret_reverse = lstm_cell(false,
                                     prog,
                                     ins,
                                     {args[0],
                                      w_reverse,
                                      r_reverse,
                                      bias_reverse,
                                      seq_lens,
                                      ih_reverse,
                                      ic_reverse,
                                      pph_reverse},
                                     actv_funcs.at(3),
                                     actv_funcs.at(4),
                                     actv_funcs.at(5));

        auto concat_hs_output =
Shucai Xiao's avatar
Shucai Xiao committed
812
            prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
Shucai Xiao's avatar
Shucai Xiao committed
813
814
815
816
        auto concat_cell_output =
            prog.insert_instruction(ins, op::concat{1}, ret_forward[3], ret_reverse[3]);
        last_hs_output   = prog.insert_instruction(ins, op::squeeze{{0}}, concat_hs_output);
        last_cell_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_cell_output);
Shucai Xiao's avatar
Shucai Xiao committed
817
818

        // the following logic is to ensure the last instruction is a concat
Shucai Xiao's avatar
Shucai Xiao committed
819
        if(ret_forward[0] == prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
820
        {
Shucai Xiao's avatar
Shucai Xiao committed
821
            cell_outputs = concat_cell_output;
Shucai Xiao's avatar
Shucai Xiao committed
822
823
824
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
825
            ret_forward[1] =
Shucai Xiao's avatar
Shucai Xiao committed
826
                prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
Shucai Xiao's avatar
Shucai Xiao committed
827
            ret_reverse[1] =
Shucai Xiao's avatar
Shucai Xiao committed
828
                prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
Shucai Xiao's avatar
Shucai Xiao committed
829
830
831
832
833
834
835

            ret_forward[3] =
                prog.insert_instruction(ins, op::concat{0}, ret_forward[2], ret_forward[3]);
            ret_reverse[3] =
                prog.insert_instruction(ins, op::concat{0}, ret_reverse[3], ret_reverse[2]);
            cell_outputs =
                prog.insert_instruction(ins, op::concat{1}, ret_forward[3], ret_reverse[3]);
Shucai Xiao's avatar
Shucai Xiao committed
836
        }
Shucai Xiao's avatar
Shucai Xiao committed
837
838
839

        hidden_state =
            prog.replace_instruction(ins, op::concat{1}, {ret_forward[1], ret_reverse[1]});
Shucai Xiao's avatar
Shucai Xiao committed
840
841
842
    }
    else
    {
Shucai Xiao's avatar
Shucai Xiao committed
843
        bool is_forward = (dirct == op::rnn_direction::forward);
Shucai Xiao's avatar
Shucai Xiao committed
844
845
846
847
848
849
        // weight matrices
        auto w = args[1];
        auto r = args[2];

        // bias
        instruction_ref bias = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
850
        if(args.size() >= 4 && args[3]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
851
852
853
854
855
856
        {
            bias = args[3];
        }

        // initial hidden state
        instruction_ref ih{};
Shucai Xiao's avatar
Shucai Xiao committed
857
        if(args.size() >= 6 && args[5]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
858
859
860
861
862
863
864
865
866
867
        {
            ih = args[5];
        }
        else
        {
            ih = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
        }

        // initial cell value
        instruction_ref ic{};
Shucai Xiao's avatar
Shucai Xiao committed
868
        if(args.size() >= 7 && args[6]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
869
870
871
872
873
874
875
876
877
        {
            ic = args[6];
        }
        else
        {
            ic = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
        }

        // process weight of the peephole
Shucai Xiao's avatar
Shucai Xiao committed
878
        instruction_ref pph = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
879
880
881
882
        if(args.size() == 8 && args[7]->name() != "undefined")
        {
            pph = args[7];
        }
Shucai Xiao's avatar
Shucai Xiao committed
883

Shucai Xiao's avatar
Shucai Xiao committed
884
885
886
887
888
        if(!is_forward and variable_seq_len)
        {
            args[0] =
                prog.insert_instruction(ins, op::rnn_var_sl_shift_sequence{}, args[0], seq_lens);
        }
Shucai Xiao's avatar
Shucai Xiao committed
889
        auto ret = lstm_cell(is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
890
891
                             prog,
                             ins,
Shucai Xiao's avatar
Shucai Xiao committed
892
                             {args[0], w, r, bias, seq_lens, ih, ic, pph},
Shucai Xiao's avatar
Shucai Xiao committed
893
894
895
896
                             actv_funcs.at(0),
                             actv_funcs.at(1),
                             actv_funcs.at(2));

Shucai Xiao's avatar
Shucai Xiao committed
897
898
899
        last_hs_output   = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
        last_cell_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[3]);

Shucai Xiao's avatar
Shucai Xiao committed
900
        if(ret[0] == prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
901
        {
Shucai Xiao's avatar
Shucai Xiao committed
902
903
            cell_outputs = ret[3];
            hidden_state = prog.replace_instruction(ins, op::concat{0}, ret[1]);
Shucai Xiao's avatar
Shucai Xiao committed
904
905
906
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
907
908
909
910
911
            auto concat_cell_arg0 = is_forward ? ret[2] : ret[3];
            auto concat_cell_arg1 = is_forward ? ret[3] : ret[2];
            cell_outputs =
                prog.insert_instruction(ins, op::concat{0}, concat_cell_arg0, concat_cell_arg1);

Shucai Xiao's avatar
Shucai Xiao committed
912
913
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
Shucai Xiao's avatar
Shucai Xiao committed
914
            hidden_state = prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
Shucai Xiao's avatar
Shucai Xiao committed
915
916
917
        }
    }

918
919
920
921
922
    // in case of all sequences are of the same lengths and shorter than the
    // max sequence length, need to pad 0's at the end for output hidden states
    hidden_state = pad_hidden_states(prog, args[0], seq_lens, hidden_state);

    // replace last hidden states with corresponding instructions
Shucai Xiao's avatar
Shucai Xiao committed
923
    ins = replace_last_hs_output(prog, hidden_state, seq_lens, last_hs_output, dirct);
924
925

    // replace last cell outputs with corresponding instructions
Shucai Xiao's avatar
Shucai Xiao committed
926
    replace_last_cell_output(prog, ins, seq_lens, cell_outputs, last_cell_output, dirct);
Shucai Xiao's avatar
Shucai Xiao committed
927
928
929
}

std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
930
931
932
933
934
935
                                                    program& prog,
                                                    instruction_ref ins,
                                                    std::vector<instruction_ref> inputs,
                                                    const operation& actv_func1,
                                                    const operation& actv_func2,
                                                    const operation& actv_func3) const
Shucai Xiao's avatar
Shucai Xiao committed
936
{
Shucai Xiao's avatar
Shucai Xiao committed
937
    // must have 7 args in the input vector
Shucai Xiao's avatar
Shucai Xiao committed
938
939
940
941
942
943
944
945
946
    assert(inputs.size() == 8);
    auto seq      = inputs.at(0);
    auto w        = inputs.at(1);
    auto r        = inputs.at(2);
    auto bias     = inputs.at(3);
    auto seq_lens = inputs.at(4);
    auto ih       = inputs.at(5);
    auto ic       = inputs.at(6);
    auto pph      = inputs.at(7);
Shucai Xiao's avatar
Shucai Xiao committed
947

Shucai Xiao's avatar
Shucai Xiao committed
948
    instruction_ref hidden_states = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
949
950
951
    instruction_ref cell_outputs  = prog.end();

    instruction_ref last_hs_output{};
Shucai Xiao's avatar
Shucai Xiao committed
952
953
    instruction_ref last_cell_output{};

954
955
956
    migraphx::shape r_shape = r->get_shape();
    long hs                 = static_cast<long>(r_shape.lens()[2]);
    auto bs                 = ih->get_shape().lens()[1];
Shucai Xiao's avatar
Shucai Xiao committed
957
958

    std::vector<int64_t> perm{1, 0};
959
    // w matrix, squeeze and transpose
Shucai Xiao's avatar
Shucai Xiao committed
960
    auto sw  = prog.insert_instruction(ins, op::squeeze{{0}}, w);
961
    auto tsw = prog.insert_instruction(ins, op::transpose{perm}, sw);
Shucai Xiao's avatar
Shucai Xiao committed
962

963
    // r matrix, squeeze and transpose
Shucai Xiao's avatar
Shucai Xiao committed
964
    auto sr  = prog.insert_instruction(ins, op::squeeze{{0}}, r);
965
    auto tsr = prog.insert_instruction(ins, op::transpose{perm}, sr);
Shucai Xiao's avatar
Shucai Xiao committed
966

Shucai Xiao's avatar
Shucai Xiao committed
967
968
969
970
    // initial hidden state
    auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);

    // initial cell state
Shucai Xiao's avatar
Shucai Xiao committed
971
    auto sic     = prog.insert_instruction(ins, op::squeeze{{0}}, ic);
972
    auto ic_lens = sic->get_shape().lens();
Shucai Xiao's avatar
Shucai Xiao committed
973

974
    // bias
975
    instruction_ref wrb{};
Shucai Xiao's avatar
Shucai Xiao committed
976
    if(bias != prog.end())
977
    {
978

Shucai Xiao's avatar
Shucai Xiao committed
979
980
981
        auto sbias  = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
        auto ub_wb  = prog.insert_instruction(ins, op::slice{{0}, {0}, {4 * hs}}, sbias);
        auto ub_rb  = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {8 * hs}}, sbias);
982
        auto ub_wrb = prog.insert_instruction(ins, op::add{}, ub_wb, ub_rb);
983

984
985
        wrb = prog.insert_instruction(
            ins, op::broadcast{1, {bs, 4 * static_cast<size_t>(hs)}}, ub_wrb);
986
987
    }

Shucai Xiao's avatar
Shucai Xiao committed
988
    // peep hole
Shucai Xiao's avatar
Shucai Xiao committed
989
990
991
    instruction_ref pphi_brcst{};
    instruction_ref ppho_brcst{};
    instruction_ref pphf_brcst{};
Shucai Xiao's avatar
Shucai Xiao committed
992
    if(pph != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
993
    {
Shucai Xiao's avatar
Shucai Xiao committed
994
995
        auto spph  = prog.insert_instruction(ins, op::squeeze{{0}}, pph);
        auto pphi  = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph);
996
        pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, pphi);
Shucai Xiao's avatar
Shucai Xiao committed
997

Shucai Xiao's avatar
Shucai Xiao committed
998
        auto ppho  = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph);
999
        ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, ppho);
Shucai Xiao's avatar
Shucai Xiao committed
1000

Shucai Xiao's avatar
Shucai Xiao committed
1001
        auto pphf  = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph);
1002
        pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, pphf);
Shucai Xiao's avatar
Shucai Xiao committed
1003
    }
Shucai Xiao's avatar
Shucai Xiao committed
1004

Shucai Xiao's avatar
Shucai Xiao committed
1005
    long seq_len = static_cast<long>(get_seq_len(prog, seq, seq_lens));
Shucai Xiao's avatar
Shucai Xiao committed
1006
    for(long i = 0; i < seq_len; ++i)
Shucai Xiao's avatar
Shucai Xiao committed
1007
1008
1009
    {
        long seq_index = is_forward ? i : (seq_len - 1 - i);
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
Shucai Xiao's avatar
Shucai Xiao committed
1010
1011
        auto cont_xt = prog.insert_instruction(ins, op::contiguous{}, xt);
        xt           = prog.insert_instruction(ins, op::squeeze{{0}}, cont_xt);
Shucai Xiao's avatar
Shucai Xiao committed
1012

1013
1014
        auto xt_tsw  = prog.insert_instruction(ins, op::dot{}, xt, tsw);
        auto sih_tsr = prog.insert_instruction(ins, op::dot{}, sih, tsr);
Shucai Xiao's avatar
Shucai Xiao committed
1015
        auto xt_sih  = prog.insert_instruction(ins, op::add{}, xt_tsw, sih_tsr);
Shucai Xiao's avatar
Shucai Xiao committed
1016
        if(bias != prog.end())
1017
        {
Shucai Xiao's avatar
Shucai Xiao committed
1018
            xt_sih = prog.insert_instruction(ins, op::add{}, xt_sih, wrb);
1019
        }
Shucai Xiao's avatar
Shucai Xiao committed
1020

1021
1022
        auto it_before_actv = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_sih);
        auto ot_before_actv = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, xt_sih);
Shucai Xiao's avatar
Shucai Xiao committed
1023
1024
1025
1026
        auto ft_before_actv =
            prog.insert_instruction(ins, op::slice{{1}, {2 * hs}, {3 * hs}}, xt_sih);
        auto ct_before_actv =
            prog.insert_instruction(ins, op::slice{{1}, {3 * hs}, {4 * hs}}, xt_sih);
1027

Shucai Xiao's avatar
Shucai Xiao committed
1028
        if(pph != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
1029
        {
Shucai Xiao's avatar
Shucai Xiao committed
1030
1031
1032
            auto pphi_ct   = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic);
            it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, pphi_ct);

Shucai Xiao's avatar
Shucai Xiao committed
1033
1034
            auto pphf_ct   = prog.insert_instruction(ins, op::mul{}, pphf_brcst, sic);
            ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, pphf_ct);
Shucai Xiao's avatar
Shucai Xiao committed
1035
        }
Shucai Xiao's avatar
Shucai Xiao committed
1036
        auto it = prog.insert_instruction(ins, actv_func1, it_before_actv);
1037
        auto ft = prog.insert_instruction(ins, actv_func1, ft_before_actv);
1038
        auto ct = prog.insert_instruction(ins, actv_func2, ct_before_actv);
1039
1040

        // equation Ct = ft (.) Ct-1 + it (.) ct
Shucai Xiao's avatar
Shucai Xiao committed
1041
1042
1043
        auto ft_cell = prog.insert_instruction(ins, op::mul{}, ft, sic);
        auto it_ct   = prog.insert_instruction(ins, op::mul{}, it, ct);
        auto cellt   = prog.insert_instruction(ins, op::add{}, ft_cell, it_ct);
1044

Shucai Xiao's avatar
Shucai Xiao committed
1045
        if(pph != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
1046
        {
Shucai Xiao's avatar
Shucai Xiao committed
1047
1048
            auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt);
            ot_before_actv  = prog.insert_instruction(ins, op::add{}, ot_before_actv, ppho_cellt);
Shucai Xiao's avatar
Shucai Xiao committed
1049
        }
1050
1051
1052
1053
        auto ot = prog.insert_instruction(ins, actv_func1, ot_before_actv);

        // Ht = ot (.) h(Ct)
        auto h_cellt = prog.insert_instruction(ins, actv_func3, cellt);
Shucai Xiao's avatar
Shucai Xiao committed
1054
        auto ht      = prog.insert_instruction(ins, op::mul{}, ot, h_cellt);
1055
1056
1057
1058

        sic = cellt;
        sih = ht;

Shucai Xiao's avatar
Shucai Xiao committed
1059
1060
        last_hs_output   = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht);
        last_cell_output = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, cellt);
1061

Shucai Xiao's avatar
Shucai Xiao committed
1062
        if(i < seq_len - 1)
1063
        {
Shucai Xiao's avatar
Shucai Xiao committed
1064
            if(i == 0)
1065
            {
Shucai Xiao's avatar
Shucai Xiao committed
1066
1067
                hidden_states = last_hs_output;
                cell_outputs  = last_cell_output;
1068
1069
1070
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
1071
1072
                auto concat_hs_arg0 = is_forward ? hidden_states : last_hs_output;
                auto concat_hs_arg1 = is_forward ? last_hs_output : hidden_states;
Shucai Xiao's avatar
Shucai Xiao committed
1073
                hidden_states =
Shucai Xiao's avatar
Shucai Xiao committed
1074
1075
1076
1077
1078
1079
                    prog.insert_instruction(ins, op::concat{0}, concat_hs_arg0, concat_hs_arg1);

                auto concat_cell_arg0 = is_forward ? cell_outputs : last_cell_output;
                auto concat_cell_arg1 = is_forward ? last_cell_output : cell_outputs;
                cell_outputs =
                    prog.insert_instruction(ins, op::concat{0}, concat_cell_arg0, concat_cell_arg1);
1080
1081
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
1082
    }
1083

Shucai Xiao's avatar
Shucai Xiao committed
1084
    return {hidden_states, last_hs_output, cell_outputs, last_cell_output};
Shucai Xiao's avatar
Shucai Xiao committed
1085
1086
1087
1088
1089
1090
1091
1092
1093
}

std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
{
    auto lstm_op = any_cast<op::lstm>(ins->get_operator());
    // before rewrite the lstm operator, need to ensure
    // we have 6 actv funcs, even though a user does not
    // specifiy any actv func. If less than 46, use the
    // algorithm in parse_lstm to make 6 actv functions
Shucai Xiao's avatar
Shucai Xiao committed
1094
    const auto& actv_funcs     = lstm_op.actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1095
    std::size_t num_actv_funcs = actv_funcs.size();
Shucai Xiao's avatar
Shucai Xiao committed
1096
    if(lstm_op.direction == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1097
1098
1099
1100
    {
        switch(num_actv_funcs)
        {
        case 0:
Shucai Xiao's avatar
Shucai Xiao committed
1101
            return {op::sigmoid{}, op::tanh{}, op::tanh{}, op::sigmoid{}, op::tanh{}, op::tanh{}};
Shucai Xiao's avatar
Shucai Xiao committed
1102
1103

        case 1:
Shucai Xiao's avatar
Shucai Xiao committed
1104
1105
1106
1107
1108
1109
            return {actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0)};
Shucai Xiao's avatar
Shucai Xiao committed
1110
1111

        case 2:
Shucai Xiao's avatar
Shucai Xiao committed
1112
1113
1114
1115
1116
1117
1118
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(1),
                    actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(1)};

Shucai Xiao's avatar
Shucai Xiao committed
1119
        case 3:
Shucai Xiao's avatar
Shucai Xiao committed
1120
1121
1122
1123
1124
1125
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2)};
Shucai Xiao's avatar
Shucai Xiao committed
1126
1127

        case 4:
Shucai Xiao's avatar
Shucai Xiao committed
1128
1129
1130
1131
1132
1133
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(3),
                    actv_funcs.at(3),
                    actv_funcs.at(3)};
Shucai Xiao's avatar
Shucai Xiao committed
1134
1135

        case 5:
Shucai Xiao's avatar
Shucai Xiao committed
1136
1137
1138
1139
1140
1141
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(3),
                    actv_funcs.at(4),
                    actv_funcs.at(4)};
Shucai Xiao's avatar
Shucai Xiao committed
1142

Shucai Xiao's avatar
Shucai Xiao committed
1143
        default: return actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1144
1145
1146
1147
1148
1149
        }
    }
    else
    {
        switch(num_actv_funcs)
        {
Shucai Xiao's avatar
Shucai Xiao committed
1150
        case 0: return {op::sigmoid{}, op::tanh{}, op::tanh{}};
Shucai Xiao's avatar
Shucai Xiao committed
1151

Shucai Xiao's avatar
Shucai Xiao committed
1152
        case 1: return {actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0)};
Shucai Xiao's avatar
Shucai Xiao committed
1153

Shucai Xiao's avatar
Shucai Xiao committed
1154
1155
1156
        case 2: return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1)};

        default: return actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1157
1158
1159
1160
        }
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
bool rewrite_rnn::is_variable_seq_lens(const program& prog, instruction_ref seq_lens) const
{
    bool is_var_lens = false;
    if(seq_lens != prog.end())
    {
        if(seq_lens->can_eval())
        {
            auto arg_lens = seq_lens->eval();
            std::vector<int64_t> vec_lens;
            arg_lens.visit([&](auto l) { vec_lens.assign(l.begin(), l.end()); });
            int64_t l = 0;
            if(!vec_lens.empty())
            {
                l = vec_lens[0];
            }
            if(!std::all_of(vec_lens.begin(), vec_lens.end(), [&](auto v) { return v == l; }))
            {
                is_var_lens = true;
            }
        }
        else
        {
            is_var_lens = true;
        }
    }

    return is_var_lens;
}

std::size_t
rewrite_rnn::get_seq_len(const program& prog, instruction_ref input, instruction_ref seq_lens) const
{
    bool is_var_lens = is_variable_seq_lens(prog, seq_lens);
    auto input_shape = input->get_shape();
    auto length      = input_shape.lens()[0];
    if(!is_var_lens and seq_lens != prog.end())
    {
        auto arg_len = seq_lens->eval();
        std::vector<std::size_t> vec_lens;
        arg_len.visit([&](auto l) { vec_lens.assign(l.begin(), l.end()); });
        length = vec_lens.empty() ? length : vec_lens[0];
    }

    return length;
}

instruction_ref rewrite_rnn::replace_last_hs_output(program& prog,
                                                    instruction_ref ins,
                                                    instruction_ref seq_lens,
                                                    instruction_ref last_hs_output,
                                                    op::rnn_direction dirct) const
{
    bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);
    instruction_ref result_ins{};
    if(variable_seq_len)
    {
        result_ins = prog.insert_instruction(
            std::next(ins), op::rnn_var_sl_shift_output{"hidden_states", dirct}, ins, seq_lens);
        prog.replace_instruction(ins, result_ins);
1220
1221
        auto hs_outputs = find_all(result_ins->outputs(),
                                   [&](auto i) { return i->name() == "rnn_last_hs_output"; });
Shucai Xiao's avatar
Shucai Xiao committed
1222

1223
        for(auto& hs_out : hs_outputs)
Shucai Xiao's avatar
Shucai Xiao committed
1224
        {
1225
1226
1227
            auto inputs = hs_out->inputs();
            prog.replace_instruction(
                hs_out, op::rnn_var_sl_last_output{dirct}, inputs.front(), seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1228
1229
1230
1231
        }
    }
    else
    {
1232
1233
        auto hs_outputs =
            find_all(ins->outputs(), [&](auto i) { return i->name() == "rnn_last_hs_output"; });
Shucai Xiao's avatar
Shucai Xiao committed
1234

1235
1236
1237
        for(auto& hs_out : hs_outputs)
        {
            prog.replace_instruction(hs_out, last_hs_output);
Shucai Xiao's avatar
Shucai Xiao committed
1238
        }
1239

Shucai Xiao's avatar
Shucai Xiao committed
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
        result_ins = ins;
    }

    return result_ins;
}

void rewrite_rnn::replace_last_cell_output(program& prog,
                                           instruction_ref ins,
                                           instruction_ref seq_lens,
                                           instruction_ref cell_outputs,
                                           instruction_ref last_cell_output,
                                           op::rnn_direction dirct) const
{
    bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);
1254
1255
    auto ins_outputs =
        find_all(ins->outputs(), [&](auto i) { return i->name() == "rnn_last_cell_output"; });
Shucai Xiao's avatar
Shucai Xiao committed
1256
1257
1258

    if(variable_seq_len)
    {
1259
        if(!ins_outputs.empty())
Shucai Xiao's avatar
Shucai Xiao committed
1260
1261
1262
1263
1264
1265
1266
1267
        {
            cell_outputs =
                prog.insert_instruction(std::next(ins),
                                        op::rnn_var_sl_shift_output{"cell_outputs", dirct},
                                        cell_outputs,
                                        seq_lens);
        }

1268
        for(auto co : ins_outputs)
Shucai Xiao's avatar
Shucai Xiao committed
1269
        {
1270
            prog.replace_instruction(co, op::rnn_var_sl_last_output{dirct}, cell_outputs, seq_lens);
Shucai Xiao's avatar
Shucai Xiao committed
1271
1272
1273
1274
1275
1276
        }
    }
    // replace the rnn_last_cell_output with the last_cell_output. The while
    // loop is to handle the case of multiple rnn_last_cell_output operators
    else
    {
1277
        for(auto co : ins_outputs)
Shucai Xiao's avatar
Shucai Xiao committed
1278
        {
1279
            prog.replace_instruction(co, last_cell_output);
Shucai Xiao's avatar
Shucai Xiao committed
1280
1281
1282
1283
        }
    }
}

1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
instruction_ref rewrite_rnn::pad_hidden_states(program& prog,
                                               instruction_ref seq,
                                               instruction_ref seq_lens,
                                               instruction_ref hs) const
{
    auto max_seq_len = seq->get_shape().lens()[0];
    auto seq_len     = get_seq_len(prog, seq, seq_lens);

    // condition of all sequence are of the same length and
    // less than max_seq_len, we need to append the hs outputs
    auto hs_padded = hs;
    if(seq_len < max_seq_len)
    {
        auto s        = hs->get_shape();
        auto pad_lens = s.lens();
        pad_lens[0]   = static_cast<std::size_t>(max_seq_len - seq_len);
        shape pad_s{s.type(), pad_lens};
        std::vector<float> pad_data(pad_s.elements(), 0.0f);
        auto pl   = prog.add_literal(pad_s, pad_data.begin(), pad_data.end());
        hs_padded = prog.insert_instruction(std::next(hs), op::concat{0}, hs, pl);
        prog.replace_instruction(hs, hs_padded);
    }

    return hs_padded;
}

1310
namespace op {
Shucai Xiao's avatar
Shucai Xiao committed
1311
std::ostream& operator<<(std::ostream& os, rnn_direction v)
1312
{
Shucai Xiao's avatar
Shucai Xiao committed
1313
1314
    std::vector<std::string> rnn_direction_str = {"forward", "reverse", "bidirectional"};
    os << rnn_direction_str[static_cast<std::underlying_type<rnn_direction>::type>(v)];
1315
1316
    return os;
}
Shucai Xiao's avatar
Shucai Xiao committed
1317
} // namespace op
1318

Shucai Xiao's avatar
Shucai Xiao committed
1319
1320
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx