rewrite_rnn.cpp 42.1 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
Paul's avatar
Paul committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#include <migraphx/op/add.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/gru.hpp>
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/rnn.hpp>
#include <migraphx/op/rnn_last_output.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/squeeze.hpp>
#include <migraphx/op/sub.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/unsqueeze.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
18
19
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
20
#include <migraphx/op/common.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
21
22
23
24

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

Shucai Xiao's avatar
Shucai Xiao committed
25
void rewrite_rnn::apply(program& prog) const
Shucai Xiao's avatar
Shucai Xiao committed
26
27
28
{
    for(auto ins : iterator_for(prog))
    {
Shucai Xiao's avatar
Shucai Xiao committed
29
        if(ins->name() == "rnn")
Shucai Xiao's avatar
Shucai Xiao committed
30
        {
Shucai Xiao's avatar
Shucai Xiao committed
31
            apply_vanilla_rnn(prog, ins);
32
        }
33
        else if(ins->name() == "gru")
34
35
        {
            apply_gru(prog, ins);
Shucai Xiao's avatar
Shucai Xiao committed
36
        }
37
38
39
40
        else if(ins->name() == "lstm")
        {
            apply_lstm(prog, ins);
        }
Shucai Xiao's avatar
Shucai Xiao committed
41
    }
42
43
}

Shucai Xiao's avatar
Shucai Xiao committed
44
void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
45
46
47
48
{
    assert(ins->name() == "rnn");
    // could be 3 to 6 inputs, but the parse_rnn function will
    // append undefined operators to make 6 arguments when parsing
Shucai Xiao's avatar
Shucai Xiao committed
49
    // an onnx file. Another case is user can have num of arguments
50
51
52
53
54
55
56
57
58
59
    // when writing their program.
    auto args = ins->inputs();

    shape seq_shape         = args[0]->get_shape();
    std::size_t hidden_size = args[1]->get_shape().lens()[1];
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
    migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
    std::vector<float> data(ih_shape.elements(), 0);

Shucai Xiao's avatar
Shucai Xiao committed
60
61
    auto actv_funcs         = vanilla_rnn_actv_funcs(ins);
    auto rnn_op             = any_cast<op::rnn>(ins->get_operator());
62
    op::rnn_direction dicrt = rnn_op.direction;
63
    instruction_ref last_output{};
64
    if(dicrt == op::rnn_direction::bidirectional)
65
66
67
68
69
70
71
72
73
74
75
76
    {
        // input weight matrix
        auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
        auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);

        // hidden state weight matrix
        auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
        auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);

        // process bias
        instruction_ref bias_forward = prog.end();
        instruction_ref bias_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
77
        if(args.size() >= 4 && args[3]->name() != "undefined")
78
79
80
81
82
83
84
85
86
        {
            bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
            bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
        }

        // process intial hidden state, it could be the 6th argument
        // or the 5th one (if the sequence len argument is ignored)
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
87
        if(args.size() == 6 && args[5]->name() != "undefined")
88
89
90
91
92
93
94
95
96
97
        {
            ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
            ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
        }
        else
        {
            ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
            ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
        }

Shucai Xiao's avatar
Shucai Xiao committed
98
        auto ret_forward = vanilla_rnn_cell(true,
Shucai Xiao's avatar
Shucai Xiao committed
99
100
101
102
103
104
105
106
                                            prog,
                                            ins,
                                            args[0],
                                            w_forward,
                                            r_forward,
                                            bias_forward,
                                            ih_forward,
                                            actv_funcs.at(0));
Shucai Xiao's avatar
Shucai Xiao committed
107
        auto ret_reverse = vanilla_rnn_cell(false,
Shucai Xiao's avatar
Shucai Xiao committed
108
109
110
111
112
113
114
115
                                            prog,
                                            ins,
                                            args[0],
                                            w_reverse,
                                            r_reverse,
                                            bias_reverse,
                                            ih_reverse,
                                            actv_funcs.at(1));
116
117
118
119
120
121
122
123
124
125

        auto concat_output =
            prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
        last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);

        // The following logic is to ensure the last instruction rewritten from
        // rnn operator is a concat instruction
        // sequence len is 1
        if(ret_forward[0] == prog.end())
        {
126
            prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
127
128
129
130
131
132
133
        }
        else
        {
            ret_forward[0] =
                prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
            ret_reverse[0] =
                prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
134
            prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
135
136
137
138
        }
    }
    else
    {
139
        bool is_forward = (dicrt == op::rnn_direction::forward);
140
141
142
143
144
145
146
147
        // input weight matrix
        auto w = args[1];

        // hidden state weight matrix
        auto r = args[2];

        // process bias and initial hidden state
        instruction_ref bias = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
148
        if(args.size() >= 4 && args[3]->name() != "undefined")
149
150
151
152
153
154
        {
            bias = args[3];
        }

        // process intial hidden state
        instruction_ref ih;
Shucai Xiao's avatar
Shucai Xiao committed
155
        if(args.size() == 6 && args[5]->name() != "undefined")
156
157
158
159
160
161
162
163
        {
            ih = args[5];
        }
        else
        {
            ih = prog.add_literal(migraphx::literal{ih_shape, data});
        }

Shucai Xiao's avatar
Shucai Xiao committed
164
165
        auto ret =
            vanilla_rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0));
166
167
168
169
170
171
172
        last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);

        // following logic is to ensure the last instruction is a
        // concat instruction
        // sequence len is 1
        if(ret[0] == prog.end())
        {
173
            prog.replace_instruction(ins, op::concat{0}, ret[1]);
174
175
176
177
178
        }
        else
        {
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
179
            prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        }
    }

    // search its output to find if there are rnn_last_output operator
    // while loop to handle case of multiple rnn_last_output operators
    auto last_output_it = ins->outputs().begin();
    while(last_output_it != ins->outputs().end())
    {
        last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
            return i->name() == "rnn_last_output";
        });

        if(last_output_it != ins->outputs().end())
        {
            prog.replace_instruction(*last_output_it, last_output);
            last_output_it++;
        }
    }
Shucai Xiao's avatar
Shucai Xiao committed
198
199
}

Shucai Xiao's avatar
Shucai Xiao committed
200
std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
201
202
203
204
205
206
207
208
                                                           program& prog,
                                                           instruction_ref ins,
                                                           instruction_ref input,
                                                           instruction_ref w,
                                                           instruction_ref r,
                                                           instruction_ref bias,
                                                           instruction_ref ih,
                                                           operation& actv_func) const
Shucai Xiao's avatar
Shucai Xiao committed
209
{
Shucai Xiao's avatar
Shucai Xiao committed
210
211
    // squeeze and transpose w
    std::vector<int64_t> perm{1, 0};
Shucai Xiao's avatar
Shucai Xiao committed
212
    auto sw      = prog.insert_instruction(ins, op::squeeze{{0}}, w);
Shucai Xiao's avatar
Shucai Xiao committed
213
    auto tran_sw = prog.insert_instruction(ins, op::transpose{perm}, sw);
Shucai Xiao's avatar
Shucai Xiao committed
214
215

    // squeeze and transpose r
Shucai Xiao's avatar
Shucai Xiao committed
216
    auto sr      = prog.insert_instruction(ins, op::squeeze{{0}}, r);
Shucai Xiao's avatar
Shucai Xiao committed
217
218
219
    auto tran_sr = prog.insert_instruction(ins, op::transpose{perm}, sr);

    // initial hidden state
Shucai Xiao's avatar
Shucai Xiao committed
220
    auto sih      = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
221
    auto sih_lens = sih->get_shape().lens();
Shucai Xiao's avatar
Shucai Xiao committed
222
223

    // bias
224
    instruction_ref bb{};
Shucai Xiao's avatar
Shucai Xiao committed
225
    if(bias != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
226
    {
227
        long hs    = static_cast<long>(r->get_shape().lens()[2]);
Shucai Xiao's avatar
Shucai Xiao committed
228
229
230
        auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
        auto wb    = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
        auto rb    = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
231
232
        auto wrb   = prog.insert_instruction(ins, op::add{}, wb, rb);
        bb         = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, wrb);
Shucai Xiao's avatar
Shucai Xiao committed
233
234
    }

Shucai Xiao's avatar
Shucai Xiao committed
235
236
    instruction_ref hidden_out = prog.end();
    instruction_ref last_out{};
Shucai Xiao's avatar
Shucai Xiao committed
237
238
    last_out            = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);
    std::size_t seq_len = input->get_shape().lens()[0];
Shucai Xiao's avatar
Shucai Xiao committed
239
240
    for(std::size_t i = 0; i < seq_len; i++)
    {
Shucai Xiao's avatar
Shucai Xiao committed
241
        long seq_index = is_forward ? i : (seq_len - 1 - i);
Shucai Xiao's avatar
Shucai Xiao committed
242
243
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
        xt      = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
244
245
        auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
        auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
Shucai Xiao's avatar
Shucai Xiao committed
246
        if(bias != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
247
        {
248
            xt_wi = prog.insert_instruction(ins, op::add{}, xt_wi, bb);
Shucai Xiao's avatar
Shucai Xiao committed
249
        }
250
        auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
Shucai Xiao's avatar
Shucai Xiao committed
251
252

        // apply activation function
Shucai Xiao's avatar
Shucai Xiao committed
253
254
        auto ht = prog.insert_instruction(ins, actv_func, xt_ht);
        sih     = ht;
Shucai Xiao's avatar
Shucai Xiao committed
255

Shucai Xiao's avatar
Shucai Xiao committed
256
257
258
        // add the dimensions of sequence length (axis 0 for sequence length,
        // axis 1 for num_directions
        last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht);
Shucai Xiao's avatar
Shucai Xiao committed
259

Shucai Xiao's avatar
Shucai Xiao committed
260
261
262
        // concatenation for the last last_out is performed in the apply()
        // function to ensure the last instruction is concat, then we have
        // output inserted
Shucai Xiao's avatar
Shucai Xiao committed
263
        if(i < seq_len - 1)
Shucai Xiao's avatar
Shucai Xiao committed
264
        {
Shucai Xiao's avatar
Shucai Xiao committed
265
266
            if(is_forward)
            {
Shucai Xiao's avatar
Shucai Xiao committed
267
268
269
270
                hidden_out =
                    (seq_index == 0)
                        ? last_out
                        : prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out);
Shucai Xiao's avatar
Shucai Xiao committed
271
272
273
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
274
275
276
277
                hidden_out =
                    (seq_index == seq_len - 1)
                        ? last_out
                        : prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out);
Shucai Xiao's avatar
Shucai Xiao committed
278
            }
Shucai Xiao's avatar
Shucai Xiao committed
279
280
281
        }
    }

282
    return {hidden_out, last_out};
Shucai Xiao's avatar
Shucai Xiao committed
283
284
}

Shucai Xiao's avatar
Shucai Xiao committed
285
std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins) const
286
287
{
    auto rnn_op = any_cast<op::rnn>(ins->get_operator());
Shucai Xiao's avatar
Shucai Xiao committed
288
289
290
291
    // could be 3 to 6 inputs, but the parse_gru function will
    // append undefined operators to make 6 arguments when parsing
    // an onnx file. Another case is user can have any num of arguments
    // when writing their program.
292
    if(rnn_op.direction == op::rnn_direction::bidirectional)
293
    {
Shucai Xiao's avatar
Shucai Xiao committed
294
        if(rnn_op.actv_funcs.empty())
295
296
297
298
        {
            // default is tanh
            return {op::tanh{}, op::tanh{}};
        }
Shucai Xiao's avatar
Shucai Xiao committed
299
        else if(rnn_op.actv_funcs.size() == 1)
300
301
302
303
304
305
306
307
308
309
        {
            return {rnn_op.actv_funcs.at(0), rnn_op.actv_funcs.at(0)};
        }
        else
        {
            return rnn_op.actv_funcs;
        }
    }
    else
    {
Shucai Xiao's avatar
Shucai Xiao committed
310
        if(rnn_op.actv_funcs.empty())
311
312
313
314
315
316
317
318
319
320
321
        {
            // default is tanh
            return {op::tanh{}};
        }
        else
        {
            return rnn_op.actv_funcs;
        }
    }
}

322
323
324
325
void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
{
    assert(ins->name() == "gru");
    const auto actv_funcs = gru_actv_funcs(ins);
Shucai Xiao's avatar
Shucai Xiao committed
326
327
328
329
    // could be 3 to 6 inputs, but the parse_gru function will
    // append undefined operators to make 6 arguments when parsing
    // an onnx file. Another case is user can have num of arguments
    // when writing their program.
330
331
332
333
334
335
336
337
338
    auto args = ins->inputs();

    shape seq_shape         = args[0]->get_shape();
    std::size_t hidden_size = args[2]->get_shape().lens()[2];
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
    migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
    std::vector<float> data(ih_shape.elements(), 0.0);

Shucai Xiao's avatar
Shucai Xiao committed
339
    auto gru_op             = any_cast<op::gru>(ins->get_operator());
340
    op::rnn_direction dicrt = gru_op.direction;
341
    instruction_ref last_output{};
342
    if(dicrt == op::rnn_direction::bidirectional)
343
344
345
346
347
348
349
350
351
352
353
354
    {
        // w weight matrix
        auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
        auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);

        // r weight matrix
        auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
        auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);

        // bias
        instruction_ref bias_forward = prog.end();
        instruction_ref bias_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
355
        if(args.size() >= 4 && args[3]->name() != "undefined")
356
357
358
359
360
361
362
363
        {
            bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
            bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
        }

        // intial hidden state
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
364
        if(args.size() == 6 && args[5]->name() != "undefined")
365
366
367
368
369
370
371
372
373
374
        {
            ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
            ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
        }
        else
        {
            ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
            ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
        }

Shucai Xiao's avatar
Shucai Xiao committed
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
        auto ret_forward = gru_cell(true,
                                    prog,
                                    ins,
                                    {args[0], w_forward, r_forward, bias_forward, ih_forward},
                                    gru_op.linear_before_reset,
                                    actv_funcs.at(0),
                                    actv_funcs.at(1));

        auto ret_reverse = gru_cell(false,
                                    prog,
                                    ins,
                                    {args[0], w_reverse, r_reverse, bias_reverse, ih_reverse},
                                    gru_op.linear_before_reset,
                                    actv_funcs.at(2),
                                    actv_funcs.at(3));
390
391
392
393
394
395
396
397
398

        auto concat_output =
            prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
        last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);

        // The following logic is to ensure the last instruction rewritten
        // from gru operator is a concat
        if(ret_forward[0] == prog.end())
        {
Shucai Xiao's avatar
Shucai Xiao committed
399
            prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
400
401
402
403
404
405
406
        }
        else
        {
            ret_forward[0] =
                prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
            ret_reverse[0] =
                prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
Shucai Xiao's avatar
Shucai Xiao committed
407
            prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
408
409
410
411
        }
    }
    else
    {
412
        bool is_forward = (dicrt == op::rnn_direction::forward);
413
414
415
416
417
418
        // weight matrix
        auto w = args[1];
        auto r = args[2];

        // bias
        instruction_ref bias = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
419
        if(args.size() >= 4 && args[3]->name() != "undefined")
420
421
422
423
424
425
        {
            bias = args[3];
        }

        // intial hidden state
        instruction_ref ih{};
Shucai Xiao's avatar
Shucai Xiao committed
426
        if(args.size() == 6 && args[5]->name() != "undefined")
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
        {
            ih = args[5];
        }
        else
        {
            ih = prog.add_literal(migraphx::literal{ih_shape, data});
        }

        auto ret = gru_cell(is_forward,
                            prog,
                            ins,
                            {args[0], w, r, bias, ih},
                            gru_op.linear_before_reset,
                            actv_funcs.at(0),
                            actv_funcs.at(1));

        last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);

        if(ret[0] == prog.end())
        {
Shucai Xiao's avatar
Shucai Xiao committed
447
            prog.replace_instruction(ins, op::concat{0}, ret[1]);
448
449
450
451
452
        }
        else
        {
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
Shucai Xiao's avatar
Shucai Xiao committed
453
            prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
454
455
456
        }
    }

457
458
459
    // replace the corresponding rnn_last_output instruction
    // with the last_output, if rnn_last_output exists
    // while loop to handle case of multiple rnn_last_output operators
460
461
462
463
    auto last_output_it = ins->outputs().begin();
    while(last_output_it != ins->outputs().end())
    {
        last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
464
            return i->name() == "rnn_last_output";
465
466
467
468
469
470
471
472
473
474
475
        });

        if(last_output_it != ins->outputs().end())
        {
            prog.replace_instruction(*last_output_it, last_output);
            last_output_it++;
        }
    }
}

std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
476
477
478
479
480
481
                                                   program& prog,
                                                   instruction_ref ins,
                                                   std::vector<instruction_ref> inputs,
                                                   int linear_before_reset,
                                                   const operation& actv_func1,
                                                   const operation& actv_func2) const
482
483
484
485
486
487
488
489
{
    assert(inputs.size() == 5);
    auto seq  = inputs.at(0);
    auto w    = inputs.at(1);
    auto r    = inputs.at(2);
    auto bias = inputs.at(3);
    auto ih   = inputs.at(4);

Shucai Xiao's avatar
Shucai Xiao committed
490
491
    instruction_ref hidden_states = prog.end();
    instruction_ref last_output{};
Shucai Xiao's avatar
Shucai Xiao committed
492
    migraphx::shape seq_shape = seq->get_shape();
Shucai Xiao's avatar
Shucai Xiao committed
493
494
495
    migraphx::shape r_shape   = r->get_shape();
    long seq_len              = static_cast<long>(seq_shape.lens()[0]);
    long hs                   = static_cast<long>(r_shape.lens()[2]);
496

Shucai Xiao's avatar
Shucai Xiao committed
497
    migraphx::shape s(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]});
498
    std::vector<float> data(s.elements(), 1.0f);
499
500
    auto l1 = prog.add_literal(migraphx::literal{s, data});

501
    // w matrix squeeze to 2-dim and do a transpose
502
    std::vector<int64_t> perm{1, 0};
Shucai Xiao's avatar
Shucai Xiao committed
503
    auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
504
    auto tw = prog.insert_instruction(ins, op::transpose{perm}, sw);
505

506
    // r slide to two part, zr and h
Shucai Xiao's avatar
Shucai Xiao committed
507
508
509
    auto sr   = prog.insert_instruction(ins, op::squeeze{{0}}, r);
    auto rzr  = prog.insert_instruction(ins, op::slice{{0}, {0}, {2 * hs}}, sr);
    auto trzr = prog.insert_instruction(ins, op::transpose{perm}, rzr);
510

Shucai Xiao's avatar
Shucai Xiao committed
511
    auto rh  = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
512
    auto trh = prog.insert_instruction(ins, op::transpose{perm}, rh);
513
514

    // initial states
Shucai Xiao's avatar
Shucai Xiao committed
515
516
    auto sih  = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
    size_t bs = ih->get_shape().lens()[1];
517
518

    // bias
519
520
521
    instruction_ref bwb{};
    instruction_ref brb_zr{};
    instruction_ref brb_h{};
522
523
524
    if(bias != prog.end())
    {
        auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
Shucai Xiao's avatar
Shucai Xiao committed
525
526
        auto wb    = prog.insert_instruction(ins, op::slice{{0}, {0}, {3 * hs}}, sbias);
        bwb = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(3 * hs)}}, wb);
527
528

        auto rb_zr = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {5 * hs}}, sbias);
Shucai Xiao's avatar
Shucai Xiao committed
529
530
531
532
        auto rb_h  = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
        brb_zr     = prog.insert_instruction(
            ins, op::broadcast{1, {bs, static_cast<size_t>(2 * hs)}}, rb_zr);
        brb_h = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(hs)}}, rb_h);
533
534
535
536
537
538
539
540
    }

    for(long i = 0; i < seq_len; i++)
    {
        long seq_index = is_forward ? i : (seq_len - 1 - i);
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
        xt      = prog.insert_instruction(ins, op::squeeze{{0}}, xt);

541
542
        auto xt_w    = prog.insert_instruction(ins, op::dot{}, xt, tw);
        auto ih1_rzr = prog.insert_instruction(ins, op::dot{}, sih, trzr);
Shucai Xiao's avatar
Shucai Xiao committed
543
        if(bias != prog.end())
544
        {
545
546
            xt_w    = prog.insert_instruction(ins, op::add{}, xt_w, bwb);
            ih1_rzr = prog.insert_instruction(ins, op::add{}, ih1_rzr, brb_zr);
547
548
549
550
551
552
553
554
555
556
        }

        auto xw_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_w);
        auto xw_r = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, xt_w);
        auto xw_h = prog.insert_instruction(ins, op::slice{{1}, {2 * hs}, {3 * hs}}, xt_w);

        auto hr_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, ih1_rzr);
        auto hr_r = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, ih1_rzr);

        auto xw_hr_z = prog.insert_instruction(ins, op::add{}, xw_z, hr_z);
Shucai Xiao's avatar
Shucai Xiao committed
557
        auto zt      = prog.insert_instruction(ins, actv_func1, xw_hr_z);
558
559

        auto xw_hr_r = prog.insert_instruction(ins, op::add{}, xw_r, hr_r);
Shucai Xiao's avatar
Shucai Xiao committed
560
        auto rt      = prog.insert_instruction(ins, actv_func1, xw_hr_r);
561
562
563
564
565
566

        instruction_ref hr_h{};
        if(linear_before_reset == 0)
        {
            // equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
            auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
Shucai Xiao's avatar
Shucai Xiao committed
567
            if(bias != prog.end())
568
            {
Shucai Xiao's avatar
Shucai Xiao committed
569
                hr_h = prog.insert_instruction(ins, op::dot{}, rt_ht1, trh, brb_h);
570
            }
571
572
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
573
                hr_h = prog.insert_instruction(ins, op::dot{}, rt_ht1, trh);
574
            }
575
576
577
        }
        else
        {
578
579
            // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
            instruction_ref ht1_rh{};
Shucai Xiao's avatar
Shucai Xiao committed
580
            if(bias != prog.end())
581
            {
582
                ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, trh, brb_h);
583
            }
584
            else
585
            {
Shucai Xiao's avatar
Shucai Xiao committed
586
                ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, trh);
587
            }
Shucai Xiao's avatar
Shucai Xiao committed
588
            hr_h = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
589
590
        }

591
        auto xw_hr_h = prog.insert_instruction(ins, op::add{}, xw_h, hr_h);
Shucai Xiao's avatar
Shucai Xiao committed
592
        auto ht      = prog.insert_instruction(ins, actv_func2, xw_hr_h);
593

594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
        // equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
        auto one_minus_zt    = prog.insert_instruction(ins, op::sub{}, l1, zt);
        auto one_minus_zt_ht = prog.insert_instruction(ins, op::mul{}, one_minus_zt, ht);
        auto zt_ht1          = prog.insert_instruction(ins, op::mul{}, zt, sih);
        sih                  = prog.insert_instruction(ins, op::add{}, one_minus_zt_ht, zt_ht1);
        last_output          = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);

        if(i < seq_len - 1)
        {
            if(is_forward)
            {
                hidden_states =
                    (seq_index == 0)
                        ? last_output
                        : prog.insert_instruction(ins, op::concat{0}, hidden_states, last_output);
            }
            else
            {
                hidden_states =
                    (seq_index == seq_len - 1)
                        ? last_output
                        : prog.insert_instruction(ins, op::concat{0}, last_output, hidden_states);
            }
        }
    }

    return {hidden_states, last_output};
}

std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
{
    auto gru_op = any_cast<op::gru>(ins->get_operator());
    // before rewrite the gru operator, need to ensure
    // we have 4 actv funcs, even though a user does not
    // specifiy any actv func. If less than 4, use the
    // algorithm in parse_gru to make 4 actv functions
630
    if(gru_op.direction == op::rnn_direction::bidirectional)
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
    {
        if(gru_op.actv_funcs.empty())
            return {op::sigmoid{}, op::tanh{}, op::sigmoid{}, op::tanh{}};
        else if(gru_op.actv_funcs.size() == 1)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0)};
        else if(gru_op.actv_funcs.size() == 2)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1)};
        else if(gru_op.actv_funcs.size() == 3)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1),
                    gru_op.actv_funcs.at(2),
                    gru_op.actv_funcs.at(0)};
        else
            return gru_op.actv_funcs;
    }
    else
    {
        if(gru_op.actv_funcs.empty())
            return {op::sigmoid{}, op::tanh{}};
        else if(gru_op.actv_funcs.size() == 1)
            return {gru_op.actv_funcs.at(0), gru_op.actv_funcs.at(0)};
        else
            return gru_op.actv_funcs;
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
663
664
665
666
667
// for lstm operators
void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
{
    assert(ins->name() == "lstm");
    auto args = ins->inputs();
Shucai Xiao's avatar
Shucai Xiao committed
668

Shucai Xiao's avatar
Shucai Xiao committed
669
    shape seq_shape         = args[0]->get_shape();
670
    std::size_t hidden_size = args[2]->get_shape().lens()[2];
Shucai Xiao's avatar
Shucai Xiao committed
671
672
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
Shucai Xiao's avatar
Shucai Xiao committed
673
    migraphx::shape ihc_shape{type, {1, batch_size, hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
674
    std::vector<float> ihc_data(ihc_shape.elements(), 0.0);
Shucai Xiao's avatar
Shucai Xiao committed
675
676
677

    migraphx::shape pph_shape{type, {1, 3 * hidden_size}};

Shucai Xiao's avatar
Shucai Xiao committed
678
679
    auto actv_funcs         = lstm_actv_funcs(ins);
    auto lstm_op            = any_cast<op::lstm>(ins->get_operator());
Shucai Xiao's avatar
Shucai Xiao committed
680
    op::rnn_direction dirct = lstm_op.direction;
Shucai Xiao's avatar
Shucai Xiao committed
681
682
683

    instruction_ref last_output{};
    instruction_ref last_cell_output{};
Shucai Xiao's avatar
Shucai Xiao committed
684
    if(dirct == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
685
686
687
688
689
690
691
692
693
694
695
696
697
    {
        // input weight matrix
        // input weight matrix
        auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
        auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);

        // hidden state weight matrix
        auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
        auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);

        // process bias
        instruction_ref bias_forward = prog.end();
        instruction_ref bias_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
698
        if(args.size() >= 4 && args[3]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
        {
            bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
            bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
        }

        // process intial hidden state, it is the 6th argument
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
        if(args.size() >= 6 && args[5]->name() != "undefined")
        {
            ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
            ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
        }
        else
        {
            ih_forward = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
            ih_reverse = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
        }

        // process initial cell value
        instruction_ref ic_forward{};
        instruction_ref ic_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
721
        if(args.size() >= 7 && args[6]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
722
723
724
725
726
727
728
729
730
731
732
        {
            ic_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[6]);
            ic_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[6]);
        }
        else
        {
            ic_forward = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
            ic_reverse = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
        }

        // process weight of the peephole
Shucai Xiao's avatar
Shucai Xiao committed
733
734
        instruction_ref pph_forward = prog.end();
        instruction_ref pph_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
735
        if(args.size() == 8 && args[7]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
736
737
738
739
        {
            pph_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[7]);
            pph_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[7]);
        }
Shucai Xiao's avatar
Shucai Xiao committed
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760

        auto ret_forward = lstm_cell(
            true,
            prog,
            ins,
            {args[0], w_forward, r_forward, bias_forward, ih_forward, ic_forward, pph_forward},
            actv_funcs.at(0),
            actv_funcs.at(1),
            actv_funcs.at(2));

        auto ret_reverse = lstm_cell(
            false,
            prog,
            ins,
            {args[0], w_reverse, r_reverse, bias_reverse, ih_reverse, ic_reverse, pph_reverse},
            actv_funcs.at(3),
            actv_funcs.at(4),
            actv_funcs.at(5));

        auto concat_output =
            prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
Shucai Xiao's avatar
Shucai Xiao committed
761
762
763
        last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);

        // last cell output
764
765
        last_cell_output =
            prog.insert_instruction(ins, op::concat{0}, ret_forward[2], ret_reverse[2]);
Shucai Xiao's avatar
Shucai Xiao committed
766
767

        // the following logic is to ensure the last instruction is a concat
Shucai Xiao's avatar
Shucai Xiao committed
768
        if(ret_forward[0] == prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
769
770
771
772
773
        {
            prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
774
775
776
777
778
            ret_forward[0] =
                prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
            ret_reverse[0] =
                prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
            prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
Shucai Xiao's avatar
Shucai Xiao committed
779
780
781
782
        }
    }
    else
    {
Shucai Xiao's avatar
Shucai Xiao committed
783
        bool is_forward = (dirct == op::rnn_direction::forward);
Shucai Xiao's avatar
Shucai Xiao committed
784
785
786
787
788
789
        // weight matrices
        auto w = args[1];
        auto r = args[2];

        // bias
        instruction_ref bias = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
790
        if(args.size() >= 4 && args[3]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
791
792
793
794
795
796
        {
            bias = args[3];
        }

        // initial hidden state
        instruction_ref ih{};
Shucai Xiao's avatar
Shucai Xiao committed
797
        if(args.size() >= 6 && args[5]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
798
799
800
801
802
803
804
805
806
807
        {
            ih = args[5];
        }
        else
        {
            ih = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
        }

        // initial cell value
        instruction_ref ic{};
Shucai Xiao's avatar
Shucai Xiao committed
808
        if(args.size() >= 7 && args[6]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
809
810
811
812
813
814
815
816
817
        {
            ic = args[6];
        }
        else
        {
            ic = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
        }

        // process weight of the peephole
Shucai Xiao's avatar
Shucai Xiao committed
818
        instruction_ref pph = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
819
820
821
822
        if(args.size() == 8 && args[7]->name() != "undefined")
        {
            pph = args[7];
        }
Shucai Xiao's avatar
Shucai Xiao committed
823
824

        auto ret = lstm_cell(is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
825
826
827
828
829
830
831
                             prog,
                             ins,
                             {args[0], w, r, bias, ih, ic, pph},
                             actv_funcs.at(0),
                             actv_funcs.at(1),
                             actv_funcs.at(2));

Shucai Xiao's avatar
Shucai Xiao committed
832
        last_output      = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
833
        last_cell_output = ret[2];
Shucai Xiao's avatar
Shucai Xiao committed
834
        if(ret[0] == prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
835
836
837
838
839
840
841
842
843
844
845
846
847
        {
            prog.replace_instruction(ins, op::concat{0}, ret[1]);
        }
        else
        {
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
            prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
        }
    }

    // replace the corresponding lstm_last_output instruction
    // with the last_output, and the lstm_last_cell_output with
Shucai Xiao's avatar
Shucai Xiao committed
848
    // the last_cell_output. The while loop is to handle the case
Shucai Xiao's avatar
Shucai Xiao committed
849
850
851
852
853
854
    // of multiple lstm_last_output and lstm_last_cell_output
    // operators
    auto last_output_it = ins->outputs().begin();
    while(last_output_it != ins->outputs().end())
    {
        last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
855
            return i->name() == "rnn_last_output";
Shucai Xiao's avatar
Shucai Xiao committed
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
        });

        if(last_output_it != ins->outputs().end())
        {
            prog.replace_instruction(*last_output_it, last_output);
            last_output_it++;
        }
    }

    auto last_cell_output_it = ins->outputs().begin();
    while(last_cell_output_it != ins->outputs().end())
    {
        last_cell_output_it = std::find_if(last_cell_output_it, ins->outputs().end(), [](auto i) {
            return i->name() == "lstm_last_cell_output";
        });

        if(last_cell_output_it != ins->outputs().end())
        {
            prog.replace_instruction(*last_cell_output_it, last_cell_output);
            last_cell_output_it++;
        }
Shucai Xiao's avatar
Shucai Xiao committed
877
    }
Shucai Xiao's avatar
Shucai Xiao committed
878
879
880
}

std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
881
882
883
884
885
886
                                                    program& prog,
                                                    instruction_ref ins,
                                                    std::vector<instruction_ref> inputs,
                                                    const operation& actv_func1,
                                                    const operation& actv_func2,
                                                    const operation& actv_func3) const
Shucai Xiao's avatar
Shucai Xiao committed
887
{
Shucai Xiao's avatar
Shucai Xiao committed
888
889
    // must have 7 args in the input vector
    assert(inputs.size() == 7);
Shucai Xiao's avatar
Shucai Xiao committed
890
891
892
    auto seq  = inputs.at(0);
    auto w    = inputs.at(1);
    auto r    = inputs.at(2);
Shucai Xiao's avatar
Shucai Xiao committed
893
    auto bias = inputs.at(3);
Shucai Xiao's avatar
Shucai Xiao committed
894
895
896
    auto ih   = inputs.at(4);
    auto ic   = inputs.at(5);
    auto pph  = inputs.at(6);
Shucai Xiao's avatar
Shucai Xiao committed
897

Shucai Xiao's avatar
Shucai Xiao committed
898
899
900
901
902
    instruction_ref hidden_states = prog.end();
    instruction_ref last_output{};
    instruction_ref last_cell_output{};

    migraphx::shape seq_shape = seq->get_shape();
Shucai Xiao's avatar
Shucai Xiao committed
903
    migraphx::shape r_shape   = r->get_shape();
Shucai Xiao's avatar
Shucai Xiao committed
904
    long seq_len              = static_cast<long>(seq_shape.lens()[0]);
Shucai Xiao's avatar
Shucai Xiao committed
905
    long hs                   = static_cast<long>(r_shape.lens()[2]);
Shucai Xiao's avatar
Shucai Xiao committed
906
    auto bs                   = ih->get_shape().lens()[1];
Shucai Xiao's avatar
Shucai Xiao committed
907
908

    std::vector<int64_t> perm{1, 0};
909
    // w matrix, squeeze and transpose
Shucai Xiao's avatar
Shucai Xiao committed
910
    auto sw  = prog.insert_instruction(ins, op::squeeze{{0}}, w);
911
    auto tsw = prog.insert_instruction(ins, op::transpose{perm}, sw);
Shucai Xiao's avatar
Shucai Xiao committed
912

913
    // r matrix, squeeze and transpose
Shucai Xiao's avatar
Shucai Xiao committed
914
    auto sr  = prog.insert_instruction(ins, op::squeeze{{0}}, r);
915
    auto tsr = prog.insert_instruction(ins, op::transpose{perm}, sr);
Shucai Xiao's avatar
Shucai Xiao committed
916

Shucai Xiao's avatar
Shucai Xiao committed
917
918
919
920
    // initial hidden state
    auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);

    // initial cell state
Shucai Xiao's avatar
Shucai Xiao committed
921
    auto sic     = prog.insert_instruction(ins, op::squeeze{{0}}, ic);
922
    auto ic_lens = sic->get_shape().lens();
Shucai Xiao's avatar
Shucai Xiao committed
923

924
    // bias
925
    instruction_ref wrb{};
Shucai Xiao's avatar
Shucai Xiao committed
926
    if(bias != prog.end())
927
    {
928

Shucai Xiao's avatar
Shucai Xiao committed
929
930
931
        auto sbias  = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
        auto ub_wb  = prog.insert_instruction(ins, op::slice{{0}, {0}, {4 * hs}}, sbias);
        auto ub_rb  = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {8 * hs}}, sbias);
932
        auto ub_wrb = prog.insert_instruction(ins, op::add{}, ub_wb, ub_rb);
933

934
935
        wrb = prog.insert_instruction(
            ins, op::broadcast{1, {bs, 4 * static_cast<size_t>(hs)}}, ub_wrb);
936
937
    }

Shucai Xiao's avatar
Shucai Xiao committed
938
    // peep hole
Shucai Xiao's avatar
Shucai Xiao committed
939
940
941
    instruction_ref pphi_brcst{};
    instruction_ref ppho_brcst{};
    instruction_ref pphf_brcst{};
Shucai Xiao's avatar
Shucai Xiao committed
942
    if(pph != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
943
    {
Shucai Xiao's avatar
Shucai Xiao committed
944
945
        auto spph  = prog.insert_instruction(ins, op::squeeze{{0}}, pph);
        auto pphi  = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph);
946
        pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, pphi);
Shucai Xiao's avatar
Shucai Xiao committed
947

Shucai Xiao's avatar
Shucai Xiao committed
948
        auto ppho  = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph);
949
        ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, ppho);
Shucai Xiao's avatar
Shucai Xiao committed
950

Shucai Xiao's avatar
Shucai Xiao committed
951
        auto pphf  = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph);
952
        pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, pphf);
Shucai Xiao's avatar
Shucai Xiao committed
953
    }
Shucai Xiao's avatar
Shucai Xiao committed
954

Shucai Xiao's avatar
Shucai Xiao committed
955
    for(long i = 0; i < seq_len; ++i)
Shucai Xiao's avatar
Shucai Xiao committed
956
957
958
    {
        long seq_index = is_forward ? i : (seq_len - 1 - i);
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
Shucai Xiao's avatar
Shucai Xiao committed
959
        xt      = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
Shucai Xiao's avatar
Shucai Xiao committed
960

961
962
        auto xt_tsw  = prog.insert_instruction(ins, op::dot{}, xt, tsw);
        auto sih_tsr = prog.insert_instruction(ins, op::dot{}, sih, tsr);
Shucai Xiao's avatar
Shucai Xiao committed
963
        auto xt_sih  = prog.insert_instruction(ins, op::add{}, xt_tsw, sih_tsr);
Shucai Xiao's avatar
Shucai Xiao committed
964
        if(bias != prog.end())
965
        {
Shucai Xiao's avatar
Shucai Xiao committed
966
            xt_sih = prog.insert_instruction(ins, op::add{}, xt_sih, wrb);
967
        }
Shucai Xiao's avatar
Shucai Xiao committed
968

969
970
        auto it_before_actv = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_sih);
        auto ot_before_actv = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, xt_sih);
Shucai Xiao's avatar
Shucai Xiao committed
971
972
973
974
        auto ft_before_actv =
            prog.insert_instruction(ins, op::slice{{1}, {2 * hs}, {3 * hs}}, xt_sih);
        auto ct_before_actv =
            prog.insert_instruction(ins, op::slice{{1}, {3 * hs}, {4 * hs}}, xt_sih);
975

Shucai Xiao's avatar
Shucai Xiao committed
976
        if(pph != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
977
        {
Shucai Xiao's avatar
Shucai Xiao committed
978
979
980
            auto pphi_ct   = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic);
            it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, pphi_ct);

Shucai Xiao's avatar
Shucai Xiao committed
981
982
            auto pphf_ct   = prog.insert_instruction(ins, op::mul{}, pphf_brcst, sic);
            ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, pphf_ct);
Shucai Xiao's avatar
Shucai Xiao committed
983
        }
Shucai Xiao's avatar
Shucai Xiao committed
984
        auto it = prog.insert_instruction(ins, actv_func1, it_before_actv);
985
        auto ft = prog.insert_instruction(ins, actv_func1, ft_before_actv);
986
        auto ct = prog.insert_instruction(ins, actv_func2, ct_before_actv);
987
988

        // equation Ct = ft (.) Ct-1 + it (.) ct
Shucai Xiao's avatar
Shucai Xiao committed
989
990
991
        auto ft_cell     = prog.insert_instruction(ins, op::mul{}, ft, sic);
        auto it_ct       = prog.insert_instruction(ins, op::mul{}, it, ct);
        auto cellt       = prog.insert_instruction(ins, op::add{}, ft_cell, it_ct);
992
993
        last_cell_output = cellt;

Shucai Xiao's avatar
Shucai Xiao committed
994
        if(pph != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
995
        {
Shucai Xiao's avatar
Shucai Xiao committed
996
997
            auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt);
            ot_before_actv  = prog.insert_instruction(ins, op::add{}, ot_before_actv, ppho_cellt);
Shucai Xiao's avatar
Shucai Xiao committed
998
        }
999
1000
1001
1002
        auto ot = prog.insert_instruction(ins, actv_func1, ot_before_actv);

        // Ht = ot (.) h(Ct)
        auto h_cellt = prog.insert_instruction(ins, actv_func3, cellt);
Shucai Xiao's avatar
Shucai Xiao committed
1003
        auto ht      = prog.insert_instruction(ins, op::mul{}, ot, h_cellt);
1004
1005
1006
1007
1008
1009

        sic = cellt;
        sih = ht;

        last_output = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht);

Shucai Xiao's avatar
Shucai Xiao committed
1010
        if(i < seq_len - 1)
1011
        {
Shucai Xiao's avatar
Shucai Xiao committed
1012
            if(i == 0)
1013
            {
Shucai Xiao's avatar
Shucai Xiao committed
1014
                hidden_states = last_output;
1015
1016
1017
1018
1019
            }
            else
            {
                auto concat_arg0 = is_forward ? hidden_states : last_output;
                auto concat_arg1 = is_forward ? last_output : hidden_states;
Shucai Xiao's avatar
Shucai Xiao committed
1020
1021
                hidden_states =
                    prog.insert_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
1022
1023
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
1024
    }
1025
1026
1027
1028

    last_cell_output = prog.insert_instruction(ins, op::unsqueeze{{0}}, last_cell_output);

    return {hidden_states, last_output, last_cell_output};
Shucai Xiao's avatar
Shucai Xiao committed
1029
1030
1031
1032
1033
1034
1035
1036
1037
}

std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
{
    auto lstm_op = any_cast<op::lstm>(ins->get_operator());
    // before rewrite the lstm operator, need to ensure
    // we have 6 actv funcs, even though a user does not
    // specifiy any actv func. If less than 46, use the
    // algorithm in parse_lstm to make 6 actv functions
Shucai Xiao's avatar
Shucai Xiao committed
1038
    const auto& actv_funcs     = lstm_op.actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1039
    std::size_t num_actv_funcs = actv_funcs.size();
Shucai Xiao's avatar
Shucai Xiao committed
1040
    if(lstm_op.direction == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1041
1042
1043
1044
    {
        switch(num_actv_funcs)
        {
        case 0:
Shucai Xiao's avatar
Shucai Xiao committed
1045
            return {op::sigmoid{}, op::tanh{}, op::tanh{}, op::sigmoid{}, op::tanh{}, op::tanh{}};
Shucai Xiao's avatar
Shucai Xiao committed
1046
1047

        case 1:
Shucai Xiao's avatar
Shucai Xiao committed
1048
1049
1050
1051
1052
1053
            return {actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0)};
Shucai Xiao's avatar
Shucai Xiao committed
1054
1055

        case 2:
Shucai Xiao's avatar
Shucai Xiao committed
1056
1057
1058
1059
1060
1061
1062
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(1),
                    actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(1)};

Shucai Xiao's avatar
Shucai Xiao committed
1063
        case 3:
Shucai Xiao's avatar
Shucai Xiao committed
1064
1065
1066
1067
1068
1069
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2)};
Shucai Xiao's avatar
Shucai Xiao committed
1070
1071

        case 4:
Shucai Xiao's avatar
Shucai Xiao committed
1072
1073
1074
1075
1076
1077
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(3),
                    actv_funcs.at(3),
                    actv_funcs.at(3)};
Shucai Xiao's avatar
Shucai Xiao committed
1078
1079

        case 5:
Shucai Xiao's avatar
Shucai Xiao committed
1080
1081
1082
1083
1084
1085
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(3),
                    actv_funcs.at(4),
                    actv_funcs.at(4)};
Shucai Xiao's avatar
Shucai Xiao committed
1086

Shucai Xiao's avatar
Shucai Xiao committed
1087
        default: return actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1088
1089
1090
1091
1092
1093
        }
    }
    else
    {
        switch(num_actv_funcs)
        {
Shucai Xiao's avatar
Shucai Xiao committed
1094
        case 0: return {op::sigmoid{}, op::tanh{}, op::tanh{}};
Shucai Xiao's avatar
Shucai Xiao committed
1095

Shucai Xiao's avatar
Shucai Xiao committed
1096
        case 1: return {actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0)};
Shucai Xiao's avatar
Shucai Xiao committed
1097

Shucai Xiao's avatar
Shucai Xiao committed
1098
1099
1100
        case 2: return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1)};

        default: return actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1101
1102
1103
1104
        }
    }
}

1105
namespace op {
Shucai Xiao's avatar
Shucai Xiao committed
1106
std::ostream& operator<<(std::ostream& os, rnn_direction v)
1107
{
Shucai Xiao's avatar
Shucai Xiao committed
1108
1109
    std::vector<std::string> rnn_direction_str = {"forward", "reverse", "bidirectional"};
    os << rnn_direction_str[static_cast<std::underlying_type<rnn_direction>::type>(v)];
1110
1111
    return os;
}
Shucai Xiao's avatar
Shucai Xiao committed
1112
} // namespace op
1113

Shucai Xiao's avatar
Shucai Xiao committed
1114
1115
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx