rewrite_rnn.cpp 41.7 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
4
5
6
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
7
#include <migraphx/op/common.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
8
9
10
11

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

Shucai Xiao's avatar
Shucai Xiao committed
12
void rewrite_rnn::apply(program& prog) const
Shucai Xiao's avatar
Shucai Xiao committed
13
14
15
{
    for(auto ins : iterator_for(prog))
    {
Shucai Xiao's avatar
Shucai Xiao committed
16
        if(ins->name() == "rnn")
Shucai Xiao's avatar
Shucai Xiao committed
17
        {
Shucai Xiao's avatar
Shucai Xiao committed
18
            apply_vanilla_rnn(prog, ins);
19
        }
20
        else if(ins->name() == "gru")
21
22
        {
            apply_gru(prog, ins);
Shucai Xiao's avatar
Shucai Xiao committed
23
        }
24
25
26
27
        else if(ins->name() == "lstm")
        {
            apply_lstm(prog, ins);
        }
Shucai Xiao's avatar
Shucai Xiao committed
28
    }
29
30
}

Shucai Xiao's avatar
Shucai Xiao committed
31
void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
32
33
34
35
{
    assert(ins->name() == "rnn");
    // could be 3 to 6 inputs, but the parse_rnn function will
    // append undefined operators to make 6 arguments when parsing
Shucai Xiao's avatar
Shucai Xiao committed
36
    // an onnx file. Another case is user can have num of arguments
37
38
39
40
41
42
43
44
45
46
    // when writing their program.
    auto args = ins->inputs();

    shape seq_shape         = args[0]->get_shape();
    std::size_t hidden_size = args[1]->get_shape().lens()[1];
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
    migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
    std::vector<float> data(ih_shape.elements(), 0);

Shucai Xiao's avatar
Shucai Xiao committed
47
48
    auto actv_funcs         = vanilla_rnn_actv_funcs(ins);
    auto rnn_op             = any_cast<op::rnn>(ins->get_operator());
49
    op::rnn_direction dicrt = rnn_op.direction;
50
    instruction_ref last_output{};
51
    if(dicrt == op::rnn_direction::bidirectional)
52
53
54
55
56
57
58
59
60
61
62
63
    {
        // input weight matrix
        auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
        auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);

        // hidden state weight matrix
        auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
        auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);

        // process bias
        instruction_ref bias_forward = prog.end();
        instruction_ref bias_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
64
        if(args.size() >= 4 && args[3]->name() != "undefined")
65
66
67
68
69
70
71
72
73
        {
            bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
            bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
        }

        // process intial hidden state, it could be the 6th argument
        // or the 5th one (if the sequence len argument is ignored)
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
74
        if(args.size() == 6 && args[5]->name() != "undefined")
75
76
77
78
79
80
81
82
83
84
        {
            ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
            ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
        }
        else
        {
            ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
            ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
        }

Shucai Xiao's avatar
Shucai Xiao committed
85
        auto ret_forward = vanilla_rnn_cell(true,
Shucai Xiao's avatar
Shucai Xiao committed
86
87
88
89
90
91
92
93
                                            prog,
                                            ins,
                                            args[0],
                                            w_forward,
                                            r_forward,
                                            bias_forward,
                                            ih_forward,
                                            actv_funcs.at(0));
Shucai Xiao's avatar
Shucai Xiao committed
94
        auto ret_reverse = vanilla_rnn_cell(false,
Shucai Xiao's avatar
Shucai Xiao committed
95
96
97
98
99
100
101
102
                                            prog,
                                            ins,
                                            args[0],
                                            w_reverse,
                                            r_reverse,
                                            bias_reverse,
                                            ih_reverse,
                                            actv_funcs.at(1));
103
104
105
106
107
108
109
110
111
112

        auto concat_output =
            prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
        last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);

        // The following logic is to ensure the last instruction rewritten from
        // rnn operator is a concat instruction
        // sequence len is 1
        if(ret_forward[0] == prog.end())
        {
113
            prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
114
115
116
117
118
119
120
        }
        else
        {
            ret_forward[0] =
                prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
            ret_reverse[0] =
                prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
121
            prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
122
123
124
125
        }
    }
    else
    {
126
        bool is_forward = (dicrt == op::rnn_direction::forward);
127
128
129
130
131
132
133
134
        // input weight matrix
        auto w = args[1];

        // hidden state weight matrix
        auto r = args[2];

        // process bias and initial hidden state
        instruction_ref bias = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
135
        if(args.size() >= 4 && args[3]->name() != "undefined")
136
137
138
139
140
141
        {
            bias = args[3];
        }

        // process intial hidden state
        instruction_ref ih;
Shucai Xiao's avatar
Shucai Xiao committed
142
        if(args.size() == 6 && args[5]->name() != "undefined")
143
144
145
146
147
148
149
150
        {
            ih = args[5];
        }
        else
        {
            ih = prog.add_literal(migraphx::literal{ih_shape, data});
        }

Shucai Xiao's avatar
Shucai Xiao committed
151
152
        auto ret =
            vanilla_rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0));
153
154
155
156
157
158
159
        last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);

        // following logic is to ensure the last instruction is a
        // concat instruction
        // sequence len is 1
        if(ret[0] == prog.end())
        {
160
            prog.replace_instruction(ins, op::concat{0}, ret[1]);
161
162
163
164
165
        }
        else
        {
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
166
            prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
        }
    }

    // search its output to find if there are rnn_last_output operator
    // while loop to handle case of multiple rnn_last_output operators
    auto last_output_it = ins->outputs().begin();
    while(last_output_it != ins->outputs().end())
    {
        last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
            return i->name() == "rnn_last_output";
        });

        if(last_output_it != ins->outputs().end())
        {
            prog.replace_instruction(*last_output_it, last_output);
            last_output_it++;
        }
    }
Shucai Xiao's avatar
Shucai Xiao committed
185
186
}

Shucai Xiao's avatar
Shucai Xiao committed
187
std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
188
189
190
191
192
193
194
195
                                                           program& prog,
                                                           instruction_ref ins,
                                                           instruction_ref input,
                                                           instruction_ref w,
                                                           instruction_ref r,
                                                           instruction_ref bias,
                                                           instruction_ref ih,
                                                           operation& actv_func) const
Shucai Xiao's avatar
Shucai Xiao committed
196
{
Shucai Xiao's avatar
Shucai Xiao committed
197
198
    // squeeze and transpose w
    std::vector<int64_t> perm{1, 0};
Shucai Xiao's avatar
Shucai Xiao committed
199
    auto sw      = prog.insert_instruction(ins, op::squeeze{{0}}, w);
Shucai Xiao's avatar
Shucai Xiao committed
200
    auto tran_sw = prog.insert_instruction(ins, op::transpose{perm}, sw);
Shucai Xiao's avatar
Shucai Xiao committed
201
202

    // squeeze and transpose r
Shucai Xiao's avatar
Shucai Xiao committed
203
    auto sr      = prog.insert_instruction(ins, op::squeeze{{0}}, r);
Shucai Xiao's avatar
Shucai Xiao committed
204
205
206
    auto tran_sr = prog.insert_instruction(ins, op::transpose{perm}, sr);

    // initial hidden state
Shucai Xiao's avatar
Shucai Xiao committed
207
    auto sih      = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
208
    auto sih_lens = sih->get_shape().lens();
Shucai Xiao's avatar
Shucai Xiao committed
209
210

    // bias
211
    instruction_ref bb{};
Shucai Xiao's avatar
Shucai Xiao committed
212
    if(bias != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
213
    {
214
        long hs    = static_cast<long>(r->get_shape().lens()[2]);
Shucai Xiao's avatar
Shucai Xiao committed
215
216
217
        auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
        auto wb    = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
        auto rb    = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
218
219
        auto wrb   = prog.insert_instruction(ins, op::add{}, wb, rb);
        bb         = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, wrb);
Shucai Xiao's avatar
Shucai Xiao committed
220
221
    }

Shucai Xiao's avatar
Shucai Xiao committed
222
223
    instruction_ref hidden_out = prog.end();
    instruction_ref last_out{};
Shucai Xiao's avatar
Shucai Xiao committed
224
225
    last_out            = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);
    std::size_t seq_len = input->get_shape().lens()[0];
Shucai Xiao's avatar
Shucai Xiao committed
226
227
    for(std::size_t i = 0; i < seq_len; i++)
    {
Shucai Xiao's avatar
Shucai Xiao committed
228
        long seq_index = is_forward ? i : (seq_len - 1 - i);
Shucai Xiao's avatar
Shucai Xiao committed
229
230
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
        xt      = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
231
232
        auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
        auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
Shucai Xiao's avatar
Shucai Xiao committed
233
        if(bias != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
234
        {
235
            xt_wi = prog.insert_instruction(ins, op::add{}, xt_wi, bb);
Shucai Xiao's avatar
Shucai Xiao committed
236
        }
237
        auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
Shucai Xiao's avatar
Shucai Xiao committed
238
239

        // apply activation function
Shucai Xiao's avatar
Shucai Xiao committed
240
241
        auto ht = prog.insert_instruction(ins, actv_func, xt_ht);
        sih     = ht;
Shucai Xiao's avatar
Shucai Xiao committed
242

Shucai Xiao's avatar
Shucai Xiao committed
243
244
245
        // add the dimensions of sequence length (axis 0 for sequence length,
        // axis 1 for num_directions
        last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht);
Shucai Xiao's avatar
Shucai Xiao committed
246

Shucai Xiao's avatar
Shucai Xiao committed
247
248
249
        // concatenation for the last last_out is performed in the apply()
        // function to ensure the last instruction is concat, then we have
        // output inserted
Shucai Xiao's avatar
Shucai Xiao committed
250
        if(i < seq_len - 1)
Shucai Xiao's avatar
Shucai Xiao committed
251
        {
Shucai Xiao's avatar
Shucai Xiao committed
252
253
            if(is_forward)
            {
Shucai Xiao's avatar
Shucai Xiao committed
254
255
256
257
                hidden_out =
                    (seq_index == 0)
                        ? last_out
                        : prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out);
Shucai Xiao's avatar
Shucai Xiao committed
258
259
260
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
261
262
263
264
                hidden_out =
                    (seq_index == seq_len - 1)
                        ? last_out
                        : prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out);
Shucai Xiao's avatar
Shucai Xiao committed
265
            }
Shucai Xiao's avatar
Shucai Xiao committed
266
267
268
        }
    }

269
    return {hidden_out, last_out};
Shucai Xiao's avatar
Shucai Xiao committed
270
271
}

Shucai Xiao's avatar
Shucai Xiao committed
272
std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins) const
273
274
{
    auto rnn_op = any_cast<op::rnn>(ins->get_operator());
Shucai Xiao's avatar
Shucai Xiao committed
275
276
277
278
    // could be 3 to 6 inputs, but the parse_gru function will
    // append undefined operators to make 6 arguments when parsing
    // an onnx file. Another case is user can have any num of arguments
    // when writing their program.
279
    if(rnn_op.direction == op::rnn_direction::bidirectional)
280
    {
Shucai Xiao's avatar
Shucai Xiao committed
281
        if(rnn_op.actv_funcs.empty())
282
283
284
285
        {
            // default is tanh
            return {op::tanh{}, op::tanh{}};
        }
Shucai Xiao's avatar
Shucai Xiao committed
286
        else if(rnn_op.actv_funcs.size() == 1)
287
288
289
290
291
292
293
294
295
296
        {
            return {rnn_op.actv_funcs.at(0), rnn_op.actv_funcs.at(0)};
        }
        else
        {
            return rnn_op.actv_funcs;
        }
    }
    else
    {
Shucai Xiao's avatar
Shucai Xiao committed
297
        if(rnn_op.actv_funcs.empty())
298
299
300
301
302
303
304
305
306
307
308
        {
            // default is tanh
            return {op::tanh{}};
        }
        else
        {
            return rnn_op.actv_funcs;
        }
    }
}

309
310
311
312
void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
{
    assert(ins->name() == "gru");
    const auto actv_funcs = gru_actv_funcs(ins);
Shucai Xiao's avatar
Shucai Xiao committed
313
314
315
316
    // could be 3 to 6 inputs, but the parse_gru function will
    // append undefined operators to make 6 arguments when parsing
    // an onnx file. Another case is user can have num of arguments
    // when writing their program.
317
318
319
320
321
322
323
324
325
    auto args = ins->inputs();

    shape seq_shape         = args[0]->get_shape();
    std::size_t hidden_size = args[2]->get_shape().lens()[2];
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
    migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
    std::vector<float> data(ih_shape.elements(), 0.0);

Shucai Xiao's avatar
Shucai Xiao committed
326
    auto gru_op             = any_cast<op::gru>(ins->get_operator());
327
    op::rnn_direction dicrt = gru_op.direction;
328
    instruction_ref last_output{};
329
    if(dicrt == op::rnn_direction::bidirectional)
330
331
332
333
334
335
336
337
338
339
340
341
    {
        // w weight matrix
        auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
        auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);

        // r weight matrix
        auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
        auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);

        // bias
        instruction_ref bias_forward = prog.end();
        instruction_ref bias_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
342
        if(args.size() >= 4 && args[3]->name() != "undefined")
343
344
345
346
347
348
349
350
        {
            bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
            bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
        }

        // intial hidden state
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
351
        if(args.size() == 6 && args[5]->name() != "undefined")
352
353
354
355
356
357
358
359
360
361
        {
            ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
            ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
        }
        else
        {
            ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
            ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
        }

Shucai Xiao's avatar
Shucai Xiao committed
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
        auto ret_forward = gru_cell(true,
                                    prog,
                                    ins,
                                    {args[0], w_forward, r_forward, bias_forward, ih_forward},
                                    gru_op.linear_before_reset,
                                    actv_funcs.at(0),
                                    actv_funcs.at(1));

        auto ret_reverse = gru_cell(false,
                                    prog,
                                    ins,
                                    {args[0], w_reverse, r_reverse, bias_reverse, ih_reverse},
                                    gru_op.linear_before_reset,
                                    actv_funcs.at(2),
                                    actv_funcs.at(3));
377
378
379
380
381
382
383
384
385

        auto concat_output =
            prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
        last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);

        // The following logic is to ensure the last instruction rewritten
        // from gru operator is a concat
        if(ret_forward[0] == prog.end())
        {
Shucai Xiao's avatar
Shucai Xiao committed
386
            prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
387
388
389
390
391
392
393
        }
        else
        {
            ret_forward[0] =
                prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
            ret_reverse[0] =
                prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
Shucai Xiao's avatar
Shucai Xiao committed
394
            prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
395
396
397
398
        }
    }
    else
    {
399
        bool is_forward = (dicrt == op::rnn_direction::forward);
400
401
402
403
404
405
        // weight matrix
        auto w = args[1];
        auto r = args[2];

        // bias
        instruction_ref bias = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
406
        if(args.size() >= 4 && args[3]->name() != "undefined")
407
408
409
410
411
412
        {
            bias = args[3];
        }

        // intial hidden state
        instruction_ref ih{};
Shucai Xiao's avatar
Shucai Xiao committed
413
        if(args.size() == 6 && args[5]->name() != "undefined")
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
        {
            ih = args[5];
        }
        else
        {
            ih = prog.add_literal(migraphx::literal{ih_shape, data});
        }

        auto ret = gru_cell(is_forward,
                            prog,
                            ins,
                            {args[0], w, r, bias, ih},
                            gru_op.linear_before_reset,
                            actv_funcs.at(0),
                            actv_funcs.at(1));

        last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);

        if(ret[0] == prog.end())
        {
Shucai Xiao's avatar
Shucai Xiao committed
434
            prog.replace_instruction(ins, op::concat{0}, ret[1]);
435
436
437
438
439
        }
        else
        {
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
Shucai Xiao's avatar
Shucai Xiao committed
440
            prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
441
442
443
        }
    }

444
445
446
    // replace the corresponding rnn_last_output instruction
    // with the last_output, if rnn_last_output exists
    // while loop to handle case of multiple rnn_last_output operators
447
448
449
450
    auto last_output_it = ins->outputs().begin();
    while(last_output_it != ins->outputs().end())
    {
        last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
451
            return i->name() == "rnn_last_output";
452
453
454
455
456
457
458
459
460
461
462
        });

        if(last_output_it != ins->outputs().end())
        {
            prog.replace_instruction(*last_output_it, last_output);
            last_output_it++;
        }
    }
}

std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
463
464
465
466
467
468
                                                   program& prog,
                                                   instruction_ref ins,
                                                   std::vector<instruction_ref> inputs,
                                                   int linear_before_reset,
                                                   const operation& actv_func1,
                                                   const operation& actv_func2) const
469
470
471
472
473
474
475
476
{
    assert(inputs.size() == 5);
    auto seq  = inputs.at(0);
    auto w    = inputs.at(1);
    auto r    = inputs.at(2);
    auto bias = inputs.at(3);
    auto ih   = inputs.at(4);

Shucai Xiao's avatar
Shucai Xiao committed
477
478
    instruction_ref hidden_states = prog.end();
    instruction_ref last_output{};
Shucai Xiao's avatar
Shucai Xiao committed
479
    migraphx::shape seq_shape = seq->get_shape();
Shucai Xiao's avatar
Shucai Xiao committed
480
481
482
    migraphx::shape r_shape   = r->get_shape();
    long seq_len              = static_cast<long>(seq_shape.lens()[0]);
    long hs                   = static_cast<long>(r_shape.lens()[2]);
483

Shucai Xiao's avatar
Shucai Xiao committed
484
    migraphx::shape s(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]});
485
    std::vector<float> data(s.elements(), 1.0f);
486
487
    auto l1 = prog.add_literal(migraphx::literal{s, data});

488
    // w matrix squeeze to 2-dim and do a transpose
489
    std::vector<int64_t> perm{1, 0};
Shucai Xiao's avatar
Shucai Xiao committed
490
    auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
491
    auto tw = prog.insert_instruction(ins, op::transpose{perm}, sw);
492

493
    // r slide to two part, zr and h
Shucai Xiao's avatar
Shucai Xiao committed
494
495
496
    auto sr   = prog.insert_instruction(ins, op::squeeze{{0}}, r);
    auto rzr  = prog.insert_instruction(ins, op::slice{{0}, {0}, {2 * hs}}, sr);
    auto trzr = prog.insert_instruction(ins, op::transpose{perm}, rzr);
497

Shucai Xiao's avatar
Shucai Xiao committed
498
    auto rh  = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
499
    auto trh = prog.insert_instruction(ins, op::transpose{perm}, rh);
500
501

    // initial states
Shucai Xiao's avatar
Shucai Xiao committed
502
503
    auto sih  = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
    size_t bs = ih->get_shape().lens()[1];
504
505

    // bias
506
507
508
    instruction_ref bwb{};
    instruction_ref brb_zr{};
    instruction_ref brb_h{};
509
510
511
    if(bias != prog.end())
    {
        auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
Shucai Xiao's avatar
Shucai Xiao committed
512
513
        auto wb    = prog.insert_instruction(ins, op::slice{{0}, {0}, {3 * hs}}, sbias);
        bwb = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(3 * hs)}}, wb);
514
515

        auto rb_zr = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {5 * hs}}, sbias);
Shucai Xiao's avatar
Shucai Xiao committed
516
517
518
519
        auto rb_h  = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
        brb_zr     = prog.insert_instruction(
            ins, op::broadcast{1, {bs, static_cast<size_t>(2 * hs)}}, rb_zr);
        brb_h = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(hs)}}, rb_h);
520
521
522
523
524
525
526
527
    }

    for(long i = 0; i < seq_len; i++)
    {
        long seq_index = is_forward ? i : (seq_len - 1 - i);
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
        xt      = prog.insert_instruction(ins, op::squeeze{{0}}, xt);

528
529
        auto xt_w    = prog.insert_instruction(ins, op::dot{}, xt, tw);
        auto ih1_rzr = prog.insert_instruction(ins, op::dot{}, sih, trzr);
Shucai Xiao's avatar
Shucai Xiao committed
530
        if(bias != prog.end())
531
        {
532
533
            xt_w    = prog.insert_instruction(ins, op::add{}, xt_w, bwb);
            ih1_rzr = prog.insert_instruction(ins, op::add{}, ih1_rzr, brb_zr);
534
535
536
537
538
539
540
541
542
543
        }

        auto xw_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_w);
        auto xw_r = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, xt_w);
        auto xw_h = prog.insert_instruction(ins, op::slice{{1}, {2 * hs}, {3 * hs}}, xt_w);

        auto hr_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, ih1_rzr);
        auto hr_r = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, ih1_rzr);

        auto xw_hr_z = prog.insert_instruction(ins, op::add{}, xw_z, hr_z);
Shucai Xiao's avatar
Shucai Xiao committed
544
        auto zt      = prog.insert_instruction(ins, actv_func1, xw_hr_z);
545
546

        auto xw_hr_r = prog.insert_instruction(ins, op::add{}, xw_r, hr_r);
Shucai Xiao's avatar
Shucai Xiao committed
547
        auto rt      = prog.insert_instruction(ins, actv_func1, xw_hr_r);
548
549
550
551
552
553

        instruction_ref hr_h{};
        if(linear_before_reset == 0)
        {
            // equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
            auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
Shucai Xiao's avatar
Shucai Xiao committed
554
            if(bias != prog.end())
555
            {
Shucai Xiao's avatar
Shucai Xiao committed
556
                hr_h = prog.insert_instruction(ins, op::dot{}, rt_ht1, trh, brb_h);
557
            }
558
559
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
560
                hr_h = prog.insert_instruction(ins, op::dot{}, rt_ht1, trh);
561
            }
562
563
564
        }
        else
        {
565
566
            // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
            instruction_ref ht1_rh{};
Shucai Xiao's avatar
Shucai Xiao committed
567
            if(bias != prog.end())
568
            {
569
                ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, trh, brb_h);
570
            }
571
            else
572
            {
Shucai Xiao's avatar
Shucai Xiao committed
573
                ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, trh);
574
            }
Shucai Xiao's avatar
Shucai Xiao committed
575
            hr_h = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
576
577
        }

578
        auto xw_hr_h = prog.insert_instruction(ins, op::add{}, xw_h, hr_h);
Shucai Xiao's avatar
Shucai Xiao committed
579
        auto ht      = prog.insert_instruction(ins, actv_func2, xw_hr_h);
580

581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
        // equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
        auto one_minus_zt    = prog.insert_instruction(ins, op::sub{}, l1, zt);
        auto one_minus_zt_ht = prog.insert_instruction(ins, op::mul{}, one_minus_zt, ht);
        auto zt_ht1          = prog.insert_instruction(ins, op::mul{}, zt, sih);
        sih                  = prog.insert_instruction(ins, op::add{}, one_minus_zt_ht, zt_ht1);
        last_output          = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);

        if(i < seq_len - 1)
        {
            if(is_forward)
            {
                hidden_states =
                    (seq_index == 0)
                        ? last_output
                        : prog.insert_instruction(ins, op::concat{0}, hidden_states, last_output);
            }
            else
            {
                hidden_states =
                    (seq_index == seq_len - 1)
                        ? last_output
                        : prog.insert_instruction(ins, op::concat{0}, last_output, hidden_states);
            }
        }
    }

    return {hidden_states, last_output};
}

std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
{
    auto gru_op = any_cast<op::gru>(ins->get_operator());
    // before rewrite the gru operator, need to ensure
    // we have 4 actv funcs, even though a user does not
    // specifiy any actv func. If less than 4, use the
    // algorithm in parse_gru to make 4 actv functions
617
    if(gru_op.direction == op::rnn_direction::bidirectional)
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
    {
        if(gru_op.actv_funcs.empty())
            return {op::sigmoid{}, op::tanh{}, op::sigmoid{}, op::tanh{}};
        else if(gru_op.actv_funcs.size() == 1)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0)};
        else if(gru_op.actv_funcs.size() == 2)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1)};
        else if(gru_op.actv_funcs.size() == 3)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1),
                    gru_op.actv_funcs.at(2),
                    gru_op.actv_funcs.at(0)};
        else
            return gru_op.actv_funcs;
    }
    else
    {
        if(gru_op.actv_funcs.empty())
            return {op::sigmoid{}, op::tanh{}};
        else if(gru_op.actv_funcs.size() == 1)
            return {gru_op.actv_funcs.at(0), gru_op.actv_funcs.at(0)};
        else
            return gru_op.actv_funcs;
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
650
651
652
653
654
// for lstm operators
void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
{
    assert(ins->name() == "lstm");
    auto args = ins->inputs();
Shucai Xiao's avatar
Shucai Xiao committed
655

Shucai Xiao's avatar
Shucai Xiao committed
656
    shape seq_shape         = args[0]->get_shape();
657
    std::size_t hidden_size = args[2]->get_shape().lens()[2];
Shucai Xiao's avatar
Shucai Xiao committed
658
659
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
Shucai Xiao's avatar
Shucai Xiao committed
660
    migraphx::shape ihc_shape{type, {1, batch_size, hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
661
    std::vector<float> ihc_data(ihc_shape.elements(), 0.0);
Shucai Xiao's avatar
Shucai Xiao committed
662
663

    migraphx::shape pph_shape{type, {1, 3 * hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
664
    std::vector<float> pph_data(pph_shape.elements(), 0.0);
Shucai Xiao's avatar
Shucai Xiao committed
665

Shucai Xiao's avatar
Shucai Xiao committed
666
667
    auto actv_funcs         = lstm_actv_funcs(ins);
    auto lstm_op            = any_cast<op::lstm>(ins->get_operator());
Shucai Xiao's avatar
Shucai Xiao committed
668
    op::rnn_direction dirct = lstm_op.direction;
Shucai Xiao's avatar
Shucai Xiao committed
669
670
671

    instruction_ref last_output{};
    instruction_ref last_cell_output{};
Shucai Xiao's avatar
Shucai Xiao committed
672
    if(dirct == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
673
674
675
676
677
678
679
680
681
682
683
684
685
    {
        // input weight matrix
        // input weight matrix
        auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
        auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);

        // hidden state weight matrix
        auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
        auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);

        // process bias
        instruction_ref bias_forward = prog.end();
        instruction_ref bias_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
686
        if(args.size() >= 4 && args[3]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
        {
            bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
            bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
        }

        // process intial hidden state, it is the 6th argument
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
        if(args.size() >= 6 && args[5]->name() != "undefined")
        {
            ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
            ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
        }
        else
        {
            ih_forward = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
            ih_reverse = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
        }

        // process initial cell value
        instruction_ref ic_forward{};
        instruction_ref ic_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
709
        if(args.size() >= 7 && args[6]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
710
711
712
713
714
715
716
717
718
719
720
        {
            ic_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[6]);
            ic_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[6]);
        }
        else
        {
            ic_forward = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
            ic_reverse = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
        }

        // process weight of the peephole
Shucai Xiao's avatar
Shucai Xiao committed
721
722
        instruction_ref pph_forward = prog.end();
        instruction_ref pph_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
723
        if(args.size() == 8 && args[7]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
724
725
726
727
        {
            pph_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[7]);
            pph_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[7]);
        }
Shucai Xiao's avatar
Shucai Xiao committed
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748

        auto ret_forward = lstm_cell(
            true,
            prog,
            ins,
            {args[0], w_forward, r_forward, bias_forward, ih_forward, ic_forward, pph_forward},
            actv_funcs.at(0),
            actv_funcs.at(1),
            actv_funcs.at(2));

        auto ret_reverse = lstm_cell(
            false,
            prog,
            ins,
            {args[0], w_reverse, r_reverse, bias_reverse, ih_reverse, ic_reverse, pph_reverse},
            actv_funcs.at(3),
            actv_funcs.at(4),
            actv_funcs.at(5));

        auto concat_output =
            prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
Shucai Xiao's avatar
Shucai Xiao committed
749
750
751
        last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);

        // last cell output
752
753
        last_cell_output =
            prog.insert_instruction(ins, op::concat{0}, ret_forward[2], ret_reverse[2]);
Shucai Xiao's avatar
Shucai Xiao committed
754
755

        // the following logic is to ensure the last instruction is a concat
Shucai Xiao's avatar
Shucai Xiao committed
756
        if(ret_forward[0] == prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
757
758
759
760
761
        {
            prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
762
763
764
765
766
            ret_forward[0] =
                prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
            ret_reverse[0] =
                prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
            prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
Shucai Xiao's avatar
Shucai Xiao committed
767
768
769
770
        }
    }
    else
    {
Shucai Xiao's avatar
Shucai Xiao committed
771
        bool is_forward = (dirct == op::rnn_direction::forward);
Shucai Xiao's avatar
Shucai Xiao committed
772
773
774
775
776
777
        // weight matrices
        auto w = args[1];
        auto r = args[2];

        // bias
        instruction_ref bias = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
778
        if(args.size() >= 4 && args[3]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
779
780
781
782
783
784
        {
            bias = args[3];
        }

        // initial hidden state
        instruction_ref ih{};
Shucai Xiao's avatar
Shucai Xiao committed
785
        if(args.size() >= 6 && args[5]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
786
787
788
789
790
791
792
793
794
795
        {
            ih = args[5];
        }
        else
        {
            ih = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
        }

        // initial cell value
        instruction_ref ic{};
Shucai Xiao's avatar
Shucai Xiao committed
796
        if(args.size() >= 7 && args[6]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
797
798
799
800
801
802
803
804
805
        {
            ic = args[6];
        }
        else
        {
            ic = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
        }

        // process weight of the peephole
Shucai Xiao's avatar
Shucai Xiao committed
806
        instruction_ref pph = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
807
808
809
810
        if(args.size() == 8 && args[7]->name() != "undefined")
        {
            pph = args[7];
        }
Shucai Xiao's avatar
Shucai Xiao committed
811
812

        auto ret = lstm_cell(is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
813
814
815
816
817
818
819
                             prog,
                             ins,
                             {args[0], w, r, bias, ih, ic, pph},
                             actv_funcs.at(0),
                             actv_funcs.at(1),
                             actv_funcs.at(2));

Shucai Xiao's avatar
Shucai Xiao committed
820
        last_output      = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
821
        last_cell_output = ret[2];
Shucai Xiao's avatar
Shucai Xiao committed
822
        if(ret[0] == prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
823
824
825
826
827
828
829
830
831
832
833
834
835
        {
            prog.replace_instruction(ins, op::concat{0}, ret[1]);
        }
        else
        {
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
            prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
        }
    }

    // replace the corresponding lstm_last_output instruction
    // with the last_output, and the lstm_last_cell_output with
Shucai Xiao's avatar
Shucai Xiao committed
836
    // the last_cell_output. The while loop is to handle the case
Shucai Xiao's avatar
Shucai Xiao committed
837
838
839
840
841
842
    // of multiple lstm_last_output and lstm_last_cell_output
    // operators
    auto last_output_it = ins->outputs().begin();
    while(last_output_it != ins->outputs().end())
    {
        last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
843
            return i->name() == "rnn_last_output";
Shucai Xiao's avatar
Shucai Xiao committed
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
        });

        if(last_output_it != ins->outputs().end())
        {
            prog.replace_instruction(*last_output_it, last_output);
            last_output_it++;
        }
    }

    auto last_cell_output_it = ins->outputs().begin();
    while(last_cell_output_it != ins->outputs().end())
    {
        last_cell_output_it = std::find_if(last_cell_output_it, ins->outputs().end(), [](auto i) {
            return i->name() == "lstm_last_cell_output";
        });

        if(last_cell_output_it != ins->outputs().end())
        {
            prog.replace_instruction(*last_cell_output_it, last_cell_output);
            last_cell_output_it++;
        }
Shucai Xiao's avatar
Shucai Xiao committed
865
    }
Shucai Xiao's avatar
Shucai Xiao committed
866
867
868
}

std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
869
870
871
872
873
874
                                                    program& prog,
                                                    instruction_ref ins,
                                                    std::vector<instruction_ref> inputs,
                                                    const operation& actv_func1,
                                                    const operation& actv_func2,
                                                    const operation& actv_func3) const
Shucai Xiao's avatar
Shucai Xiao committed
875
{
Shucai Xiao's avatar
Shucai Xiao committed
876
877
    // must have 7 args in the input vector
    assert(inputs.size() == 7);
Shucai Xiao's avatar
Shucai Xiao committed
878
879
880
    auto seq  = inputs.at(0);
    auto w    = inputs.at(1);
    auto r    = inputs.at(2);
Shucai Xiao's avatar
Shucai Xiao committed
881
    auto bias = inputs.at(3);
Shucai Xiao's avatar
Shucai Xiao committed
882
883
884
    auto ih   = inputs.at(4);
    auto ic   = inputs.at(5);
    auto pph  = inputs.at(6);
Shucai Xiao's avatar
Shucai Xiao committed
885

Shucai Xiao's avatar
Shucai Xiao committed
886
887
888
889
890
    instruction_ref hidden_states = prog.end();
    instruction_ref last_output{};
    instruction_ref last_cell_output{};

    migraphx::shape seq_shape = seq->get_shape();
Shucai Xiao's avatar
Shucai Xiao committed
891
    migraphx::shape r_shape   = r->get_shape();
Shucai Xiao's avatar
Shucai Xiao committed
892
    long seq_len              = static_cast<long>(seq_shape.lens()[0]);
Shucai Xiao's avatar
Shucai Xiao committed
893
    long hs                   = static_cast<long>(r_shape.lens()[2]);
Shucai Xiao's avatar
Shucai Xiao committed
894
    auto bs                   = ih->get_shape().lens()[1];
Shucai Xiao's avatar
Shucai Xiao committed
895
896

    std::vector<int64_t> perm{1, 0};
897
    // w matrix, squeeze and transpose
Shucai Xiao's avatar
Shucai Xiao committed
898
    auto sw  = prog.insert_instruction(ins, op::squeeze{{0}}, w);
899
    auto tsw = prog.insert_instruction(ins, op::transpose{perm}, sw);
Shucai Xiao's avatar
Shucai Xiao committed
900

901
    // r matrix, squeeze and transpose
Shucai Xiao's avatar
Shucai Xiao committed
902
    auto sr  = prog.insert_instruction(ins, op::squeeze{{0}}, r);
903
    auto tsr = prog.insert_instruction(ins, op::transpose{perm}, sr);
Shucai Xiao's avatar
Shucai Xiao committed
904

Shucai Xiao's avatar
Shucai Xiao committed
905
906
907
908
    // initial hidden state
    auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);

    // initial cell state
Shucai Xiao's avatar
Shucai Xiao committed
909
    auto sic     = prog.insert_instruction(ins, op::squeeze{{0}}, ic);
910
    auto ic_lens = sic->get_shape().lens();
Shucai Xiao's avatar
Shucai Xiao committed
911

912
    // bias
913
    instruction_ref wrb{};
Shucai Xiao's avatar
Shucai Xiao committed
914
    if(bias != prog.end())
915
    {
916

Shucai Xiao's avatar
Shucai Xiao committed
917
918
919
        auto sbias  = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
        auto ub_wb  = prog.insert_instruction(ins, op::slice{{0}, {0}, {4 * hs}}, sbias);
        auto ub_rb  = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {8 * hs}}, sbias);
920
        auto ub_wrb = prog.insert_instruction(ins, op::add{}, ub_wb, ub_rb);
921

922
923
        wrb = prog.insert_instruction(
            ins, op::broadcast{1, {bs, 4 * static_cast<size_t>(hs)}}, ub_wrb);
924
925
    }

Shucai Xiao's avatar
Shucai Xiao committed
926
    // peep hole
Shucai Xiao's avatar
Shucai Xiao committed
927
928
929
    instruction_ref pphi_brcst{};
    instruction_ref ppho_brcst{};
    instruction_ref pphf_brcst{};
Shucai Xiao's avatar
Shucai Xiao committed
930
    if(pph != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
931
    {
Shucai Xiao's avatar
Shucai Xiao committed
932
933
        auto spph  = prog.insert_instruction(ins, op::squeeze{{0}}, pph);
        auto pphi  = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph);
934
        pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, pphi);
Shucai Xiao's avatar
Shucai Xiao committed
935

Shucai Xiao's avatar
Shucai Xiao committed
936
        auto ppho  = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph);
937
        ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, ppho);
Shucai Xiao's avatar
Shucai Xiao committed
938

Shucai Xiao's avatar
Shucai Xiao committed
939
        auto pphf  = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph);
940
        pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, pphf);
Shucai Xiao's avatar
Shucai Xiao committed
941
    }
Shucai Xiao's avatar
Shucai Xiao committed
942

Shucai Xiao's avatar
Shucai Xiao committed
943
    for(long i = 0; i < seq_len; ++i)
Shucai Xiao's avatar
Shucai Xiao committed
944
945
946
    {
        long seq_index = is_forward ? i : (seq_len - 1 - i);
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
Shucai Xiao's avatar
Shucai Xiao committed
947
        xt      = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
Shucai Xiao's avatar
Shucai Xiao committed
948

949
950
        auto xt_tsw  = prog.insert_instruction(ins, op::dot{}, xt, tsw);
        auto sih_tsr = prog.insert_instruction(ins, op::dot{}, sih, tsr);
Shucai Xiao's avatar
Shucai Xiao committed
951
        auto xt_sih  = prog.insert_instruction(ins, op::add{}, xt_tsw, sih_tsr);
Shucai Xiao's avatar
Shucai Xiao committed
952
        if(bias != prog.end())
953
        {
Shucai Xiao's avatar
Shucai Xiao committed
954
            xt_sih = prog.insert_instruction(ins, op::add{}, xt_sih, wrb);
955
        }
Shucai Xiao's avatar
Shucai Xiao committed
956

957
958
        auto it_before_actv = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_sih);
        auto ot_before_actv = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, xt_sih);
Shucai Xiao's avatar
Shucai Xiao committed
959
960
961
962
        auto ft_before_actv =
            prog.insert_instruction(ins, op::slice{{1}, {2 * hs}, {3 * hs}}, xt_sih);
        auto ct_before_actv =
            prog.insert_instruction(ins, op::slice{{1}, {3 * hs}, {4 * hs}}, xt_sih);
963

Shucai Xiao's avatar
Shucai Xiao committed
964
        if(pph != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
965
        {
Shucai Xiao's avatar
Shucai Xiao committed
966
967
968
            auto pphi_ct   = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic);
            it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, pphi_ct);

Shucai Xiao's avatar
Shucai Xiao committed
969
970
            auto pphf_ct   = prog.insert_instruction(ins, op::mul{}, pphf_brcst, sic);
            ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, pphf_ct);
Shucai Xiao's avatar
Shucai Xiao committed
971
        }
Shucai Xiao's avatar
Shucai Xiao committed
972
        auto it = prog.insert_instruction(ins, actv_func1, it_before_actv);
973
        auto ft = prog.insert_instruction(ins, actv_func1, ft_before_actv);
974
        auto ct = prog.insert_instruction(ins, actv_func2, ct_before_actv);
975
976

        // equation Ct = ft (.) Ct-1 + it (.) ct
Shucai Xiao's avatar
Shucai Xiao committed
977
978
979
        auto ft_cell     = prog.insert_instruction(ins, op::mul{}, ft, sic);
        auto it_ct       = prog.insert_instruction(ins, op::mul{}, it, ct);
        auto cellt       = prog.insert_instruction(ins, op::add{}, ft_cell, it_ct);
980
981
        last_cell_output = cellt;

Shucai Xiao's avatar
Shucai Xiao committed
982
        if(pph != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
983
        {
Shucai Xiao's avatar
Shucai Xiao committed
984
985
            auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt);
            ot_before_actv  = prog.insert_instruction(ins, op::add{}, ot_before_actv, ppho_cellt);
Shucai Xiao's avatar
Shucai Xiao committed
986
        }
987
988
989
990
        auto ot = prog.insert_instruction(ins, actv_func1, ot_before_actv);

        // Ht = ot (.) h(Ct)
        auto h_cellt = prog.insert_instruction(ins, actv_func3, cellt);
Shucai Xiao's avatar
Shucai Xiao committed
991
        auto ht      = prog.insert_instruction(ins, op::mul{}, ot, h_cellt);
992
993
994
995
996
997

        sic = cellt;
        sih = ht;

        last_output = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht);

Shucai Xiao's avatar
Shucai Xiao committed
998
        if(i < seq_len - 1)
999
        {
Shucai Xiao's avatar
Shucai Xiao committed
1000
            if(i == 0)
1001
            {
Shucai Xiao's avatar
Shucai Xiao committed
1002
                hidden_states = last_output;
1003
1004
1005
1006
1007
            }
            else
            {
                auto concat_arg0 = is_forward ? hidden_states : last_output;
                auto concat_arg1 = is_forward ? last_output : hidden_states;
Shucai Xiao's avatar
Shucai Xiao committed
1008
1009
                hidden_states =
                    prog.insert_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
1010
1011
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
1012
    }
1013
1014
1015
1016

    last_cell_output = prog.insert_instruction(ins, op::unsqueeze{{0}}, last_cell_output);

    return {hidden_states, last_output, last_cell_output};
Shucai Xiao's avatar
Shucai Xiao committed
1017
1018
1019
1020
1021
1022
1023
1024
1025
}

std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
{
    auto lstm_op = any_cast<op::lstm>(ins->get_operator());
    // before rewrite the lstm operator, need to ensure
    // we have 6 actv funcs, even though a user does not
    // specifiy any actv func. If less than 46, use the
    // algorithm in parse_lstm to make 6 actv functions
Shucai Xiao's avatar
Shucai Xiao committed
1026
    const auto& actv_funcs     = lstm_op.actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1027
    std::size_t num_actv_funcs = actv_funcs.size();
Shucai Xiao's avatar
Shucai Xiao committed
1028
    if(lstm_op.direction == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1029
1030
1031
1032
    {
        switch(num_actv_funcs)
        {
        case 0:
Shucai Xiao's avatar
Shucai Xiao committed
1033
            return {op::sigmoid{}, op::tanh{}, op::tanh{}, op::sigmoid{}, op::tanh{}, op::tanh{}};
Shucai Xiao's avatar
Shucai Xiao committed
1034
1035

        case 1:
Shucai Xiao's avatar
Shucai Xiao committed
1036
1037
1038
1039
1040
1041
            return {actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0)};
Shucai Xiao's avatar
Shucai Xiao committed
1042
1043

        case 2:
Shucai Xiao's avatar
Shucai Xiao committed
1044
1045
1046
1047
1048
1049
1050
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(1),
                    actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(1)};

Shucai Xiao's avatar
Shucai Xiao committed
1051
        case 3:
Shucai Xiao's avatar
Shucai Xiao committed
1052
1053
1054
1055
1056
1057
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2)};
Shucai Xiao's avatar
Shucai Xiao committed
1058
1059

        case 4:
Shucai Xiao's avatar
Shucai Xiao committed
1060
1061
1062
1063
1064
1065
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(3),
                    actv_funcs.at(3),
                    actv_funcs.at(3)};
Shucai Xiao's avatar
Shucai Xiao committed
1066
1067

        case 5:
Shucai Xiao's avatar
Shucai Xiao committed
1068
1069
1070
1071
1072
1073
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(3),
                    actv_funcs.at(4),
                    actv_funcs.at(4)};
Shucai Xiao's avatar
Shucai Xiao committed
1074

Shucai Xiao's avatar
Shucai Xiao committed
1075
        default: return actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1076
1077
1078
1079
1080
1081
        }
    }
    else
    {
        switch(num_actv_funcs)
        {
Shucai Xiao's avatar
Shucai Xiao committed
1082
        case 0: return {op::sigmoid{}, op::tanh{}, op::tanh{}};
Shucai Xiao's avatar
Shucai Xiao committed
1083

Shucai Xiao's avatar
Shucai Xiao committed
1084
        case 1: return {actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0)};
Shucai Xiao's avatar
Shucai Xiao committed
1085

Shucai Xiao's avatar
Shucai Xiao committed
1086
1087
1088
        case 2: return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1)};

        default: return actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1089
1090
1091
1092
        }
    }
}

1093
namespace op {
Shucai Xiao's avatar
Shucai Xiao committed
1094
std::ostream& operator<<(std::ostream& os, rnn_direction v)
1095
{
Shucai Xiao's avatar
Shucai Xiao committed
1096
1097
    std::vector<std::string> rnn_direction_str = {"forward", "reverse", "bidirectional"};
    os << rnn_direction_str[static_cast<std::underlying_type<rnn_direction>::type>(v)];
1098
1099
    return os;
}
Shucai Xiao's avatar
Shucai Xiao committed
1100
} // namespace op
1101

Shucai Xiao's avatar
Shucai Xiao committed
1102
1103
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx