rewrite_rnn.cpp 42.3 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
Paul's avatar
Paul committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#include <migraphx/op/add.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/gru.hpp>
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/rnn.hpp>
#include <migraphx/op/rnn_last_output.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/squeeze.hpp>
#include <migraphx/op/sub.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/unsqueeze.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
18
#include <migraphx/op/contiguous.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
19
20
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
21
#include <migraphx/op/common.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
22
23
24
25

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

Shucai Xiao's avatar
Shucai Xiao committed
26
void rewrite_rnn::apply(program& prog) const
Shucai Xiao's avatar
Shucai Xiao committed
27
28
29
{
    for(auto ins : iterator_for(prog))
    {
Shucai Xiao's avatar
Shucai Xiao committed
30
        if(ins->name() == "rnn")
Shucai Xiao's avatar
Shucai Xiao committed
31
        {
Shucai Xiao's avatar
Shucai Xiao committed
32
            apply_vanilla_rnn(prog, ins);
33
        }
34
        else if(ins->name() == "gru")
35
36
        {
            apply_gru(prog, ins);
Shucai Xiao's avatar
Shucai Xiao committed
37
        }
38
39
40
41
        else if(ins->name() == "lstm")
        {
            apply_lstm(prog, ins);
        }
Shucai Xiao's avatar
Shucai Xiao committed
42
    }
43
44
}

Shucai Xiao's avatar
Shucai Xiao committed
45
void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
46
47
48
49
{
    assert(ins->name() == "rnn");
    // could be 3 to 6 inputs, but the parse_rnn function will
    // append undefined operators to make 6 arguments when parsing
Shucai Xiao's avatar
Shucai Xiao committed
50
    // an onnx file. Another case is user can have num of arguments
51
52
53
54
55
56
57
58
59
60
    // when writing their program.
    auto args = ins->inputs();

    shape seq_shape         = args[0]->get_shape();
    std::size_t hidden_size = args[1]->get_shape().lens()[1];
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
    migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
    std::vector<float> data(ih_shape.elements(), 0);

Shucai Xiao's avatar
Shucai Xiao committed
61
62
    auto actv_funcs         = vanilla_rnn_actv_funcs(ins);
    auto rnn_op             = any_cast<op::rnn>(ins->get_operator());
63
    op::rnn_direction dicrt = rnn_op.direction;
64
    instruction_ref last_output{};
65
    if(dicrt == op::rnn_direction::bidirectional)
66
67
68
69
70
71
72
73
74
75
76
77
    {
        // input weight matrix
        auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
        auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);

        // hidden state weight matrix
        auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
        auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);

        // process bias
        instruction_ref bias_forward = prog.end();
        instruction_ref bias_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
78
        if(args.size() >= 4 && args[3]->name() != "undefined")
79
80
81
82
83
84
85
86
87
        {
            bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
            bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
        }

        // process intial hidden state, it could be the 6th argument
        // or the 5th one (if the sequence len argument is ignored)
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
88
        if(args.size() == 6 && args[5]->name() != "undefined")
89
90
91
92
93
94
95
96
97
98
        {
            ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
            ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
        }
        else
        {
            ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
            ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
        }

Shucai Xiao's avatar
Shucai Xiao committed
99
        auto ret_forward = vanilla_rnn_cell(true,
Shucai Xiao's avatar
Shucai Xiao committed
100
101
102
103
104
105
106
107
                                            prog,
                                            ins,
                                            args[0],
                                            w_forward,
                                            r_forward,
                                            bias_forward,
                                            ih_forward,
                                            actv_funcs.at(0));
Shucai Xiao's avatar
Shucai Xiao committed
108
        auto ret_reverse = vanilla_rnn_cell(false,
Shucai Xiao's avatar
Shucai Xiao committed
109
110
111
112
113
114
115
116
                                            prog,
                                            ins,
                                            args[0],
                                            w_reverse,
                                            r_reverse,
                                            bias_reverse,
                                            ih_reverse,
                                            actv_funcs.at(1));
117
118
119
120
121
122
123
124
125
126

        auto concat_output =
            prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
        last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);

        // The following logic is to ensure the last instruction rewritten from
        // rnn operator is a concat instruction
        // sequence len is 1
        if(ret_forward[0] == prog.end())
        {
127
            prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
128
129
130
131
132
133
134
        }
        else
        {
            ret_forward[0] =
                prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
            ret_reverse[0] =
                prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
135
            prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
136
137
138
139
        }
    }
    else
    {
140
        bool is_forward = (dicrt == op::rnn_direction::forward);
141
142
143
144
145
146
147
148
        // input weight matrix
        auto w = args[1];

        // hidden state weight matrix
        auto r = args[2];

        // process bias and initial hidden state
        instruction_ref bias = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
149
        if(args.size() >= 4 && args[3]->name() != "undefined")
150
151
152
153
154
155
        {
            bias = args[3];
        }

        // process intial hidden state
        instruction_ref ih;
Shucai Xiao's avatar
Shucai Xiao committed
156
        if(args.size() == 6 && args[5]->name() != "undefined")
157
158
159
160
161
162
163
164
        {
            ih = args[5];
        }
        else
        {
            ih = prog.add_literal(migraphx::literal{ih_shape, data});
        }

Shucai Xiao's avatar
Shucai Xiao committed
165
166
        auto ret =
            vanilla_rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0));
167
168
169
170
171
172
173
        last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);

        // following logic is to ensure the last instruction is a
        // concat instruction
        // sequence len is 1
        if(ret[0] == prog.end())
        {
174
            prog.replace_instruction(ins, op::concat{0}, ret[1]);
175
176
177
178
179
        }
        else
        {
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
180
            prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        }
    }

    // search its output to find if there are rnn_last_output operator
    // while loop to handle case of multiple rnn_last_output operators
    auto last_output_it = ins->outputs().begin();
    while(last_output_it != ins->outputs().end())
    {
        last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
            return i->name() == "rnn_last_output";
        });

        if(last_output_it != ins->outputs().end())
        {
            prog.replace_instruction(*last_output_it, last_output);
            last_output_it++;
        }
    }
Shucai Xiao's avatar
Shucai Xiao committed
199
200
}

Shucai Xiao's avatar
Shucai Xiao committed
201
std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
202
203
204
205
206
207
208
209
                                                           program& prog,
                                                           instruction_ref ins,
                                                           instruction_ref input,
                                                           instruction_ref w,
                                                           instruction_ref r,
                                                           instruction_ref bias,
                                                           instruction_ref ih,
                                                           operation& actv_func) const
Shucai Xiao's avatar
Shucai Xiao committed
210
{
Shucai Xiao's avatar
Shucai Xiao committed
211
212
    // squeeze and transpose w
    std::vector<int64_t> perm{1, 0};
Shucai Xiao's avatar
Shucai Xiao committed
213
    auto sw      = prog.insert_instruction(ins, op::squeeze{{0}}, w);
Shucai Xiao's avatar
Shucai Xiao committed
214
    auto tran_sw = prog.insert_instruction(ins, op::transpose{perm}, sw);
Shucai Xiao's avatar
Shucai Xiao committed
215
216

    // squeeze and transpose r
Shucai Xiao's avatar
Shucai Xiao committed
217
    auto sr      = prog.insert_instruction(ins, op::squeeze{{0}}, r);
Shucai Xiao's avatar
Shucai Xiao committed
218
219
220
    auto tran_sr = prog.insert_instruction(ins, op::transpose{perm}, sr);

    // initial hidden state
Shucai Xiao's avatar
Shucai Xiao committed
221
    auto sih      = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
222
    auto sih_lens = sih->get_shape().lens();
Shucai Xiao's avatar
Shucai Xiao committed
223
224

    // bias
225
    instruction_ref bb{};
Shucai Xiao's avatar
Shucai Xiao committed
226
    if(bias != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
227
    {
228
        long hs    = static_cast<long>(r->get_shape().lens()[2]);
Shucai Xiao's avatar
Shucai Xiao committed
229
230
231
        auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
        auto wb    = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
        auto rb    = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
232
233
        auto wrb   = prog.insert_instruction(ins, op::add{}, wb, rb);
        bb         = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, wrb);
Shucai Xiao's avatar
Shucai Xiao committed
234
235
    }

Shucai Xiao's avatar
Shucai Xiao committed
236
237
    instruction_ref hidden_out = prog.end();
    instruction_ref last_out{};
Shucai Xiao's avatar
Shucai Xiao committed
238
239
    last_out            = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);
    std::size_t seq_len = input->get_shape().lens()[0];
Shucai Xiao's avatar
Shucai Xiao committed
240
241
    for(std::size_t i = 0; i < seq_len; i++)
    {
Shucai Xiao's avatar
Shucai Xiao committed
242
        long seq_index = is_forward ? i : (seq_len - 1 - i);
Shucai Xiao's avatar
Shucai Xiao committed
243
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
Shucai Xiao's avatar
Shucai Xiao committed
244
245
246
247
        auto cont_xt = prog.insert_instruction(ins, op::contiguous{}, xt);
        xt           = prog.insert_instruction(ins, op::squeeze{{0}}, cont_xt);
        auto xt_wi   = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
        auto ht_ri   = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
Shucai Xiao's avatar
Shucai Xiao committed
248
        if(bias != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
249
        {
250
            xt_wi = prog.insert_instruction(ins, op::add{}, xt_wi, bb);
Shucai Xiao's avatar
Shucai Xiao committed
251
        }
252
        auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
Shucai Xiao's avatar
Shucai Xiao committed
253
254

        // apply activation function
Shucai Xiao's avatar
Shucai Xiao committed
255
256
        auto ht = prog.insert_instruction(ins, actv_func, xt_ht);
        sih     = ht;
Shucai Xiao's avatar
Shucai Xiao committed
257

Shucai Xiao's avatar
Shucai Xiao committed
258
259
260
        // add the dimensions of sequence length (axis 0 for sequence length,
        // axis 1 for num_directions
        last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht);
Shucai Xiao's avatar
Shucai Xiao committed
261

Shucai Xiao's avatar
Shucai Xiao committed
262
263
264
        // concatenation for the last last_out is performed in the apply()
        // function to ensure the last instruction is concat, then we have
        // output inserted
Shucai Xiao's avatar
Shucai Xiao committed
265
        if(i < seq_len - 1)
Shucai Xiao's avatar
Shucai Xiao committed
266
        {
Shucai Xiao's avatar
Shucai Xiao committed
267
268
            if(is_forward)
            {
Shucai Xiao's avatar
Shucai Xiao committed
269
270
271
272
                hidden_out =
                    (seq_index == 0)
                        ? last_out
                        : prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out);
Shucai Xiao's avatar
Shucai Xiao committed
273
274
275
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
276
277
278
279
                hidden_out =
                    (seq_index == seq_len - 1)
                        ? last_out
                        : prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out);
Shucai Xiao's avatar
Shucai Xiao committed
280
            }
Shucai Xiao's avatar
Shucai Xiao committed
281
282
283
        }
    }

284
    return {hidden_out, last_out};
Shucai Xiao's avatar
Shucai Xiao committed
285
286
}

Shucai Xiao's avatar
Shucai Xiao committed
287
std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins) const
288
289
{
    auto rnn_op = any_cast<op::rnn>(ins->get_operator());
Shucai Xiao's avatar
Shucai Xiao committed
290
291
292
293
    // could be 3 to 6 inputs, but the parse_gru function will
    // append undefined operators to make 6 arguments when parsing
    // an onnx file. Another case is user can have any num of arguments
    // when writing their program.
294
    if(rnn_op.direction == op::rnn_direction::bidirectional)
295
    {
Shucai Xiao's avatar
Shucai Xiao committed
296
        if(rnn_op.actv_funcs.empty())
297
298
299
300
        {
            // default is tanh
            return {op::tanh{}, op::tanh{}};
        }
Shucai Xiao's avatar
Shucai Xiao committed
301
        else if(rnn_op.actv_funcs.size() == 1)
302
303
304
305
306
307
308
309
310
311
        {
            return {rnn_op.actv_funcs.at(0), rnn_op.actv_funcs.at(0)};
        }
        else
        {
            return rnn_op.actv_funcs;
        }
    }
    else
    {
Shucai Xiao's avatar
Shucai Xiao committed
312
        if(rnn_op.actv_funcs.empty())
313
314
315
316
317
318
319
320
321
322
323
        {
            // default is tanh
            return {op::tanh{}};
        }
        else
        {
            return rnn_op.actv_funcs;
        }
    }
}

324
325
326
327
void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
{
    assert(ins->name() == "gru");
    const auto actv_funcs = gru_actv_funcs(ins);
Shucai Xiao's avatar
Shucai Xiao committed
328
329
330
331
    // could be 3 to 6 inputs, but the parse_gru function will
    // append undefined operators to make 6 arguments when parsing
    // an onnx file. Another case is user can have num of arguments
    // when writing their program.
332
333
334
335
336
337
338
339
340
    auto args = ins->inputs();

    shape seq_shape         = args[0]->get_shape();
    std::size_t hidden_size = args[2]->get_shape().lens()[2];
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
    migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
    std::vector<float> data(ih_shape.elements(), 0.0);

Shucai Xiao's avatar
Shucai Xiao committed
341
    auto gru_op             = any_cast<op::gru>(ins->get_operator());
342
    op::rnn_direction dicrt = gru_op.direction;
343
    instruction_ref last_output{};
344
    if(dicrt == op::rnn_direction::bidirectional)
345
346
347
348
349
350
351
352
353
354
355
356
    {
        // w weight matrix
        auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
        auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);

        // r weight matrix
        auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
        auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);

        // bias
        instruction_ref bias_forward = prog.end();
        instruction_ref bias_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
357
        if(args.size() >= 4 && args[3]->name() != "undefined")
358
359
360
361
362
363
364
365
        {
            bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
            bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
        }

        // intial hidden state
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
366
        if(args.size() == 6 && args[5]->name() != "undefined")
367
368
369
370
371
372
373
374
375
376
        {
            ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
            ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
        }
        else
        {
            ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
            ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
        }

Shucai Xiao's avatar
Shucai Xiao committed
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
        auto ret_forward = gru_cell(true,
                                    prog,
                                    ins,
                                    {args[0], w_forward, r_forward, bias_forward, ih_forward},
                                    gru_op.linear_before_reset,
                                    actv_funcs.at(0),
                                    actv_funcs.at(1));

        auto ret_reverse = gru_cell(false,
                                    prog,
                                    ins,
                                    {args[0], w_reverse, r_reverse, bias_reverse, ih_reverse},
                                    gru_op.linear_before_reset,
                                    actv_funcs.at(2),
                                    actv_funcs.at(3));
392
393
394
395
396
397
398
399
400

        auto concat_output =
            prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
        last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);

        // The following logic is to ensure the last instruction rewritten
        // from gru operator is a concat
        if(ret_forward[0] == prog.end())
        {
Shucai Xiao's avatar
Shucai Xiao committed
401
            prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
402
403
404
405
406
407
408
        }
        else
        {
            ret_forward[0] =
                prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
            ret_reverse[0] =
                prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
Shucai Xiao's avatar
Shucai Xiao committed
409
            prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
410
411
412
413
        }
    }
    else
    {
414
        bool is_forward = (dicrt == op::rnn_direction::forward);
415
416
417
418
419
420
        // weight matrix
        auto w = args[1];
        auto r = args[2];

        // bias
        instruction_ref bias = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
421
        if(args.size() >= 4 && args[3]->name() != "undefined")
422
423
424
425
426
427
        {
            bias = args[3];
        }

        // intial hidden state
        instruction_ref ih{};
Shucai Xiao's avatar
Shucai Xiao committed
428
        if(args.size() == 6 && args[5]->name() != "undefined")
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
        {
            ih = args[5];
        }
        else
        {
            ih = prog.add_literal(migraphx::literal{ih_shape, data});
        }

        auto ret = gru_cell(is_forward,
                            prog,
                            ins,
                            {args[0], w, r, bias, ih},
                            gru_op.linear_before_reset,
                            actv_funcs.at(0),
                            actv_funcs.at(1));

        last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);

        if(ret[0] == prog.end())
        {
Shucai Xiao's avatar
Shucai Xiao committed
449
            prog.replace_instruction(ins, op::concat{0}, ret[1]);
450
451
452
453
454
        }
        else
        {
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
Shucai Xiao's avatar
Shucai Xiao committed
455
            prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
456
457
458
        }
    }

459
460
461
    // replace the corresponding rnn_last_output instruction
    // with the last_output, if rnn_last_output exists
    // while loop to handle case of multiple rnn_last_output operators
462
463
464
465
    auto last_output_it = ins->outputs().begin();
    while(last_output_it != ins->outputs().end())
    {
        last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
466
            return i->name() == "rnn_last_output";
467
468
469
470
471
472
473
474
475
476
477
        });

        if(last_output_it != ins->outputs().end())
        {
            prog.replace_instruction(*last_output_it, last_output);
            last_output_it++;
        }
    }
}

std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
478
479
480
481
482
483
                                                   program& prog,
                                                   instruction_ref ins,
                                                   std::vector<instruction_ref> inputs,
                                                   int linear_before_reset,
                                                   const operation& actv_func1,
                                                   const operation& actv_func2) const
484
485
486
487
488
489
490
491
{
    assert(inputs.size() == 5);
    auto seq  = inputs.at(0);
    auto w    = inputs.at(1);
    auto r    = inputs.at(2);
    auto bias = inputs.at(3);
    auto ih   = inputs.at(4);

Shucai Xiao's avatar
Shucai Xiao committed
492
493
    instruction_ref hidden_states = prog.end();
    instruction_ref last_output{};
Shucai Xiao's avatar
Shucai Xiao committed
494
    migraphx::shape seq_shape = seq->get_shape();
Shucai Xiao's avatar
Shucai Xiao committed
495
496
497
    migraphx::shape r_shape   = r->get_shape();
    long seq_len              = static_cast<long>(seq_shape.lens()[0]);
    long hs                   = static_cast<long>(r_shape.lens()[2]);
498

Shucai Xiao's avatar
Shucai Xiao committed
499
    migraphx::shape s(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]});
500
    std::vector<float> data(s.elements(), 1.0f);
501
502
    auto l1 = prog.add_literal(migraphx::literal{s, data});

503
    // w matrix squeeze to 2-dim and do a transpose
504
    std::vector<int64_t> perm{1, 0};
Shucai Xiao's avatar
Shucai Xiao committed
505
    auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
506
    auto tw = prog.insert_instruction(ins, op::transpose{perm}, sw);
507

508
    // r slide to two part, zr and h
Shucai Xiao's avatar
Shucai Xiao committed
509
510
511
    auto sr   = prog.insert_instruction(ins, op::squeeze{{0}}, r);
    auto rzr  = prog.insert_instruction(ins, op::slice{{0}, {0}, {2 * hs}}, sr);
    auto trzr = prog.insert_instruction(ins, op::transpose{perm}, rzr);
512

Shucai Xiao's avatar
Shucai Xiao committed
513
    auto rh  = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
514
    auto trh = prog.insert_instruction(ins, op::transpose{perm}, rh);
515
516

    // initial states
Shucai Xiao's avatar
Shucai Xiao committed
517
518
    auto sih  = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
    size_t bs = ih->get_shape().lens()[1];
519
520

    // bias
521
522
523
    instruction_ref bwb{};
    instruction_ref brb_zr{};
    instruction_ref brb_h{};
524
525
526
    if(bias != prog.end())
    {
        auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
Shucai Xiao's avatar
Shucai Xiao committed
527
528
        auto wb    = prog.insert_instruction(ins, op::slice{{0}, {0}, {3 * hs}}, sbias);
        bwb = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(3 * hs)}}, wb);
529
530

        auto rb_zr = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {5 * hs}}, sbias);
Shucai Xiao's avatar
Shucai Xiao committed
531
532
533
534
        auto rb_h  = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
        brb_zr     = prog.insert_instruction(
            ins, op::broadcast{1, {bs, static_cast<size_t>(2 * hs)}}, rb_zr);
        brb_h = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(hs)}}, rb_h);
535
536
537
538
539
540
    }

    for(long i = 0; i < seq_len; i++)
    {
        long seq_index = is_forward ? i : (seq_len - 1 - i);
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
Shucai Xiao's avatar
Shucai Xiao committed
541
542
        auto cont_xt = prog.insert_instruction(ins, op::contiguous{}, xt);
        xt           = prog.insert_instruction(ins, op::squeeze{{0}}, cont_xt);
543

544
545
        auto xt_w    = prog.insert_instruction(ins, op::dot{}, xt, tw);
        auto ih1_rzr = prog.insert_instruction(ins, op::dot{}, sih, trzr);
Shucai Xiao's avatar
Shucai Xiao committed
546
        if(bias != prog.end())
547
        {
548
549
            xt_w    = prog.insert_instruction(ins, op::add{}, xt_w, bwb);
            ih1_rzr = prog.insert_instruction(ins, op::add{}, ih1_rzr, brb_zr);
550
551
552
553
554
555
556
557
558
559
        }

        auto xw_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_w);
        auto xw_r = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, xt_w);
        auto xw_h = prog.insert_instruction(ins, op::slice{{1}, {2 * hs}, {3 * hs}}, xt_w);

        auto hr_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, ih1_rzr);
        auto hr_r = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, ih1_rzr);

        auto xw_hr_z = prog.insert_instruction(ins, op::add{}, xw_z, hr_z);
Shucai Xiao's avatar
Shucai Xiao committed
560
        auto zt      = prog.insert_instruction(ins, actv_func1, xw_hr_z);
561
562

        auto xw_hr_r = prog.insert_instruction(ins, op::add{}, xw_r, hr_r);
Shucai Xiao's avatar
Shucai Xiao committed
563
        auto rt      = prog.insert_instruction(ins, actv_func1, xw_hr_r);
564
565
566
567
568
569

        instruction_ref hr_h{};
        if(linear_before_reset == 0)
        {
            // equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
            auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
570
            hr_h        = prog.insert_instruction(ins, op::dot{}, rt_ht1, trh);
Shucai Xiao's avatar
Shucai Xiao committed
571
            if(bias != prog.end())
572
            {
573
                hr_h = prog.insert_instruction(ins, op::add{}, hr_h, brb_h);
574
            }
575
576
577
        }
        else
        {
578
            // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
579
            auto ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, trh);
Shucai Xiao's avatar
Shucai Xiao committed
580
            if(bias != prog.end())
581
            {
582
                ht1_rh = prog.insert_instruction(ins, op::add{}, ht1_rh, brb_h);
583
            }
Shucai Xiao's avatar
Shucai Xiao committed
584
            hr_h = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
585
586
        }

587
        auto xw_hr_h = prog.insert_instruction(ins, op::add{}, xw_h, hr_h);
Shucai Xiao's avatar
Shucai Xiao committed
588
        auto ht      = prog.insert_instruction(ins, actv_func2, xw_hr_h);
589

590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
        // equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
        auto one_minus_zt    = prog.insert_instruction(ins, op::sub{}, l1, zt);
        auto one_minus_zt_ht = prog.insert_instruction(ins, op::mul{}, one_minus_zt, ht);
        auto zt_ht1          = prog.insert_instruction(ins, op::mul{}, zt, sih);
        sih                  = prog.insert_instruction(ins, op::add{}, one_minus_zt_ht, zt_ht1);
        last_output          = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);

        if(i < seq_len - 1)
        {
            if(is_forward)
            {
                hidden_states =
                    (seq_index == 0)
                        ? last_output
                        : prog.insert_instruction(ins, op::concat{0}, hidden_states, last_output);
            }
            else
            {
                hidden_states =
                    (seq_index == seq_len - 1)
                        ? last_output
                        : prog.insert_instruction(ins, op::concat{0}, last_output, hidden_states);
            }
        }
    }

    return {hidden_states, last_output};
}

std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
{
    auto gru_op = any_cast<op::gru>(ins->get_operator());
    // before rewrite the gru operator, need to ensure
    // we have 4 actv funcs, even though a user does not
    // specifiy any actv func. If less than 4, use the
    // algorithm in parse_gru to make 4 actv functions
626
    if(gru_op.direction == op::rnn_direction::bidirectional)
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
    {
        if(gru_op.actv_funcs.empty())
            return {op::sigmoid{}, op::tanh{}, op::sigmoid{}, op::tanh{}};
        else if(gru_op.actv_funcs.size() == 1)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0)};
        else if(gru_op.actv_funcs.size() == 2)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1)};
        else if(gru_op.actv_funcs.size() == 3)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1),
                    gru_op.actv_funcs.at(2),
                    gru_op.actv_funcs.at(0)};
        else
            return gru_op.actv_funcs;
    }
    else
    {
        if(gru_op.actv_funcs.empty())
            return {op::sigmoid{}, op::tanh{}};
        else if(gru_op.actv_funcs.size() == 1)
            return {gru_op.actv_funcs.at(0), gru_op.actv_funcs.at(0)};
        else
            return gru_op.actv_funcs;
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
659
660
661
662
663
// for lstm operators
void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
{
    assert(ins->name() == "lstm");
    auto args = ins->inputs();
Shucai Xiao's avatar
Shucai Xiao committed
664

Shucai Xiao's avatar
Shucai Xiao committed
665
    shape seq_shape         = args[0]->get_shape();
666
    std::size_t hidden_size = args[2]->get_shape().lens()[2];
Shucai Xiao's avatar
Shucai Xiao committed
667
668
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
Shucai Xiao's avatar
Shucai Xiao committed
669
    migraphx::shape ihc_shape{type, {1, batch_size, hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
670
    std::vector<float> ihc_data(ihc_shape.elements(), 0.0);
Shucai Xiao's avatar
Shucai Xiao committed
671
672
673

    migraphx::shape pph_shape{type, {1, 3 * hidden_size}};

Shucai Xiao's avatar
Shucai Xiao committed
674
675
    auto actv_funcs         = lstm_actv_funcs(ins);
    auto lstm_op            = any_cast<op::lstm>(ins->get_operator());
Shucai Xiao's avatar
Shucai Xiao committed
676
    op::rnn_direction dirct = lstm_op.direction;
Shucai Xiao's avatar
Shucai Xiao committed
677
678
679

    instruction_ref last_output{};
    instruction_ref last_cell_output{};
Shucai Xiao's avatar
Shucai Xiao committed
680
    if(dirct == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
681
682
683
684
685
686
687
688
689
690
691
692
693
    {
        // input weight matrix
        // input weight matrix
        auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
        auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);

        // hidden state weight matrix
        auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
        auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);

        // process bias
        instruction_ref bias_forward = prog.end();
        instruction_ref bias_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
694
        if(args.size() >= 4 && args[3]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
        {
            bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
            bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
        }

        // process intial hidden state, it is the 6th argument
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
        if(args.size() >= 6 && args[5]->name() != "undefined")
        {
            ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
            ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
        }
        else
        {
            ih_forward = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
            ih_reverse = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
        }

        // process initial cell value
        instruction_ref ic_forward{};
        instruction_ref ic_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
717
        if(args.size() >= 7 && args[6]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
718
719
720
721
722
723
724
725
726
727
728
        {
            ic_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[6]);
            ic_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[6]);
        }
        else
        {
            ic_forward = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
            ic_reverse = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
        }

        // process weight of the peephole
Shucai Xiao's avatar
Shucai Xiao committed
729
730
        instruction_ref pph_forward = prog.end();
        instruction_ref pph_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
731
        if(args.size() == 8 && args[7]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
732
733
734
735
        {
            pph_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[7]);
            pph_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[7]);
        }
Shucai Xiao's avatar
Shucai Xiao committed
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756

        auto ret_forward = lstm_cell(
            true,
            prog,
            ins,
            {args[0], w_forward, r_forward, bias_forward, ih_forward, ic_forward, pph_forward},
            actv_funcs.at(0),
            actv_funcs.at(1),
            actv_funcs.at(2));

        auto ret_reverse = lstm_cell(
            false,
            prog,
            ins,
            {args[0], w_reverse, r_reverse, bias_reverse, ih_reverse, ic_reverse, pph_reverse},
            actv_funcs.at(3),
            actv_funcs.at(4),
            actv_funcs.at(5));

        auto concat_output =
            prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
Shucai Xiao's avatar
Shucai Xiao committed
757
758
759
        last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);

        // last cell output
760
761
        last_cell_output =
            prog.insert_instruction(ins, op::concat{0}, ret_forward[2], ret_reverse[2]);
Shucai Xiao's avatar
Shucai Xiao committed
762
763

        // the following logic is to ensure the last instruction is a concat
Shucai Xiao's avatar
Shucai Xiao committed
764
        if(ret_forward[0] == prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
765
766
767
768
769
        {
            prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
770
771
772
773
774
            ret_forward[0] =
                prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
            ret_reverse[0] =
                prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
            prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
Shucai Xiao's avatar
Shucai Xiao committed
775
776
777
778
        }
    }
    else
    {
Shucai Xiao's avatar
Shucai Xiao committed
779
        bool is_forward = (dirct == op::rnn_direction::forward);
Shucai Xiao's avatar
Shucai Xiao committed
780
781
782
783
784
785
        // weight matrices
        auto w = args[1];
        auto r = args[2];

        // bias
        instruction_ref bias = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
786
        if(args.size() >= 4 && args[3]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
787
788
789
790
791
792
        {
            bias = args[3];
        }

        // initial hidden state
        instruction_ref ih{};
Shucai Xiao's avatar
Shucai Xiao committed
793
        if(args.size() >= 6 && args[5]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
794
795
796
797
798
799
800
801
802
803
        {
            ih = args[5];
        }
        else
        {
            ih = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
        }

        // initial cell value
        instruction_ref ic{};
Shucai Xiao's avatar
Shucai Xiao committed
804
        if(args.size() >= 7 && args[6]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
805
806
807
808
809
810
811
812
813
        {
            ic = args[6];
        }
        else
        {
            ic = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
        }

        // process weight of the peephole
Shucai Xiao's avatar
Shucai Xiao committed
814
        instruction_ref pph = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
815
816
817
818
        if(args.size() == 8 && args[7]->name() != "undefined")
        {
            pph = args[7];
        }
Shucai Xiao's avatar
Shucai Xiao committed
819
820

        auto ret = lstm_cell(is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
821
822
823
824
825
826
827
                             prog,
                             ins,
                             {args[0], w, r, bias, ih, ic, pph},
                             actv_funcs.at(0),
                             actv_funcs.at(1),
                             actv_funcs.at(2));

Shucai Xiao's avatar
Shucai Xiao committed
828
        last_output      = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
829
        last_cell_output = ret[2];
Shucai Xiao's avatar
Shucai Xiao committed
830
        if(ret[0] == prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
831
832
833
834
835
836
837
838
839
840
841
842
843
        {
            prog.replace_instruction(ins, op::concat{0}, ret[1]);
        }
        else
        {
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
            prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
        }
    }

    // replace the corresponding lstm_last_output instruction
    // with the last_output, and the lstm_last_cell_output with
Shucai Xiao's avatar
Shucai Xiao committed
844
    // the last_cell_output. The while loop is to handle the case
Shucai Xiao's avatar
Shucai Xiao committed
845
846
847
848
849
850
    // of multiple lstm_last_output and lstm_last_cell_output
    // operators
    auto last_output_it = ins->outputs().begin();
    while(last_output_it != ins->outputs().end())
    {
        last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
851
            return i->name() == "rnn_last_output";
Shucai Xiao's avatar
Shucai Xiao committed
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
        });

        if(last_output_it != ins->outputs().end())
        {
            prog.replace_instruction(*last_output_it, last_output);
            last_output_it++;
        }
    }

    auto last_cell_output_it = ins->outputs().begin();
    while(last_cell_output_it != ins->outputs().end())
    {
        last_cell_output_it = std::find_if(last_cell_output_it, ins->outputs().end(), [](auto i) {
            return i->name() == "lstm_last_cell_output";
        });

        if(last_cell_output_it != ins->outputs().end())
        {
            prog.replace_instruction(*last_cell_output_it, last_cell_output);
            last_cell_output_it++;
        }
Shucai Xiao's avatar
Shucai Xiao committed
873
    }
Shucai Xiao's avatar
Shucai Xiao committed
874
875
876
}

std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
877
878
879
880
881
882
                                                    program& prog,
                                                    instruction_ref ins,
                                                    std::vector<instruction_ref> inputs,
                                                    const operation& actv_func1,
                                                    const operation& actv_func2,
                                                    const operation& actv_func3) const
Shucai Xiao's avatar
Shucai Xiao committed
883
{
Shucai Xiao's avatar
Shucai Xiao committed
884
885
    // must have 7 args in the input vector
    assert(inputs.size() == 7);
Shucai Xiao's avatar
Shucai Xiao committed
886
887
888
    auto seq  = inputs.at(0);
    auto w    = inputs.at(1);
    auto r    = inputs.at(2);
Shucai Xiao's avatar
Shucai Xiao committed
889
    auto bias = inputs.at(3);
Shucai Xiao's avatar
Shucai Xiao committed
890
891
892
    auto ih   = inputs.at(4);
    auto ic   = inputs.at(5);
    auto pph  = inputs.at(6);
Shucai Xiao's avatar
Shucai Xiao committed
893

Shucai Xiao's avatar
Shucai Xiao committed
894
895
896
897
898
    instruction_ref hidden_states = prog.end();
    instruction_ref last_output{};
    instruction_ref last_cell_output{};

    migraphx::shape seq_shape = seq->get_shape();
Shucai Xiao's avatar
Shucai Xiao committed
899
    migraphx::shape r_shape   = r->get_shape();
Shucai Xiao's avatar
Shucai Xiao committed
900
    long seq_len              = static_cast<long>(seq_shape.lens()[0]);
Shucai Xiao's avatar
Shucai Xiao committed
901
    long hs                   = static_cast<long>(r_shape.lens()[2]);
Shucai Xiao's avatar
Shucai Xiao committed
902
    auto bs                   = ih->get_shape().lens()[1];
Shucai Xiao's avatar
Shucai Xiao committed
903
904

    std::vector<int64_t> perm{1, 0};
905
    // w matrix, squeeze and transpose
Shucai Xiao's avatar
Shucai Xiao committed
906
    auto sw  = prog.insert_instruction(ins, op::squeeze{{0}}, w);
907
    auto tsw = prog.insert_instruction(ins, op::transpose{perm}, sw);
Shucai Xiao's avatar
Shucai Xiao committed
908

909
    // r matrix, squeeze and transpose
Shucai Xiao's avatar
Shucai Xiao committed
910
    auto sr  = prog.insert_instruction(ins, op::squeeze{{0}}, r);
911
    auto tsr = prog.insert_instruction(ins, op::transpose{perm}, sr);
Shucai Xiao's avatar
Shucai Xiao committed
912

Shucai Xiao's avatar
Shucai Xiao committed
913
914
915
916
    // initial hidden state
    auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);

    // initial cell state
Shucai Xiao's avatar
Shucai Xiao committed
917
    auto sic     = prog.insert_instruction(ins, op::squeeze{{0}}, ic);
918
    auto ic_lens = sic->get_shape().lens();
Shucai Xiao's avatar
Shucai Xiao committed
919

920
    // bias
921
    instruction_ref wrb{};
Shucai Xiao's avatar
Shucai Xiao committed
922
    if(bias != prog.end())
923
    {
924

Shucai Xiao's avatar
Shucai Xiao committed
925
926
927
        auto sbias  = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
        auto ub_wb  = prog.insert_instruction(ins, op::slice{{0}, {0}, {4 * hs}}, sbias);
        auto ub_rb  = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {8 * hs}}, sbias);
928
        auto ub_wrb = prog.insert_instruction(ins, op::add{}, ub_wb, ub_rb);
929

930
931
        wrb = prog.insert_instruction(
            ins, op::broadcast{1, {bs, 4 * static_cast<size_t>(hs)}}, ub_wrb);
932
933
    }

Shucai Xiao's avatar
Shucai Xiao committed
934
    // peep hole
Shucai Xiao's avatar
Shucai Xiao committed
935
936
937
    instruction_ref pphi_brcst{};
    instruction_ref ppho_brcst{};
    instruction_ref pphf_brcst{};
Shucai Xiao's avatar
Shucai Xiao committed
938
    if(pph != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
939
    {
Shucai Xiao's avatar
Shucai Xiao committed
940
941
        auto spph  = prog.insert_instruction(ins, op::squeeze{{0}}, pph);
        auto pphi  = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph);
942
        pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, pphi);
Shucai Xiao's avatar
Shucai Xiao committed
943

Shucai Xiao's avatar
Shucai Xiao committed
944
        auto ppho  = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph);
945
        ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, ppho);
Shucai Xiao's avatar
Shucai Xiao committed
946

Shucai Xiao's avatar
Shucai Xiao committed
947
        auto pphf  = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph);
948
        pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, pphf);
Shucai Xiao's avatar
Shucai Xiao committed
949
    }
Shucai Xiao's avatar
Shucai Xiao committed
950

Shucai Xiao's avatar
Shucai Xiao committed
951
    for(long i = 0; i < seq_len; ++i)
Shucai Xiao's avatar
Shucai Xiao committed
952
953
954
    {
        long seq_index = is_forward ? i : (seq_len - 1 - i);
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
Shucai Xiao's avatar
Shucai Xiao committed
955
956
        auto cont_xt = prog.insert_instruction(ins, op::contiguous{}, xt);
        xt           = prog.insert_instruction(ins, op::squeeze{{0}}, cont_xt);
Shucai Xiao's avatar
Shucai Xiao committed
957

958
959
        auto xt_tsw  = prog.insert_instruction(ins, op::dot{}, xt, tsw);
        auto sih_tsr = prog.insert_instruction(ins, op::dot{}, sih, tsr);
Shucai Xiao's avatar
Shucai Xiao committed
960
        auto xt_sih  = prog.insert_instruction(ins, op::add{}, xt_tsw, sih_tsr);
Shucai Xiao's avatar
Shucai Xiao committed
961
        if(bias != prog.end())
962
        {
Shucai Xiao's avatar
Shucai Xiao committed
963
            xt_sih = prog.insert_instruction(ins, op::add{}, xt_sih, wrb);
964
        }
Shucai Xiao's avatar
Shucai Xiao committed
965

966
967
        auto it_before_actv = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_sih);
        auto ot_before_actv = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, xt_sih);
Shucai Xiao's avatar
Shucai Xiao committed
968
969
970
971
        auto ft_before_actv =
            prog.insert_instruction(ins, op::slice{{1}, {2 * hs}, {3 * hs}}, xt_sih);
        auto ct_before_actv =
            prog.insert_instruction(ins, op::slice{{1}, {3 * hs}, {4 * hs}}, xt_sih);
972

Shucai Xiao's avatar
Shucai Xiao committed
973
        if(pph != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
974
        {
Shucai Xiao's avatar
Shucai Xiao committed
975
976
977
            auto pphi_ct   = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic);
            it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, pphi_ct);

Shucai Xiao's avatar
Shucai Xiao committed
978
979
            auto pphf_ct   = prog.insert_instruction(ins, op::mul{}, pphf_brcst, sic);
            ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, pphf_ct);
Shucai Xiao's avatar
Shucai Xiao committed
980
        }
Shucai Xiao's avatar
Shucai Xiao committed
981
        auto it = prog.insert_instruction(ins, actv_func1, it_before_actv);
982
        auto ft = prog.insert_instruction(ins, actv_func1, ft_before_actv);
983
        auto ct = prog.insert_instruction(ins, actv_func2, ct_before_actv);
984
985

        // equation Ct = ft (.) Ct-1 + it (.) ct
Shucai Xiao's avatar
Shucai Xiao committed
986
987
988
        auto ft_cell     = prog.insert_instruction(ins, op::mul{}, ft, sic);
        auto it_ct       = prog.insert_instruction(ins, op::mul{}, it, ct);
        auto cellt       = prog.insert_instruction(ins, op::add{}, ft_cell, it_ct);
989
990
        last_cell_output = cellt;

Shucai Xiao's avatar
Shucai Xiao committed
991
        if(pph != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
992
        {
Shucai Xiao's avatar
Shucai Xiao committed
993
994
            auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt);
            ot_before_actv  = prog.insert_instruction(ins, op::add{}, ot_before_actv, ppho_cellt);
Shucai Xiao's avatar
Shucai Xiao committed
995
        }
996
997
998
999
        auto ot = prog.insert_instruction(ins, actv_func1, ot_before_actv);

        // Ht = ot (.) h(Ct)
        auto h_cellt = prog.insert_instruction(ins, actv_func3, cellt);
Shucai Xiao's avatar
Shucai Xiao committed
1000
        auto ht      = prog.insert_instruction(ins, op::mul{}, ot, h_cellt);
1001
1002
1003
1004
1005
1006

        sic = cellt;
        sih = ht;

        last_output = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht);

Shucai Xiao's avatar
Shucai Xiao committed
1007
        if(i < seq_len - 1)
1008
        {
Shucai Xiao's avatar
Shucai Xiao committed
1009
            if(i == 0)
1010
            {
Shucai Xiao's avatar
Shucai Xiao committed
1011
                hidden_states = last_output;
1012
1013
1014
1015
1016
            }
            else
            {
                auto concat_arg0 = is_forward ? hidden_states : last_output;
                auto concat_arg1 = is_forward ? last_output : hidden_states;
Shucai Xiao's avatar
Shucai Xiao committed
1017
1018
                hidden_states =
                    prog.insert_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
1019
1020
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
1021
    }
1022
1023
1024
1025

    last_cell_output = prog.insert_instruction(ins, op::unsqueeze{{0}}, last_cell_output);

    return {hidden_states, last_output, last_cell_output};
Shucai Xiao's avatar
Shucai Xiao committed
1026
1027
1028
1029
1030
1031
1032
1033
1034
}

std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
{
    auto lstm_op = any_cast<op::lstm>(ins->get_operator());
    // before rewrite the lstm operator, need to ensure
    // we have 6 actv funcs, even though a user does not
    // specifiy any actv func. If less than 46, use the
    // algorithm in parse_lstm to make 6 actv functions
Shucai Xiao's avatar
Shucai Xiao committed
1035
    const auto& actv_funcs     = lstm_op.actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1036
    std::size_t num_actv_funcs = actv_funcs.size();
Shucai Xiao's avatar
Shucai Xiao committed
1037
    if(lstm_op.direction == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1038
1039
1040
1041
    {
        switch(num_actv_funcs)
        {
        case 0:
Shucai Xiao's avatar
Shucai Xiao committed
1042
            return {op::sigmoid{}, op::tanh{}, op::tanh{}, op::sigmoid{}, op::tanh{}, op::tanh{}};
Shucai Xiao's avatar
Shucai Xiao committed
1043
1044

        case 1:
Shucai Xiao's avatar
Shucai Xiao committed
1045
1046
1047
1048
1049
1050
            return {actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0)};
Shucai Xiao's avatar
Shucai Xiao committed
1051
1052

        case 2:
Shucai Xiao's avatar
Shucai Xiao committed
1053
1054
1055
1056
1057
1058
1059
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(1),
                    actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(1)};

Shucai Xiao's avatar
Shucai Xiao committed
1060
        case 3:
Shucai Xiao's avatar
Shucai Xiao committed
1061
1062
1063
1064
1065
1066
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2)};
Shucai Xiao's avatar
Shucai Xiao committed
1067
1068

        case 4:
Shucai Xiao's avatar
Shucai Xiao committed
1069
1070
1071
1072
1073
1074
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(3),
                    actv_funcs.at(3),
                    actv_funcs.at(3)};
Shucai Xiao's avatar
Shucai Xiao committed
1075
1076

        case 5:
Shucai Xiao's avatar
Shucai Xiao committed
1077
1078
1079
1080
1081
1082
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(3),
                    actv_funcs.at(4),
                    actv_funcs.at(4)};
Shucai Xiao's avatar
Shucai Xiao committed
1083

Shucai Xiao's avatar
Shucai Xiao committed
1084
        default: return actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1085
1086
1087
1088
1089
1090
        }
    }
    else
    {
        switch(num_actv_funcs)
        {
Shucai Xiao's avatar
Shucai Xiao committed
1091
        case 0: return {op::sigmoid{}, op::tanh{}, op::tanh{}};
Shucai Xiao's avatar
Shucai Xiao committed
1092

Shucai Xiao's avatar
Shucai Xiao committed
1093
        case 1: return {actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0)};
Shucai Xiao's avatar
Shucai Xiao committed
1094

Shucai Xiao's avatar
Shucai Xiao committed
1095
1096
1097
        case 2: return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1)};

        default: return actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1098
1099
1100
1101
        }
    }
}

1102
namespace op {
Shucai Xiao's avatar
Shucai Xiao committed
1103
std::ostream& operator<<(std::ostream& os, rnn_direction v)
1104
{
Shucai Xiao's avatar
Shucai Xiao committed
1105
1106
    std::vector<std::string> rnn_direction_str = {"forward", "reverse", "bidirectional"};
    os << rnn_direction_str[static_cast<std::underlying_type<rnn_direction>::type>(v)];
1107
1108
    return os;
}
Shucai Xiao's avatar
Shucai Xiao committed
1109
} // namespace op
1110

Shucai Xiao's avatar
Shucai Xiao committed
1111
1112
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx