rewrite_rnn.cpp 42.2 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
Paul's avatar
Paul committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#include <migraphx/op/add.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/gru.hpp>
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/rnn.hpp>
#include <migraphx/op/rnn_last_output.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/squeeze.hpp>
#include <migraphx/op/sub.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/unsqueeze.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
18
19
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
20
#include <migraphx/op/common.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
21
22
23
24

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

Shucai Xiao's avatar
Shucai Xiao committed
25
void rewrite_rnn::apply(program& prog) const
Shucai Xiao's avatar
Shucai Xiao committed
26
27
28
{
    for(auto ins : iterator_for(prog))
    {
Shucai Xiao's avatar
Shucai Xiao committed
29
        if(ins->name() == "rnn")
Shucai Xiao's avatar
Shucai Xiao committed
30
        {
Shucai Xiao's avatar
Shucai Xiao committed
31
            apply_vanilla_rnn(prog, ins);
32
        }
33
        else if(ins->name() == "gru")
34
35
        {
            apply_gru(prog, ins);
Shucai Xiao's avatar
Shucai Xiao committed
36
        }
37
38
39
40
        else if(ins->name() == "lstm")
        {
            apply_lstm(prog, ins);
        }
Shucai Xiao's avatar
Shucai Xiao committed
41
    }
42
43
}

Shucai Xiao's avatar
Shucai Xiao committed
44
void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
45
46
47
48
{
    assert(ins->name() == "rnn");
    // could be 3 to 6 inputs, but the parse_rnn function will
    // append undefined operators to make 6 arguments when parsing
Shucai Xiao's avatar
Shucai Xiao committed
49
    // an onnx file. Another case is user can have num of arguments
50
51
52
53
54
55
56
57
58
59
    // when writing their program.
    auto args = ins->inputs();

    shape seq_shape         = args[0]->get_shape();
    std::size_t hidden_size = args[1]->get_shape().lens()[1];
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
    migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
    std::vector<float> data(ih_shape.elements(), 0);

Shucai Xiao's avatar
Shucai Xiao committed
60
61
    auto actv_funcs         = vanilla_rnn_actv_funcs(ins);
    auto rnn_op             = any_cast<op::rnn>(ins->get_operator());
62
    op::rnn_direction dicrt = rnn_op.direction;
63
    instruction_ref last_output{};
64
    if(dicrt == op::rnn_direction::bidirectional)
65
66
67
68
69
70
71
72
73
74
75
76
    {
        // input weight matrix
        auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
        auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);

        // hidden state weight matrix
        auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
        auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);

        // process bias
        instruction_ref bias_forward = prog.end();
        instruction_ref bias_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
77
        if(args.size() >= 4 && args[3]->name() != "undefined")
78
79
80
81
82
83
84
85
86
        {
            bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
            bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
        }

        // process intial hidden state, it could be the 6th argument
        // or the 5th one (if the sequence len argument is ignored)
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
87
        if(args.size() == 6 && args[5]->name() != "undefined")
88
89
90
91
92
93
94
95
96
97
        {
            ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
            ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
        }
        else
        {
            ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
            ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
        }

Shucai Xiao's avatar
Shucai Xiao committed
98
        auto ret_forward = vanilla_rnn_cell(true,
Shucai Xiao's avatar
Shucai Xiao committed
99
100
101
102
103
104
105
106
                                            prog,
                                            ins,
                                            args[0],
                                            w_forward,
                                            r_forward,
                                            bias_forward,
                                            ih_forward,
                                            actv_funcs.at(0));
Shucai Xiao's avatar
Shucai Xiao committed
107
        auto ret_reverse = vanilla_rnn_cell(false,
Shucai Xiao's avatar
Shucai Xiao committed
108
109
110
111
112
113
114
115
                                            prog,
                                            ins,
                                            args[0],
                                            w_reverse,
                                            r_reverse,
                                            bias_reverse,
                                            ih_reverse,
                                            actv_funcs.at(1));
116
117
118
119
120
121
122
123
124
125

        auto concat_output =
            prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
        last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);

        // The following logic is to ensure the last instruction rewritten from
        // rnn operator is a concat instruction
        // sequence len is 1
        if(ret_forward[0] == prog.end())
        {
126
            prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
127
128
129
130
131
132
133
        }
        else
        {
            ret_forward[0] =
                prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
            ret_reverse[0] =
                prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
134
            prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
135
136
137
138
        }
    }
    else
    {
139
        bool is_forward = (dicrt == op::rnn_direction::forward);
140
141
142
143
144
145
146
147
        // input weight matrix
        auto w = args[1];

        // hidden state weight matrix
        auto r = args[2];

        // process bias and initial hidden state
        instruction_ref bias = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
148
        if(args.size() >= 4 && args[3]->name() != "undefined")
149
150
151
152
153
154
        {
            bias = args[3];
        }

        // process intial hidden state
        instruction_ref ih;
Shucai Xiao's avatar
Shucai Xiao committed
155
        if(args.size() == 6 && args[5]->name() != "undefined")
156
157
158
159
160
161
162
163
        {
            ih = args[5];
        }
        else
        {
            ih = prog.add_literal(migraphx::literal{ih_shape, data});
        }

Shucai Xiao's avatar
Shucai Xiao committed
164
165
        auto ret =
            vanilla_rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0));
166
167
168
169
170
171
172
        last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);

        // following logic is to ensure the last instruction is a
        // concat instruction
        // sequence len is 1
        if(ret[0] == prog.end())
        {
173
            prog.replace_instruction(ins, op::concat{0}, ret[1]);
174
175
176
177
178
        }
        else
        {
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
179
            prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        }
    }

    // search its output to find if there are rnn_last_output operator
    // while loop to handle case of multiple rnn_last_output operators
    auto last_output_it = ins->outputs().begin();
    while(last_output_it != ins->outputs().end())
    {
        last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
            return i->name() == "rnn_last_output";
        });

        if(last_output_it != ins->outputs().end())
        {
            prog.replace_instruction(*last_output_it, last_output);
            last_output_it++;
        }
    }
Shucai Xiao's avatar
Shucai Xiao committed
198
199
}

Shucai Xiao's avatar
Shucai Xiao committed
200
std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
201
202
203
204
205
206
207
208
                                                           program& prog,
                                                           instruction_ref ins,
                                                           instruction_ref input,
                                                           instruction_ref w,
                                                           instruction_ref r,
                                                           instruction_ref bias,
                                                           instruction_ref ih,
                                                           operation& actv_func) const
Shucai Xiao's avatar
Shucai Xiao committed
209
{
Shucai Xiao's avatar
Shucai Xiao committed
210
211
    // squeeze and transpose w
    std::vector<int64_t> perm{1, 0};
Shucai Xiao's avatar
Shucai Xiao committed
212
    auto sw      = prog.insert_instruction(ins, op::squeeze{{0}}, w);
Shucai Xiao's avatar
Shucai Xiao committed
213
    auto tran_sw = prog.insert_instruction(ins, op::transpose{perm}, sw);
Shucai Xiao's avatar
Shucai Xiao committed
214
215

    // squeeze and transpose r
Shucai Xiao's avatar
Shucai Xiao committed
216
    auto sr      = prog.insert_instruction(ins, op::squeeze{{0}}, r);
Shucai Xiao's avatar
Shucai Xiao committed
217
218
219
    auto tran_sr = prog.insert_instruction(ins, op::transpose{perm}, sr);

    // initial hidden state
Shucai Xiao's avatar
Shucai Xiao committed
220
    auto sih      = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
221
    auto sih_lens = sih->get_shape().lens();
Shucai Xiao's avatar
Shucai Xiao committed
222
223

    // bias
224
    instruction_ref bb{};
Shucai Xiao's avatar
Shucai Xiao committed
225
    if(bias != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
226
    {
227
        long hs    = static_cast<long>(r->get_shape().lens()[2]);
Shucai Xiao's avatar
Shucai Xiao committed
228
229
230
        auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
        auto wb    = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
        auto rb    = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
231
232
        auto wrb   = prog.insert_instruction(ins, op::add{}, wb, rb);
        bb         = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, wrb);
Shucai Xiao's avatar
Shucai Xiao committed
233
234
    }

Shucai Xiao's avatar
Shucai Xiao committed
235
236
    instruction_ref hidden_out = prog.end();
    instruction_ref last_out{};
Shucai Xiao's avatar
Shucai Xiao committed
237
238
    last_out            = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);
    std::size_t seq_len = input->get_shape().lens()[0];
Shucai Xiao's avatar
Shucai Xiao committed
239
240
    for(std::size_t i = 0; i < seq_len; i++)
    {
Shucai Xiao's avatar
Shucai Xiao committed
241
        long seq_index = is_forward ? i : (seq_len - 1 - i);
Shucai Xiao's avatar
Shucai Xiao committed
242
243
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
        xt      = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
244
245
        auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
        auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
Shucai Xiao's avatar
Shucai Xiao committed
246
        if(bias != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
247
        {
248
            xt_wi = prog.insert_instruction(ins, op::add{}, xt_wi, bb);
Shucai Xiao's avatar
Shucai Xiao committed
249
        }
250
        auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
Shucai Xiao's avatar
Shucai Xiao committed
251
252

        // apply activation function
Shucai Xiao's avatar
Shucai Xiao committed
253
254
        auto ht = prog.insert_instruction(ins, actv_func, xt_ht);
        sih     = ht;
Shucai Xiao's avatar
Shucai Xiao committed
255

Shucai Xiao's avatar
Shucai Xiao committed
256
257
258
        // add the dimensions of sequence length (axis 0 for sequence length,
        // axis 1 for num_directions
        last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht);
Shucai Xiao's avatar
Shucai Xiao committed
259

Shucai Xiao's avatar
Shucai Xiao committed
260
261
262
        // concatenation for the last last_out is performed in the apply()
        // function to ensure the last instruction is concat, then we have
        // output inserted
Shucai Xiao's avatar
Shucai Xiao committed
263
        if(i < seq_len - 1)
Shucai Xiao's avatar
Shucai Xiao committed
264
        {
Shucai Xiao's avatar
Shucai Xiao committed
265
266
            if(is_forward)
            {
Shucai Xiao's avatar
Shucai Xiao committed
267
268
269
270
                hidden_out =
                    (seq_index == 0)
                        ? last_out
                        : prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out);
Shucai Xiao's avatar
Shucai Xiao committed
271
272
273
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
274
275
276
277
                hidden_out =
                    (seq_index == seq_len - 1)
                        ? last_out
                        : prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out);
Shucai Xiao's avatar
Shucai Xiao committed
278
            }
Shucai Xiao's avatar
Shucai Xiao committed
279
280
281
        }
    }

282
    return {hidden_out, last_out};
Shucai Xiao's avatar
Shucai Xiao committed
283
284
}

Shucai Xiao's avatar
Shucai Xiao committed
285
std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins) const
286
287
{
    auto rnn_op = any_cast<op::rnn>(ins->get_operator());
Shucai Xiao's avatar
Shucai Xiao committed
288
289
290
291
    // could be 3 to 6 inputs, but the parse_gru function will
    // append undefined operators to make 6 arguments when parsing
    // an onnx file. Another case is user can have any num of arguments
    // when writing their program.
292
    if(rnn_op.direction == op::rnn_direction::bidirectional)
293
    {
Shucai Xiao's avatar
Shucai Xiao committed
294
        if(rnn_op.actv_funcs.empty())
295
296
297
298
        {
            // default is tanh
            return {op::tanh{}, op::tanh{}};
        }
Shucai Xiao's avatar
Shucai Xiao committed
299
        else if(rnn_op.actv_funcs.size() == 1)
300
301
302
303
304
305
306
307
308
309
        {
            return {rnn_op.actv_funcs.at(0), rnn_op.actv_funcs.at(0)};
        }
        else
        {
            return rnn_op.actv_funcs;
        }
    }
    else
    {
Shucai Xiao's avatar
Shucai Xiao committed
310
        if(rnn_op.actv_funcs.empty())
311
312
313
314
315
316
317
318
319
320
321
        {
            // default is tanh
            return {op::tanh{}};
        }
        else
        {
            return rnn_op.actv_funcs;
        }
    }
}

322
323
324
325
void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
{
    assert(ins->name() == "gru");
    const auto actv_funcs = gru_actv_funcs(ins);
Shucai Xiao's avatar
Shucai Xiao committed
326
327
328
329
    // could be 3 to 6 inputs, but the parse_gru function will
    // append undefined operators to make 6 arguments when parsing
    // an onnx file. Another case is user can have num of arguments
    // when writing their program.
330
331
332
333
334
335
336
337
338
    auto args = ins->inputs();

    shape seq_shape         = args[0]->get_shape();
    std::size_t hidden_size = args[2]->get_shape().lens()[2];
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
    migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
    std::vector<float> data(ih_shape.elements(), 0.0);

Shucai Xiao's avatar
Shucai Xiao committed
339
    auto gru_op             = any_cast<op::gru>(ins->get_operator());
340
    op::rnn_direction dicrt = gru_op.direction;
341
    instruction_ref last_output{};
342
    if(dicrt == op::rnn_direction::bidirectional)
343
344
345
346
347
348
349
350
351
352
353
354
    {
        // w weight matrix
        auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
        auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);

        // r weight matrix
        auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
        auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);

        // bias
        instruction_ref bias_forward = prog.end();
        instruction_ref bias_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
355
        if(args.size() >= 4 && args[3]->name() != "undefined")
356
357
358
359
360
361
362
363
        {
            bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
            bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
        }

        // intial hidden state
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
364
        if(args.size() == 6 && args[5]->name() != "undefined")
365
366
367
368
369
370
371
372
373
374
        {
            ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
            ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
        }
        else
        {
            ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
            ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
        }

Shucai Xiao's avatar
Shucai Xiao committed
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
        auto ret_forward = gru_cell(true,
                                    prog,
                                    ins,
                                    {args[0], w_forward, r_forward, bias_forward, ih_forward},
                                    gru_op.linear_before_reset,
                                    actv_funcs.at(0),
                                    actv_funcs.at(1));

        auto ret_reverse = gru_cell(false,
                                    prog,
                                    ins,
                                    {args[0], w_reverse, r_reverse, bias_reverse, ih_reverse},
                                    gru_op.linear_before_reset,
                                    actv_funcs.at(2),
                                    actv_funcs.at(3));
390
391
392
393
394
395
396
397
398

        auto concat_output =
            prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
        last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);

        // The following logic is to ensure the last instruction rewritten
        // from gru operator is a concat
        if(ret_forward[0] == prog.end())
        {
Shucai Xiao's avatar
Shucai Xiao committed
399
            prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
400
401
402
403
404
405
406
        }
        else
        {
            ret_forward[0] =
                prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
            ret_reverse[0] =
                prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
Shucai Xiao's avatar
Shucai Xiao committed
407
            prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
408
409
410
411
        }
    }
    else
    {
412
        bool is_forward = (dicrt == op::rnn_direction::forward);
413
414
415
416
417
418
        // weight matrix
        auto w = args[1];
        auto r = args[2];

        // bias
        instruction_ref bias = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
419
        if(args.size() >= 4 && args[3]->name() != "undefined")
420
421
422
423
424
425
        {
            bias = args[3];
        }

        // intial hidden state
        instruction_ref ih{};
Shucai Xiao's avatar
Shucai Xiao committed
426
        if(args.size() == 6 && args[5]->name() != "undefined")
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
        {
            ih = args[5];
        }
        else
        {
            ih = prog.add_literal(migraphx::literal{ih_shape, data});
        }

        auto ret = gru_cell(is_forward,
                            prog,
                            ins,
                            {args[0], w, r, bias, ih},
                            gru_op.linear_before_reset,
                            actv_funcs.at(0),
                            actv_funcs.at(1));

        last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);

        if(ret[0] == prog.end())
        {
Shucai Xiao's avatar
Shucai Xiao committed
447
            prog.replace_instruction(ins, op::concat{0}, ret[1]);
448
449
450
451
452
        }
        else
        {
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
Shucai Xiao's avatar
Shucai Xiao committed
453
            prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
454
455
456
        }
    }

457
458
459
    // replace the corresponding rnn_last_output instruction
    // with the last_output, if rnn_last_output exists
    // while loop to handle case of multiple rnn_last_output operators
460
461
462
463
    auto last_output_it = ins->outputs().begin();
    while(last_output_it != ins->outputs().end())
    {
        last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
464
            return i->name() == "rnn_last_output";
465
466
467
468
469
470
471
472
473
474
475
        });

        if(last_output_it != ins->outputs().end())
        {
            prog.replace_instruction(*last_output_it, last_output);
            last_output_it++;
        }
    }
}

std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
476
477
478
479
480
481
                                                   program& prog,
                                                   instruction_ref ins,
                                                   std::vector<instruction_ref> inputs,
                                                   int linear_before_reset,
                                                   const operation& actv_func1,
                                                   const operation& actv_func2) const
482
483
484
485
486
487
488
489
{
    assert(inputs.size() == 5);
    auto seq  = inputs.at(0);
    auto w    = inputs.at(1);
    auto r    = inputs.at(2);
    auto bias = inputs.at(3);
    auto ih   = inputs.at(4);

Shucai Xiao's avatar
Shucai Xiao committed
490
491
    instruction_ref hidden_states = prog.end();
    instruction_ref last_output{};
Shucai Xiao's avatar
Shucai Xiao committed
492
    migraphx::shape seq_shape = seq->get_shape();
Shucai Xiao's avatar
Shucai Xiao committed
493
494
495
    migraphx::shape r_shape   = r->get_shape();
    long seq_len              = static_cast<long>(seq_shape.lens()[0]);
    long hs                   = static_cast<long>(r_shape.lens()[2]);
496

Shucai Xiao's avatar
Shucai Xiao committed
497
    migraphx::shape s(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]});
498
    std::vector<float> data(s.elements(), 1.0f);
499
500
    auto l1 = prog.add_literal(migraphx::literal{s, data});

501
    // w matrix squeeze to 2-dim and do a transpose
502
    std::vector<int64_t> perm{1, 0};
Shucai Xiao's avatar
Shucai Xiao committed
503
    auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
504
    auto tw = prog.insert_instruction(ins, op::transpose{perm}, sw);
505

506
    // r slide to two part, zr and h
Shucai Xiao's avatar
Shucai Xiao committed
507
508
509
    auto sr   = prog.insert_instruction(ins, op::squeeze{{0}}, r);
    auto rzr  = prog.insert_instruction(ins, op::slice{{0}, {0}, {2 * hs}}, sr);
    auto trzr = prog.insert_instruction(ins, op::transpose{perm}, rzr);
510

Shucai Xiao's avatar
Shucai Xiao committed
511
    auto rh  = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
512
    auto trh = prog.insert_instruction(ins, op::transpose{perm}, rh);
513
514

    // initial states
Shucai Xiao's avatar
Shucai Xiao committed
515
516
    auto sih  = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
    size_t bs = ih->get_shape().lens()[1];
517
518

    // bias
519
520
521
    instruction_ref bwb{};
    instruction_ref brb_zr{};
    instruction_ref brb_h{};
522
523
524
    if(bias != prog.end())
    {
        auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
Shucai Xiao's avatar
Shucai Xiao committed
525
526
        auto wb    = prog.insert_instruction(ins, op::slice{{0}, {0}, {3 * hs}}, sbias);
        bwb = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(3 * hs)}}, wb);
527
528

        auto rb_zr = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {5 * hs}}, sbias);
Shucai Xiao's avatar
Shucai Xiao committed
529
530
531
532
        auto rb_h  = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
        brb_zr     = prog.insert_instruction(
            ins, op::broadcast{1, {bs, static_cast<size_t>(2 * hs)}}, rb_zr);
        brb_h = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(hs)}}, rb_h);
533
534
535
536
537
538
539
540
    }

    for(long i = 0; i < seq_len; i++)
    {
        long seq_index = is_forward ? i : (seq_len - 1 - i);
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
        xt      = prog.insert_instruction(ins, op::squeeze{{0}}, xt);

541
542
        auto xt_w    = prog.insert_instruction(ins, op::dot{}, xt, tw);
        auto ih1_rzr = prog.insert_instruction(ins, op::dot{}, sih, trzr);
Shucai Xiao's avatar
Shucai Xiao committed
543
        if(bias != prog.end())
544
        {
545
546
            xt_w    = prog.insert_instruction(ins, op::add{}, xt_w, bwb);
            ih1_rzr = prog.insert_instruction(ins, op::add{}, ih1_rzr, brb_zr);
547
548
549
550
551
552
553
554
555
556
        }

        auto xw_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_w);
        auto xw_r = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, xt_w);
        auto xw_h = prog.insert_instruction(ins, op::slice{{1}, {2 * hs}, {3 * hs}}, xt_w);

        auto hr_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, ih1_rzr);
        auto hr_r = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, ih1_rzr);

        auto xw_hr_z = prog.insert_instruction(ins, op::add{}, xw_z, hr_z);
Shucai Xiao's avatar
Shucai Xiao committed
557
        auto zt      = prog.insert_instruction(ins, actv_func1, xw_hr_z);
558
559

        auto xw_hr_r = prog.insert_instruction(ins, op::add{}, xw_r, hr_r);
Shucai Xiao's avatar
Shucai Xiao committed
560
        auto rt      = prog.insert_instruction(ins, actv_func1, xw_hr_r);
561
562
563
564
565
566

        instruction_ref hr_h{};
        if(linear_before_reset == 0)
        {
            // equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
            auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
Shucai Xiao's avatar
Shucai Xiao committed
567
            if(bias != prog.end())
568
            {
Shucai Xiao's avatar
Shucai Xiao committed
569
                hr_h = prog.insert_instruction(ins, op::dot{}, rt_ht1, trh, brb_h);
570
            }
571
572
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
573
                hr_h = prog.insert_instruction(ins, op::dot{}, rt_ht1, trh);
574
            }
575
576
577
        }
        else
        {
578
579
            // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
            instruction_ref ht1_rh{};
Shucai Xiao's avatar
Shucai Xiao committed
580
            if(bias != prog.end())
581
            {
582
                ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, trh, brb_h);
583
            }
584
            else
585
            {
Shucai Xiao's avatar
Shucai Xiao committed
586
                ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, trh);
587
            }
Shucai Xiao's avatar
Shucai Xiao committed
588
            hr_h = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
589
590
        }

591
        auto xw_hr_h = prog.insert_instruction(ins, op::add{}, xw_h, hr_h);
Shucai Xiao's avatar
Shucai Xiao committed
592
        auto ht      = prog.insert_instruction(ins, actv_func2, xw_hr_h);
593

594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
        // equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
        auto one_minus_zt    = prog.insert_instruction(ins, op::sub{}, l1, zt);
        auto one_minus_zt_ht = prog.insert_instruction(ins, op::mul{}, one_minus_zt, ht);
        auto zt_ht1          = prog.insert_instruction(ins, op::mul{}, zt, sih);
        sih                  = prog.insert_instruction(ins, op::add{}, one_minus_zt_ht, zt_ht1);
        last_output          = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);

        if(i < seq_len - 1)
        {
            if(is_forward)
            {
                hidden_states =
                    (seq_index == 0)
                        ? last_output
                        : prog.insert_instruction(ins, op::concat{0}, hidden_states, last_output);
            }
            else
            {
                hidden_states =
                    (seq_index == seq_len - 1)
                        ? last_output
                        : prog.insert_instruction(ins, op::concat{0}, last_output, hidden_states);
            }
        }
    }

    return {hidden_states, last_output};
}

std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
{
    auto gru_op = any_cast<op::gru>(ins->get_operator());
    // before rewrite the gru operator, need to ensure
    // we have 4 actv funcs, even though a user does not
    // specifiy any actv func. If less than 4, use the
    // algorithm in parse_gru to make 4 actv functions
630
    if(gru_op.direction == op::rnn_direction::bidirectional)
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
    {
        if(gru_op.actv_funcs.empty())
            return {op::sigmoid{}, op::tanh{}, op::sigmoid{}, op::tanh{}};
        else if(gru_op.actv_funcs.size() == 1)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0)};
        else if(gru_op.actv_funcs.size() == 2)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1)};
        else if(gru_op.actv_funcs.size() == 3)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1),
                    gru_op.actv_funcs.at(2),
                    gru_op.actv_funcs.at(0)};
        else
            return gru_op.actv_funcs;
    }
    else
    {
        if(gru_op.actv_funcs.empty())
            return {op::sigmoid{}, op::tanh{}};
        else if(gru_op.actv_funcs.size() == 1)
            return {gru_op.actv_funcs.at(0), gru_op.actv_funcs.at(0)};
        else
            return gru_op.actv_funcs;
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
663
664
665
666
667
// for lstm operators
void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
{
    assert(ins->name() == "lstm");
    auto args = ins->inputs();
Shucai Xiao's avatar
Shucai Xiao committed
668

Shucai Xiao's avatar
Shucai Xiao committed
669
    shape seq_shape         = args[0]->get_shape();
670
    std::size_t hidden_size = args[2]->get_shape().lens()[2];
Shucai Xiao's avatar
Shucai Xiao committed
671
672
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
Shucai Xiao's avatar
Shucai Xiao committed
673
    migraphx::shape ihc_shape{type, {1, batch_size, hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
674
    std::vector<float> ihc_data(ihc_shape.elements(), 0.0);
Shucai Xiao's avatar
Shucai Xiao committed
675
676

    migraphx::shape pph_shape{type, {1, 3 * hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
677
    std::vector<float> pph_data(pph_shape.elements(), 0.0);
Shucai Xiao's avatar
Shucai Xiao committed
678

Shucai Xiao's avatar
Shucai Xiao committed
679
680
    auto actv_funcs         = lstm_actv_funcs(ins);
    auto lstm_op            = any_cast<op::lstm>(ins->get_operator());
Shucai Xiao's avatar
Shucai Xiao committed
681
    op::rnn_direction dirct = lstm_op.direction;
Shucai Xiao's avatar
Shucai Xiao committed
682
683
684

    instruction_ref last_output{};
    instruction_ref last_cell_output{};
Shucai Xiao's avatar
Shucai Xiao committed
685
    if(dirct == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
686
687
688
689
690
691
692
693
694
695
696
697
698
    {
        // input weight matrix
        // input weight matrix
        auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
        auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);

        // hidden state weight matrix
        auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
        auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);

        // process bias
        instruction_ref bias_forward = prog.end();
        instruction_ref bias_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
699
        if(args.size() >= 4 && args[3]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
        {
            bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
            bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
        }

        // process intial hidden state, it is the 6th argument
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
        if(args.size() >= 6 && args[5]->name() != "undefined")
        {
            ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
            ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
        }
        else
        {
            ih_forward = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
            ih_reverse = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
        }

        // process initial cell value
        instruction_ref ic_forward{};
        instruction_ref ic_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
722
        if(args.size() >= 7 && args[6]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
723
724
725
726
727
728
729
730
731
732
733
        {
            ic_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[6]);
            ic_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[6]);
        }
        else
        {
            ic_forward = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
            ic_reverse = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
        }

        // process weight of the peephole
Shucai Xiao's avatar
Shucai Xiao committed
734
735
        instruction_ref pph_forward = prog.end();
        instruction_ref pph_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
736
        if(args.size() == 8 && args[7]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
737
738
739
740
        {
            pph_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[7]);
            pph_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[7]);
        }
Shucai Xiao's avatar
Shucai Xiao committed
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761

        auto ret_forward = lstm_cell(
            true,
            prog,
            ins,
            {args[0], w_forward, r_forward, bias_forward, ih_forward, ic_forward, pph_forward},
            actv_funcs.at(0),
            actv_funcs.at(1),
            actv_funcs.at(2));

        auto ret_reverse = lstm_cell(
            false,
            prog,
            ins,
            {args[0], w_reverse, r_reverse, bias_reverse, ih_reverse, ic_reverse, pph_reverse},
            actv_funcs.at(3),
            actv_funcs.at(4),
            actv_funcs.at(5));

        auto concat_output =
            prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
Shucai Xiao's avatar
Shucai Xiao committed
762
763
764
        last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);

        // last cell output
765
766
        last_cell_output =
            prog.insert_instruction(ins, op::concat{0}, ret_forward[2], ret_reverse[2]);
Shucai Xiao's avatar
Shucai Xiao committed
767
768

        // the following logic is to ensure the last instruction is a concat
Shucai Xiao's avatar
Shucai Xiao committed
769
        if(ret_forward[0] == prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
770
771
772
773
774
        {
            prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
775
776
777
778
779
            ret_forward[0] =
                prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
            ret_reverse[0] =
                prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
            prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
Shucai Xiao's avatar
Shucai Xiao committed
780
781
782
783
        }
    }
    else
    {
Shucai Xiao's avatar
Shucai Xiao committed
784
        bool is_forward = (dirct == op::rnn_direction::forward);
Shucai Xiao's avatar
Shucai Xiao committed
785
786
787
788
789
790
        // weight matrices
        auto w = args[1];
        auto r = args[2];

        // bias
        instruction_ref bias = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
791
        if(args.size() >= 4 && args[3]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
792
793
794
795
796
797
        {
            bias = args[3];
        }

        // initial hidden state
        instruction_ref ih{};
Shucai Xiao's avatar
Shucai Xiao committed
798
        if(args.size() >= 6 && args[5]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
799
800
801
802
803
804
805
806
807
808
        {
            ih = args[5];
        }
        else
        {
            ih = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
        }

        // initial cell value
        instruction_ref ic{};
Shucai Xiao's avatar
Shucai Xiao committed
809
        if(args.size() >= 7 && args[6]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
810
811
812
813
814
815
816
817
818
        {
            ic = args[6];
        }
        else
        {
            ic = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
        }

        // process weight of the peephole
Shucai Xiao's avatar
Shucai Xiao committed
819
        instruction_ref pph = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
820
821
822
823
        if(args.size() == 8 && args[7]->name() != "undefined")
        {
            pph = args[7];
        }
Shucai Xiao's avatar
Shucai Xiao committed
824
825

        auto ret = lstm_cell(is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
826
827
828
829
830
831
832
                             prog,
                             ins,
                             {args[0], w, r, bias, ih, ic, pph},
                             actv_funcs.at(0),
                             actv_funcs.at(1),
                             actv_funcs.at(2));

Shucai Xiao's avatar
Shucai Xiao committed
833
        last_output      = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
834
        last_cell_output = ret[2];
Shucai Xiao's avatar
Shucai Xiao committed
835
        if(ret[0] == prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
836
837
838
839
840
841
842
843
844
845
846
847
848
        {
            prog.replace_instruction(ins, op::concat{0}, ret[1]);
        }
        else
        {
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
            prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
        }
    }

    // replace the corresponding lstm_last_output instruction
    // with the last_output, and the lstm_last_cell_output with
Shucai Xiao's avatar
Shucai Xiao committed
849
    // the last_cell_output. The while loop is to handle the case
Shucai Xiao's avatar
Shucai Xiao committed
850
851
852
853
854
855
    // of multiple lstm_last_output and lstm_last_cell_output
    // operators
    auto last_output_it = ins->outputs().begin();
    while(last_output_it != ins->outputs().end())
    {
        last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
856
            return i->name() == "rnn_last_output";
Shucai Xiao's avatar
Shucai Xiao committed
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
        });

        if(last_output_it != ins->outputs().end())
        {
            prog.replace_instruction(*last_output_it, last_output);
            last_output_it++;
        }
    }

    auto last_cell_output_it = ins->outputs().begin();
    while(last_cell_output_it != ins->outputs().end())
    {
        last_cell_output_it = std::find_if(last_cell_output_it, ins->outputs().end(), [](auto i) {
            return i->name() == "lstm_last_cell_output";
        });

        if(last_cell_output_it != ins->outputs().end())
        {
            prog.replace_instruction(*last_cell_output_it, last_cell_output);
            last_cell_output_it++;
        }
Shucai Xiao's avatar
Shucai Xiao committed
878
    }
Shucai Xiao's avatar
Shucai Xiao committed
879
880
881
}

std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
882
883
884
885
886
887
                                                    program& prog,
                                                    instruction_ref ins,
                                                    std::vector<instruction_ref> inputs,
                                                    const operation& actv_func1,
                                                    const operation& actv_func2,
                                                    const operation& actv_func3) const
Shucai Xiao's avatar
Shucai Xiao committed
888
{
Shucai Xiao's avatar
Shucai Xiao committed
889
890
    // must have 7 args in the input vector
    assert(inputs.size() == 7);
Shucai Xiao's avatar
Shucai Xiao committed
891
892
893
    auto seq  = inputs.at(0);
    auto w    = inputs.at(1);
    auto r    = inputs.at(2);
Shucai Xiao's avatar
Shucai Xiao committed
894
    auto bias = inputs.at(3);
Shucai Xiao's avatar
Shucai Xiao committed
895
896
897
    auto ih   = inputs.at(4);
    auto ic   = inputs.at(5);
    auto pph  = inputs.at(6);
Shucai Xiao's avatar
Shucai Xiao committed
898

Shucai Xiao's avatar
Shucai Xiao committed
899
900
901
902
903
    instruction_ref hidden_states = prog.end();
    instruction_ref last_output{};
    instruction_ref last_cell_output{};

    migraphx::shape seq_shape = seq->get_shape();
Shucai Xiao's avatar
Shucai Xiao committed
904
    migraphx::shape r_shape   = r->get_shape();
Shucai Xiao's avatar
Shucai Xiao committed
905
    long seq_len              = static_cast<long>(seq_shape.lens()[0]);
Shucai Xiao's avatar
Shucai Xiao committed
906
    long hs                   = static_cast<long>(r_shape.lens()[2]);
Shucai Xiao's avatar
Shucai Xiao committed
907
    auto bs                   = ih->get_shape().lens()[1];
Shucai Xiao's avatar
Shucai Xiao committed
908
909

    std::vector<int64_t> perm{1, 0};
910
    // w matrix, squeeze and transpose
Shucai Xiao's avatar
Shucai Xiao committed
911
    auto sw  = prog.insert_instruction(ins, op::squeeze{{0}}, w);
912
    auto tsw = prog.insert_instruction(ins, op::transpose{perm}, sw);
Shucai Xiao's avatar
Shucai Xiao committed
913

914
    // r matrix, squeeze and transpose
Shucai Xiao's avatar
Shucai Xiao committed
915
    auto sr  = prog.insert_instruction(ins, op::squeeze{{0}}, r);
916
    auto tsr = prog.insert_instruction(ins, op::transpose{perm}, sr);
Shucai Xiao's avatar
Shucai Xiao committed
917

Shucai Xiao's avatar
Shucai Xiao committed
918
919
920
921
    // initial hidden state
    auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);

    // initial cell state
Shucai Xiao's avatar
Shucai Xiao committed
922
    auto sic     = prog.insert_instruction(ins, op::squeeze{{0}}, ic);
923
    auto ic_lens = sic->get_shape().lens();
Shucai Xiao's avatar
Shucai Xiao committed
924

925
    // bias
926
    instruction_ref wrb{};
Shucai Xiao's avatar
Shucai Xiao committed
927
    if(bias != prog.end())
928
    {
929

Shucai Xiao's avatar
Shucai Xiao committed
930
931
932
        auto sbias  = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
        auto ub_wb  = prog.insert_instruction(ins, op::slice{{0}, {0}, {4 * hs}}, sbias);
        auto ub_rb  = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {8 * hs}}, sbias);
933
        auto ub_wrb = prog.insert_instruction(ins, op::add{}, ub_wb, ub_rb);
934

935
936
        wrb = prog.insert_instruction(
            ins, op::broadcast{1, {bs, 4 * static_cast<size_t>(hs)}}, ub_wrb);
937
938
    }

Shucai Xiao's avatar
Shucai Xiao committed
939
    // peep hole
Shucai Xiao's avatar
Shucai Xiao committed
940
941
942
    instruction_ref pphi_brcst{};
    instruction_ref ppho_brcst{};
    instruction_ref pphf_brcst{};
Shucai Xiao's avatar
Shucai Xiao committed
943
    if(pph != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
944
    {
Shucai Xiao's avatar
Shucai Xiao committed
945
946
        auto spph  = prog.insert_instruction(ins, op::squeeze{{0}}, pph);
        auto pphi  = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph);
947
        pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, pphi);
Shucai Xiao's avatar
Shucai Xiao committed
948

Shucai Xiao's avatar
Shucai Xiao committed
949
        auto ppho  = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph);
950
        ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, ppho);
Shucai Xiao's avatar
Shucai Xiao committed
951

Shucai Xiao's avatar
Shucai Xiao committed
952
        auto pphf  = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph);
953
        pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, pphf);
Shucai Xiao's avatar
Shucai Xiao committed
954
    }
Shucai Xiao's avatar
Shucai Xiao committed
955

Shucai Xiao's avatar
Shucai Xiao committed
956
    for(long i = 0; i < seq_len; ++i)
Shucai Xiao's avatar
Shucai Xiao committed
957
958
959
    {
        long seq_index = is_forward ? i : (seq_len - 1 - i);
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
Shucai Xiao's avatar
Shucai Xiao committed
960
        xt      = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
Shucai Xiao's avatar
Shucai Xiao committed
961

962
963
        auto xt_tsw  = prog.insert_instruction(ins, op::dot{}, xt, tsw);
        auto sih_tsr = prog.insert_instruction(ins, op::dot{}, sih, tsr);
Shucai Xiao's avatar
Shucai Xiao committed
964
        auto xt_sih  = prog.insert_instruction(ins, op::add{}, xt_tsw, sih_tsr);
Shucai Xiao's avatar
Shucai Xiao committed
965
        if(bias != prog.end())
966
        {
Shucai Xiao's avatar
Shucai Xiao committed
967
            xt_sih = prog.insert_instruction(ins, op::add{}, xt_sih, wrb);
968
        }
Shucai Xiao's avatar
Shucai Xiao committed
969

970
971
        auto it_before_actv = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_sih);
        auto ot_before_actv = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, xt_sih);
Shucai Xiao's avatar
Shucai Xiao committed
972
973
974
975
        auto ft_before_actv =
            prog.insert_instruction(ins, op::slice{{1}, {2 * hs}, {3 * hs}}, xt_sih);
        auto ct_before_actv =
            prog.insert_instruction(ins, op::slice{{1}, {3 * hs}, {4 * hs}}, xt_sih);
976

Shucai Xiao's avatar
Shucai Xiao committed
977
        if(pph != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
978
        {
Shucai Xiao's avatar
Shucai Xiao committed
979
980
981
            auto pphi_ct   = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic);
            it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, pphi_ct);

Shucai Xiao's avatar
Shucai Xiao committed
982
983
            auto pphf_ct   = prog.insert_instruction(ins, op::mul{}, pphf_brcst, sic);
            ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, pphf_ct);
Shucai Xiao's avatar
Shucai Xiao committed
984
        }
Shucai Xiao's avatar
Shucai Xiao committed
985
        auto it = prog.insert_instruction(ins, actv_func1, it_before_actv);
986
        auto ft = prog.insert_instruction(ins, actv_func1, ft_before_actv);
987
        auto ct = prog.insert_instruction(ins, actv_func2, ct_before_actv);
988
989

        // equation Ct = ft (.) Ct-1 + it (.) ct
Shucai Xiao's avatar
Shucai Xiao committed
990
991
992
        auto ft_cell     = prog.insert_instruction(ins, op::mul{}, ft, sic);
        auto it_ct       = prog.insert_instruction(ins, op::mul{}, it, ct);
        auto cellt       = prog.insert_instruction(ins, op::add{}, ft_cell, it_ct);
993
994
        last_cell_output = cellt;

Shucai Xiao's avatar
Shucai Xiao committed
995
        if(pph != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
996
        {
Shucai Xiao's avatar
Shucai Xiao committed
997
998
            auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt);
            ot_before_actv  = prog.insert_instruction(ins, op::add{}, ot_before_actv, ppho_cellt);
Shucai Xiao's avatar
Shucai Xiao committed
999
        }
1000
1001
1002
1003
        auto ot = prog.insert_instruction(ins, actv_func1, ot_before_actv);

        // Ht = ot (.) h(Ct)
        auto h_cellt = prog.insert_instruction(ins, actv_func3, cellt);
Shucai Xiao's avatar
Shucai Xiao committed
1004
        auto ht      = prog.insert_instruction(ins, op::mul{}, ot, h_cellt);
1005
1006
1007
1008
1009
1010

        sic = cellt;
        sih = ht;

        last_output = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht);

Shucai Xiao's avatar
Shucai Xiao committed
1011
        if(i < seq_len - 1)
1012
        {
Shucai Xiao's avatar
Shucai Xiao committed
1013
            if(i == 0)
1014
            {
Shucai Xiao's avatar
Shucai Xiao committed
1015
                hidden_states = last_output;
1016
1017
1018
1019
1020
            }
            else
            {
                auto concat_arg0 = is_forward ? hidden_states : last_output;
                auto concat_arg1 = is_forward ? last_output : hidden_states;
Shucai Xiao's avatar
Shucai Xiao committed
1021
1022
                hidden_states =
                    prog.insert_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
1023
1024
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
1025
    }
1026
1027
1028
1029

    last_cell_output = prog.insert_instruction(ins, op::unsqueeze{{0}}, last_cell_output);

    return {hidden_states, last_output, last_cell_output};
Shucai Xiao's avatar
Shucai Xiao committed
1030
1031
1032
1033
1034
1035
1036
1037
1038
}

std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
{
    auto lstm_op = any_cast<op::lstm>(ins->get_operator());
    // before rewrite the lstm operator, need to ensure
    // we have 6 actv funcs, even though a user does not
    // specifiy any actv func. If less than 46, use the
    // algorithm in parse_lstm to make 6 actv functions
Shucai Xiao's avatar
Shucai Xiao committed
1039
    const auto& actv_funcs     = lstm_op.actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1040
    std::size_t num_actv_funcs = actv_funcs.size();
Shucai Xiao's avatar
Shucai Xiao committed
1041
    if(lstm_op.direction == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1042
1043
1044
1045
    {
        switch(num_actv_funcs)
        {
        case 0:
Shucai Xiao's avatar
Shucai Xiao committed
1046
            return {op::sigmoid{}, op::tanh{}, op::tanh{}, op::sigmoid{}, op::tanh{}, op::tanh{}};
Shucai Xiao's avatar
Shucai Xiao committed
1047
1048

        case 1:
Shucai Xiao's avatar
Shucai Xiao committed
1049
1050
1051
1052
1053
1054
            return {actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0)};
Shucai Xiao's avatar
Shucai Xiao committed
1055
1056

        case 2:
Shucai Xiao's avatar
Shucai Xiao committed
1057
1058
1059
1060
1061
1062
1063
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(1),
                    actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(1)};

Shucai Xiao's avatar
Shucai Xiao committed
1064
        case 3:
Shucai Xiao's avatar
Shucai Xiao committed
1065
1066
1067
1068
1069
1070
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2)};
Shucai Xiao's avatar
Shucai Xiao committed
1071
1072

        case 4:
Shucai Xiao's avatar
Shucai Xiao committed
1073
1074
1075
1076
1077
1078
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(3),
                    actv_funcs.at(3),
                    actv_funcs.at(3)};
Shucai Xiao's avatar
Shucai Xiao committed
1079
1080

        case 5:
Shucai Xiao's avatar
Shucai Xiao committed
1081
1082
1083
1084
1085
1086
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(3),
                    actv_funcs.at(4),
                    actv_funcs.at(4)};
Shucai Xiao's avatar
Shucai Xiao committed
1087

Shucai Xiao's avatar
Shucai Xiao committed
1088
        default: return actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1089
1090
1091
1092
1093
1094
        }
    }
    else
    {
        switch(num_actv_funcs)
        {
Shucai Xiao's avatar
Shucai Xiao committed
1095
        case 0: return {op::sigmoid{}, op::tanh{}, op::tanh{}};
Shucai Xiao's avatar
Shucai Xiao committed
1096

Shucai Xiao's avatar
Shucai Xiao committed
1097
        case 1: return {actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0)};
Shucai Xiao's avatar
Shucai Xiao committed
1098

Shucai Xiao's avatar
Shucai Xiao committed
1099
1100
1101
        case 2: return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1)};

        default: return actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1102
1103
1104
1105
        }
    }
}

1106
namespace op {
Shucai Xiao's avatar
Shucai Xiao committed
1107
std::ostream& operator<<(std::ostream& os, rnn_direction v)
1108
{
Shucai Xiao's avatar
Shucai Xiao committed
1109
1110
    std::vector<std::string> rnn_direction_str = {"forward", "reverse", "bidirectional"};
    os << rnn_direction_str[static_cast<std::underlying_type<rnn_direction>::type>(v)];
1111
1112
    return os;
}
Shucai Xiao's avatar
Shucai Xiao committed
1113
} // namespace op
1114

Shucai Xiao's avatar
Shucai Xiao committed
1115
1116
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx