rewrite_rnn.cpp 50.8 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
Paul's avatar
Paul committed
4
5
6
7
8
9
10
11
12
13
14
15
16
#include <migraphx/op/add.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/gru.hpp>
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/rnn.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/squeeze.hpp>
#include <migraphx/op/sub.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/unsqueeze.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
17
#include <migraphx/op/contiguous.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
18
19
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
20
#include <migraphx/op/common.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
21
22
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
23
24
25
26

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

Shucai Xiao's avatar
Shucai Xiao committed
27
void rewrite_rnn::apply(program& prog) const
Shucai Xiao's avatar
Shucai Xiao committed
28
29
30
{
    for(auto ins : iterator_for(prog))
    {
Shucai Xiao's avatar
Shucai Xiao committed
31
        if(ins->name() == "rnn")
Shucai Xiao's avatar
Shucai Xiao committed
32
        {
Shucai Xiao's avatar
Shucai Xiao committed
33
            apply_vanilla_rnn(prog, ins);
34
        }
35
        else if(ins->name() == "gru")
36
37
        {
            apply_gru(prog, ins);
Shucai Xiao's avatar
Shucai Xiao committed
38
        }
39
40
41
42
        else if(ins->name() == "lstm")
        {
            apply_lstm(prog, ins);
        }
Shucai Xiao's avatar
Shucai Xiao committed
43
    }
44
45
}

Shucai Xiao's avatar
Shucai Xiao committed
46
void rewrite_rnn::apply_vanilla_rnn(program& prog, instruction_ref ins) const
47
48
49
50
{
    assert(ins->name() == "rnn");
    // could be 3 to 6 inputs, but the parse_rnn function will
    // append undefined operators to make 6 arguments when parsing
Shucai Xiao's avatar
Shucai Xiao committed
51
    // an onnx file. Another case is user can have num of arguments
52
53
54
55
56
57
58
59
60
61
    // when writing their program.
    auto args = ins->inputs();

    shape seq_shape         = args[0]->get_shape();
    std::size_t hidden_size = args[1]->get_shape().lens()[1];
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
    migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
    std::vector<float> data(ih_shape.elements(), 0);

Shucai Xiao's avatar
Shucai Xiao committed
62
63
    auto actv_funcs         = vanilla_rnn_actv_funcs(ins);
    auto rnn_op             = any_cast<op::rnn>(ins->get_operator());
64
    op::rnn_direction dicrt = rnn_op.direction;
65
    instruction_ref last_output{};
66
    if(dicrt == op::rnn_direction::bidirectional)
67
68
69
70
71
72
73
74
75
76
77
78
    {
        // input weight matrix
        auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
        auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);

        // hidden state weight matrix
        auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
        auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);

        // process bias
        instruction_ref bias_forward = prog.end();
        instruction_ref bias_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
79
        if(args.size() >= 4 && args[3]->name() != "undefined")
80
81
82
83
84
85
86
87
88
        {
            bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
            bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
        }

        // process intial hidden state, it could be the 6th argument
        // or the 5th one (if the sequence len argument is ignored)
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
89
        if(args.size() == 6 && args[5]->name() != "undefined")
90
91
92
93
94
95
96
97
98
99
        {
            ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
            ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
        }
        else
        {
            ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
            ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
        }

Shucai Xiao's avatar
Shucai Xiao committed
100
        auto ret_forward = vanilla_rnn_cell(true,
Shucai Xiao's avatar
Shucai Xiao committed
101
102
103
104
105
106
107
108
                                            prog,
                                            ins,
                                            args[0],
                                            w_forward,
                                            r_forward,
                                            bias_forward,
                                            ih_forward,
                                            actv_funcs.at(0));
Shucai Xiao's avatar
Shucai Xiao committed
109
        auto ret_reverse = vanilla_rnn_cell(false,
Shucai Xiao's avatar
Shucai Xiao committed
110
111
112
113
114
115
116
117
                                            prog,
                                            ins,
                                            args[0],
                                            w_reverse,
                                            r_reverse,
                                            bias_reverse,
                                            ih_reverse,
                                            actv_funcs.at(1));
118
119
120
121
122
123
124
125
126
127

        auto concat_output =
            prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
        last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);

        // The following logic is to ensure the last instruction rewritten from
        // rnn operator is a concat instruction
        // sequence len is 1
        if(ret_forward[0] == prog.end())
        {
128
            prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
129
130
131
132
133
134
135
        }
        else
        {
            ret_forward[0] =
                prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
            ret_reverse[0] =
                prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
136
            prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
137
138
139
140
        }
    }
    else
    {
141
        bool is_forward = (dicrt == op::rnn_direction::forward);
142
143
144
145
146
147
148
149
        // input weight matrix
        auto w = args[1];

        // hidden state weight matrix
        auto r = args[2];

        // process bias and initial hidden state
        instruction_ref bias = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
150
        if(args.size() >= 4 && args[3]->name() != "undefined")
151
152
153
154
155
156
        {
            bias = args[3];
        }

        // process intial hidden state
        instruction_ref ih;
Shucai Xiao's avatar
Shucai Xiao committed
157
        if(args.size() == 6 && args[5]->name() != "undefined")
158
159
160
161
162
163
164
165
        {
            ih = args[5];
        }
        else
        {
            ih = prog.add_literal(migraphx::literal{ih_shape, data});
        }

Shucai Xiao's avatar
Shucai Xiao committed
166
167
        auto ret =
            vanilla_rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0));
168
169
170
171
172
173
174
        last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);

        // following logic is to ensure the last instruction is a
        // concat instruction
        // sequence len is 1
        if(ret[0] == prog.end())
        {
175
            prog.replace_instruction(ins, op::concat{0}, ret[1]);
176
177
178
179
180
        }
        else
        {
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
181
            prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
182
183
184
        }
    }

Shucai Xiao's avatar
Shucai Xiao committed
185
186
187
188
    // search its output to find if there are rnn_last_hs_output operator
    // while loop to handle case of multiple rnn_last_hs_output operators
    auto last_hs_output_it = ins->outputs().begin();
    while(last_hs_output_it != ins->outputs().end())
189
    {
Shucai Xiao's avatar
Shucai Xiao committed
190
191
        last_hs_output_it = std::find_if(last_hs_output_it, ins->outputs().end(), [](auto i) {
            return i->name() == "rnn_last_hs_output";
192
193
        });

Shucai Xiao's avatar
Shucai Xiao committed
194
        if(last_hs_output_it != ins->outputs().end())
195
        {
Shucai Xiao's avatar
Shucai Xiao committed
196
197
            prog.replace_instruction(*last_hs_output_it, last_output);
            last_hs_output_it++;
198
199
        }
    }
Shucai Xiao's avatar
Shucai Xiao committed
200
201
}

Shucai Xiao's avatar
Shucai Xiao committed
202
std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
203
204
205
206
207
208
209
210
                                                           program& prog,
                                                           instruction_ref ins,
                                                           instruction_ref input,
                                                           instruction_ref w,
                                                           instruction_ref r,
                                                           instruction_ref bias,
                                                           instruction_ref ih,
                                                           operation& actv_func) const
Shucai Xiao's avatar
Shucai Xiao committed
211
{
Shucai Xiao's avatar
Shucai Xiao committed
212
213
    // squeeze and transpose w
    std::vector<int64_t> perm{1, 0};
Shucai Xiao's avatar
Shucai Xiao committed
214
    auto sw      = prog.insert_instruction(ins, op::squeeze{{0}}, w);
Shucai Xiao's avatar
Shucai Xiao committed
215
    auto tran_sw = prog.insert_instruction(ins, op::transpose{perm}, sw);
Shucai Xiao's avatar
Shucai Xiao committed
216
217

    // squeeze and transpose r
Shucai Xiao's avatar
Shucai Xiao committed
218
    auto sr      = prog.insert_instruction(ins, op::squeeze{{0}}, r);
Shucai Xiao's avatar
Shucai Xiao committed
219
220
221
    auto tran_sr = prog.insert_instruction(ins, op::transpose{perm}, sr);

    // initial hidden state
Shucai Xiao's avatar
Shucai Xiao committed
222
    auto sih      = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
223
    auto sih_lens = sih->get_shape().lens();
Shucai Xiao's avatar
Shucai Xiao committed
224
225

    // bias
226
    instruction_ref bb{};
Shucai Xiao's avatar
Shucai Xiao committed
227
    if(bias != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
228
    {
229
        long hs    = static_cast<long>(r->get_shape().lens()[2]);
Shucai Xiao's avatar
Shucai Xiao committed
230
231
232
        auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
        auto wb    = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
        auto rb    = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
233
234
        auto wrb   = prog.insert_instruction(ins, op::add{}, wb, rb);
        bb         = prog.insert_instruction(ins, op::broadcast{1, sih_lens}, wrb);
Shucai Xiao's avatar
Shucai Xiao committed
235
236
    }

Shucai Xiao's avatar
Shucai Xiao committed
237
238
    instruction_ref hidden_out = prog.end();
    instruction_ref last_out{};
Shucai Xiao's avatar
Shucai Xiao committed
239
240
    last_out            = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);
    std::size_t seq_len = input->get_shape().lens()[0];
Shucai Xiao's avatar
Shucai Xiao committed
241
242
    for(std::size_t i = 0; i < seq_len; i++)
    {
Shucai Xiao's avatar
Shucai Xiao committed
243
        long seq_index = is_forward ? i : (seq_len - 1 - i);
Shucai Xiao's avatar
Shucai Xiao committed
244
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
Shucai Xiao's avatar
Shucai Xiao committed
245
246
247
248
        auto cont_xt = prog.insert_instruction(ins, op::contiguous{}, xt);
        xt           = prog.insert_instruction(ins, op::squeeze{{0}}, cont_xt);
        auto xt_wi   = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
        auto ht_ri   = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
Shucai Xiao's avatar
Shucai Xiao committed
249
        if(bias != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
250
        {
251
            xt_wi = prog.insert_instruction(ins, op::add{}, xt_wi, bb);
Shucai Xiao's avatar
Shucai Xiao committed
252
        }
253
        auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
Shucai Xiao's avatar
Shucai Xiao committed
254
255

        // apply activation function
Shucai Xiao's avatar
Shucai Xiao committed
256
257
        auto ht = prog.insert_instruction(ins, actv_func, xt_ht);
        sih     = ht;
Shucai Xiao's avatar
Shucai Xiao committed
258

Shucai Xiao's avatar
Shucai Xiao committed
259
260
261
        // add the dimensions of sequence length (axis 0 for sequence length,
        // axis 1 for num_directions
        last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht);
Shucai Xiao's avatar
Shucai Xiao committed
262

Shucai Xiao's avatar
Shucai Xiao committed
263
264
265
        // concatenation for the last last_out is performed in the apply()
        // function to ensure the last instruction is concat, then we have
        // output inserted
Shucai Xiao's avatar
Shucai Xiao committed
266
        if(i < seq_len - 1)
Shucai Xiao's avatar
Shucai Xiao committed
267
        {
Shucai Xiao's avatar
Shucai Xiao committed
268
269
            if(is_forward)
            {
Shucai Xiao's avatar
Shucai Xiao committed
270
271
272
273
                hidden_out =
                    (seq_index == 0)
                        ? last_out
                        : prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out);
Shucai Xiao's avatar
Shucai Xiao committed
274
275
276
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
277
278
279
280
                hidden_out =
                    (seq_index == seq_len - 1)
                        ? last_out
                        : prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out);
Shucai Xiao's avatar
Shucai Xiao committed
281
            }
Shucai Xiao's avatar
Shucai Xiao committed
282
283
284
        }
    }

285
    return {hidden_out, last_out};
Shucai Xiao's avatar
Shucai Xiao committed
286
287
}

Shucai Xiao's avatar
Shucai Xiao committed
288
std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins) const
289
290
{
    auto rnn_op = any_cast<op::rnn>(ins->get_operator());
Shucai Xiao's avatar
Shucai Xiao committed
291
292
293
294
    // could be 3 to 6 inputs, but the parse_gru function will
    // append undefined operators to make 6 arguments when parsing
    // an onnx file. Another case is user can have any num of arguments
    // when writing their program.
295
    if(rnn_op.direction == op::rnn_direction::bidirectional)
296
    {
Shucai Xiao's avatar
Shucai Xiao committed
297
        if(rnn_op.actv_funcs.empty())
298
299
300
301
        {
            // default is tanh
            return {op::tanh{}, op::tanh{}};
        }
Shucai Xiao's avatar
Shucai Xiao committed
302
        else if(rnn_op.actv_funcs.size() == 1)
303
304
305
306
307
308
309
310
311
312
        {
            return {rnn_op.actv_funcs.at(0), rnn_op.actv_funcs.at(0)};
        }
        else
        {
            return rnn_op.actv_funcs;
        }
    }
    else
    {
Shucai Xiao's avatar
Shucai Xiao committed
313
        if(rnn_op.actv_funcs.empty())
314
315
316
317
318
319
320
321
322
323
324
        {
            // default is tanh
            return {op::tanh{}};
        }
        else
        {
            return rnn_op.actv_funcs;
        }
    }
}

325
326
327
328
void rewrite_rnn::apply_gru(program& prog, instruction_ref ins) const
{
    assert(ins->name() == "gru");
    const auto actv_funcs = gru_actv_funcs(ins);
Shucai Xiao's avatar
Shucai Xiao committed
329
330
331
332
    // could be 3 to 6 inputs, but the parse_gru function will
    // append undefined operators to make 6 arguments when parsing
    // an onnx file. Another case is user can have num of arguments
    // when writing their program.
333
334
335
336
337
338
339
340
341
    auto args = ins->inputs();

    shape seq_shape         = args[0]->get_shape();
    std::size_t hidden_size = args[2]->get_shape().lens()[2];
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
    migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
    std::vector<float> data(ih_shape.elements(), 0.0);

Shucai Xiao's avatar
Shucai Xiao committed
342
    auto gru_op             = any_cast<op::gru>(ins->get_operator());
343
    op::rnn_direction dicrt = gru_op.direction;
344
    instruction_ref last_output{};
345
    if(dicrt == op::rnn_direction::bidirectional)
346
347
348
349
350
351
352
353
354
355
356
357
    {
        // w weight matrix
        auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
        auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);

        // r weight matrix
        auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
        auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);

        // bias
        instruction_ref bias_forward = prog.end();
        instruction_ref bias_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
358
        if(args.size() >= 4 && args[3]->name() != "undefined")
359
360
361
362
363
364
365
366
        {
            bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
            bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
        }

        // intial hidden state
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
367
        if(args.size() == 6 && args[5]->name() != "undefined")
368
369
370
371
372
373
374
375
376
377
        {
            ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
            ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
        }
        else
        {
            ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
            ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
        }

Shucai Xiao's avatar
Shucai Xiao committed
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
        auto ret_forward = gru_cell(true,
                                    prog,
                                    ins,
                                    {args[0], w_forward, r_forward, bias_forward, ih_forward},
                                    gru_op.linear_before_reset,
                                    actv_funcs.at(0),
                                    actv_funcs.at(1));

        auto ret_reverse = gru_cell(false,
                                    prog,
                                    ins,
                                    {args[0], w_reverse, r_reverse, bias_reverse, ih_reverse},
                                    gru_op.linear_before_reset,
                                    actv_funcs.at(2),
                                    actv_funcs.at(3));
393
394
395
396
397
398
399
400
401

        auto concat_output =
            prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
        last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);

        // The following logic is to ensure the last instruction rewritten
        // from gru operator is a concat
        if(ret_forward[0] == prog.end())
        {
Shucai Xiao's avatar
Shucai Xiao committed
402
            prog.replace_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
403
404
405
406
407
408
409
        }
        else
        {
            ret_forward[0] =
                prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
            ret_reverse[0] =
                prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
Shucai Xiao's avatar
Shucai Xiao committed
410
            prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
411
412
413
414
        }
    }
    else
    {
415
        bool is_forward = (dicrt == op::rnn_direction::forward);
416
417
418
419
420
421
        // weight matrix
        auto w = args[1];
        auto r = args[2];

        // bias
        instruction_ref bias = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
422
        if(args.size() >= 4 && args[3]->name() != "undefined")
423
424
425
426
427
428
        {
            bias = args[3];
        }

        // intial hidden state
        instruction_ref ih{};
Shucai Xiao's avatar
Shucai Xiao committed
429
        if(args.size() == 6 && args[5]->name() != "undefined")
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
        {
            ih = args[5];
        }
        else
        {
            ih = prog.add_literal(migraphx::literal{ih_shape, data});
        }

        auto ret = gru_cell(is_forward,
                            prog,
                            ins,
                            {args[0], w, r, bias, ih},
                            gru_op.linear_before_reset,
                            actv_funcs.at(0),
                            actv_funcs.at(1));

        last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);

        if(ret[0] == prog.end())
        {
Shucai Xiao's avatar
Shucai Xiao committed
450
            prog.replace_instruction(ins, op::concat{0}, ret[1]);
451
452
453
454
455
        }
        else
        {
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
Shucai Xiao's avatar
Shucai Xiao committed
456
            prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
457
458
459
        }
    }

Shucai Xiao's avatar
Shucai Xiao committed
460
461
462
463
464
    // replace the corresponding rnn_last_hs_output instruction
    // with the last_output, if rnn_last_hs_output exists
    // while loop to handle case of multiple rnn_last_hs_output operators
    auto last_hs_output_it = ins->outputs().begin();
    while(last_hs_output_it != ins->outputs().end())
465
    {
Shucai Xiao's avatar
Shucai Xiao committed
466
467
        last_hs_output_it = std::find_if(last_hs_output_it, ins->outputs().end(), [](auto i) {
            return i->name() == "rnn_last_hs_output";
468
469
        });

Shucai Xiao's avatar
Shucai Xiao committed
470
        if(last_hs_output_it != ins->outputs().end())
471
        {
Shucai Xiao's avatar
Shucai Xiao committed
472
473
            prog.replace_instruction(*last_hs_output_it, last_output);
            last_hs_output_it++;
474
475
476
477
478
        }
    }
}

std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
479
480
481
482
483
484
                                                   program& prog,
                                                   instruction_ref ins,
                                                   std::vector<instruction_ref> inputs,
                                                   int linear_before_reset,
                                                   const operation& actv_func1,
                                                   const operation& actv_func2) const
485
486
487
488
489
490
491
492
{
    assert(inputs.size() == 5);
    auto seq  = inputs.at(0);
    auto w    = inputs.at(1);
    auto r    = inputs.at(2);
    auto bias = inputs.at(3);
    auto ih   = inputs.at(4);

Shucai Xiao's avatar
Shucai Xiao committed
493
494
    instruction_ref hidden_states = prog.end();
    instruction_ref last_output{};
Shucai Xiao's avatar
Shucai Xiao committed
495
    migraphx::shape seq_shape = seq->get_shape();
Shucai Xiao's avatar
Shucai Xiao committed
496
497
498
    migraphx::shape r_shape   = r->get_shape();
    long seq_len              = static_cast<long>(seq_shape.lens()[0]);
    long hs                   = static_cast<long>(r_shape.lens()[2]);
499

Shucai Xiao's avatar
Shucai Xiao committed
500
    migraphx::shape s(seq_shape.type(), {seq_shape.lens()[1], r_shape.lens()[2]});
501
    std::vector<float> data(s.elements(), 1.0f);
502
503
    auto l1 = prog.add_literal(migraphx::literal{s, data});

504
    // w matrix squeeze to 2-dim and do a transpose
505
    std::vector<int64_t> perm{1, 0};
Shucai Xiao's avatar
Shucai Xiao committed
506
    auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
507
    auto tw = prog.insert_instruction(ins, op::transpose{perm}, sw);
508

509
    // r slide to two part, zr and h
Shucai Xiao's avatar
Shucai Xiao committed
510
511
512
    auto sr   = prog.insert_instruction(ins, op::squeeze{{0}}, r);
    auto rzr  = prog.insert_instruction(ins, op::slice{{0}, {0}, {2 * hs}}, sr);
    auto trzr = prog.insert_instruction(ins, op::transpose{perm}, rzr);
513

Shucai Xiao's avatar
Shucai Xiao committed
514
    auto rh  = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
515
    auto trh = prog.insert_instruction(ins, op::transpose{perm}, rh);
516
517

    // initial states
Shucai Xiao's avatar
Shucai Xiao committed
518
519
    auto sih  = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
    size_t bs = ih->get_shape().lens()[1];
520
521

    // bias
522
523
524
    instruction_ref bwb{};
    instruction_ref brb_zr{};
    instruction_ref brb_h{};
525
526
527
    if(bias != prog.end())
    {
        auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
Shucai Xiao's avatar
Shucai Xiao committed
528
529
        auto wb    = prog.insert_instruction(ins, op::slice{{0}, {0}, {3 * hs}}, sbias);
        bwb = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(3 * hs)}}, wb);
530
531

        auto rb_zr = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {5 * hs}}, sbias);
Shucai Xiao's avatar
Shucai Xiao committed
532
533
534
535
        auto rb_h  = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
        brb_zr     = prog.insert_instruction(
            ins, op::broadcast{1, {bs, static_cast<size_t>(2 * hs)}}, rb_zr);
        brb_h = prog.insert_instruction(ins, op::broadcast{1, {bs, static_cast<size_t>(hs)}}, rb_h);
536
537
538
539
540
541
    }

    for(long i = 0; i < seq_len; i++)
    {
        long seq_index = is_forward ? i : (seq_len - 1 - i);
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
Shucai Xiao's avatar
Shucai Xiao committed
542
543
        auto cont_xt = prog.insert_instruction(ins, op::contiguous{}, xt);
        xt           = prog.insert_instruction(ins, op::squeeze{{0}}, cont_xt);
544

545
546
        auto xt_w    = prog.insert_instruction(ins, op::dot{}, xt, tw);
        auto ih1_rzr = prog.insert_instruction(ins, op::dot{}, sih, trzr);
Shucai Xiao's avatar
Shucai Xiao committed
547
        if(bias != prog.end())
548
        {
549
550
            xt_w    = prog.insert_instruction(ins, op::add{}, xt_w, bwb);
            ih1_rzr = prog.insert_instruction(ins, op::add{}, ih1_rzr, brb_zr);
551
552
553
554
555
556
557
558
559
560
        }

        auto xw_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_w);
        auto xw_r = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, xt_w);
        auto xw_h = prog.insert_instruction(ins, op::slice{{1}, {2 * hs}, {3 * hs}}, xt_w);

        auto hr_z = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, ih1_rzr);
        auto hr_r = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, ih1_rzr);

        auto xw_hr_z = prog.insert_instruction(ins, op::add{}, xw_z, hr_z);
Shucai Xiao's avatar
Shucai Xiao committed
561
        auto zt      = prog.insert_instruction(ins, actv_func1, xw_hr_z);
562
563

        auto xw_hr_r = prog.insert_instruction(ins, op::add{}, xw_r, hr_r);
Shucai Xiao's avatar
Shucai Xiao committed
564
        auto rt      = prog.insert_instruction(ins, actv_func1, xw_hr_r);
565
566
567
568
569
570

        instruction_ref hr_h{};
        if(linear_before_reset == 0)
        {
            // equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
            auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
571
            hr_h        = prog.insert_instruction(ins, op::dot{}, rt_ht1, trh);
Shucai Xiao's avatar
Shucai Xiao committed
572
            if(bias != prog.end())
573
            {
574
                hr_h = prog.insert_instruction(ins, op::add{}, hr_h, brb_h);
575
            }
576
577
578
        }
        else
        {
579
            // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
580
            auto ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, trh);
Shucai Xiao's avatar
Shucai Xiao committed
581
            if(bias != prog.end())
582
            {
583
                ht1_rh = prog.insert_instruction(ins, op::add{}, ht1_rh, brb_h);
584
            }
Shucai Xiao's avatar
Shucai Xiao committed
585
            hr_h = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
586
587
        }

588
        auto xw_hr_h = prog.insert_instruction(ins, op::add{}, xw_h, hr_h);
Shucai Xiao's avatar
Shucai Xiao committed
589
        auto ht      = prog.insert_instruction(ins, actv_func2, xw_hr_h);
590

591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
        // equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
        auto one_minus_zt    = prog.insert_instruction(ins, op::sub{}, l1, zt);
        auto one_minus_zt_ht = prog.insert_instruction(ins, op::mul{}, one_minus_zt, ht);
        auto zt_ht1          = prog.insert_instruction(ins, op::mul{}, zt, sih);
        sih                  = prog.insert_instruction(ins, op::add{}, one_minus_zt_ht, zt_ht1);
        last_output          = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);

        if(i < seq_len - 1)
        {
            if(is_forward)
            {
                hidden_states =
                    (seq_index == 0)
                        ? last_output
                        : prog.insert_instruction(ins, op::concat{0}, hidden_states, last_output);
            }
            else
            {
                hidden_states =
                    (seq_index == seq_len - 1)
                        ? last_output
                        : prog.insert_instruction(ins, op::concat{0}, last_output, hidden_states);
            }
        }
    }

    return {hidden_states, last_output};
}

std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
{
    auto gru_op = any_cast<op::gru>(ins->get_operator());
    // before rewrite the gru operator, need to ensure
    // we have 4 actv funcs, even though a user does not
    // specifiy any actv func. If less than 4, use the
    // algorithm in parse_gru to make 4 actv functions
627
    if(gru_op.direction == op::rnn_direction::bidirectional)
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
    {
        if(gru_op.actv_funcs.empty())
            return {op::sigmoid{}, op::tanh{}, op::sigmoid{}, op::tanh{}};
        else if(gru_op.actv_funcs.size() == 1)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0)};
        else if(gru_op.actv_funcs.size() == 2)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1)};
        else if(gru_op.actv_funcs.size() == 3)
            return {gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1),
                    gru_op.actv_funcs.at(2),
                    gru_op.actv_funcs.at(0)};
        else
            return gru_op.actv_funcs;
    }
    else
    {
        if(gru_op.actv_funcs.empty())
            return {op::sigmoid{}, op::tanh{}};
        else if(gru_op.actv_funcs.size() == 1)
            return {gru_op.actv_funcs.at(0), gru_op.actv_funcs.at(0)};
        else
            return gru_op.actv_funcs;
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
660
661
662
663
664
// for lstm operators
void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
{
    assert(ins->name() == "lstm");
    auto args = ins->inputs();
Shucai Xiao's avatar
Shucai Xiao committed
665

Shucai Xiao's avatar
Shucai Xiao committed
666
    shape seq_shape         = args[0]->get_shape();
667
    std::size_t hidden_size = args[2]->get_shape().lens()[2];
Shucai Xiao's avatar
Shucai Xiao committed
668
669
    std::size_t batch_size  = seq_shape.lens()[1];
    shape::type_t type      = seq_shape.type();
Shucai Xiao's avatar
Shucai Xiao committed
670
    migraphx::shape ihc_shape{type, {1, batch_size, hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
671
    std::vector<float> ihc_data(ihc_shape.elements(), 0.0);
Shucai Xiao's avatar
Shucai Xiao committed
672
673
674

    migraphx::shape pph_shape{type, {1, 3 * hidden_size}};

Shucai Xiao's avatar
Shucai Xiao committed
675
676
    auto actv_funcs         = lstm_actv_funcs(ins);
    auto lstm_op            = any_cast<op::lstm>(ins->get_operator());
Shucai Xiao's avatar
Shucai Xiao committed
677
    op::rnn_direction dirct = lstm_op.direction;
Shucai Xiao's avatar
Shucai Xiao committed
678

Shucai Xiao's avatar
Shucai Xiao committed
679
680
681
682
683
684
685
686
687
688
    // process sequence length
    instruction_ref seq_lens = prog.end();
    if((args.size() >= 5) && args[4]->name() != "undefined")
    {
        seq_lens = args[4];
    }

    bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);

    instruction_ref last_hs_output{};
Shucai Xiao's avatar
Shucai Xiao committed
689
    instruction_ref last_cell_output{};
Shucai Xiao's avatar
Shucai Xiao committed
690
691
    instruction_ref hidden_state{};
    instruction_ref cell_outputs{};
Shucai Xiao's avatar
Shucai Xiao committed
692
    if(dirct == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
693
694
695
696
697
698
699
700
701
702
703
704
705
    {
        // input weight matrix
        // input weight matrix
        auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
        auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);

        // hidden state weight matrix
        auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
        auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);

        // process bias
        instruction_ref bias_forward = prog.end();
        instruction_ref bias_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
706
        if(args.size() >= 4 && args[3]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
        {
            bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
            bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
        }

        // process intial hidden state, it is the 6th argument
        instruction_ref ih_forward{};
        instruction_ref ih_reverse{};
        if(args.size() >= 6 && args[5]->name() != "undefined")
        {
            ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
            ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
        }
        else
        {
            ih_forward = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
            ih_reverse = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
        }

        // process initial cell value
        instruction_ref ic_forward{};
        instruction_ref ic_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
729
        if(args.size() >= 7 && args[6]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
730
731
732
733
734
735
736
737
738
739
740
        {
            ic_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[6]);
            ic_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[6]);
        }
        else
        {
            ic_forward = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
            ic_reverse = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
        }

        // process weight of the peephole
Shucai Xiao's avatar
Shucai Xiao committed
741
742
        instruction_ref pph_forward = prog.end();
        instruction_ref pph_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
743
        if(args.size() == 8 && args[7]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
744
745
746
747
        {
            pph_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[7]);
            pph_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[7]);
        }
Shucai Xiao's avatar
Shucai Xiao committed
748

Shucai Xiao's avatar
Shucai Xiao committed
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
        auto ret_forward = lstm_cell(true,
                                     prog,
                                     ins,
                                     {args[0],
                                      w_forward,
                                      r_forward,
                                      bias_forward,
                                      seq_lens,
                                      ih_forward,
                                      ic_forward,
                                      pph_forward},
                                     actv_funcs.at(0),
                                     actv_funcs.at(1),
                                     actv_funcs.at(2));

        if(variable_seq_len)
        {
            args[0] =
                prog.insert_instruction(ins, op::rnn_var_sl_shift_sequence{}, args[0], seq_lens);
        }
        auto ret_reverse = lstm_cell(false,
                                     prog,
                                     ins,
                                     {args[0],
                                      w_reverse,
                                      r_reverse,
                                      bias_reverse,
                                      seq_lens,
                                      ih_reverse,
                                      ic_reverse,
                                      pph_reverse},
                                     actv_funcs.at(3),
                                     actv_funcs.at(4),
                                     actv_funcs.at(5));

        auto concat_hs_output =
Shucai Xiao's avatar
Shucai Xiao committed
785
            prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
Shucai Xiao's avatar
Shucai Xiao committed
786
787
788
789
        auto concat_cell_output =
            prog.insert_instruction(ins, op::concat{1}, ret_forward[3], ret_reverse[3]);
        last_hs_output   = prog.insert_instruction(ins, op::squeeze{{0}}, concat_hs_output);
        last_cell_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_cell_output);
Shucai Xiao's avatar
Shucai Xiao committed
790
791

        // the following logic is to ensure the last instruction is a concat
Shucai Xiao's avatar
Shucai Xiao committed
792
        if(ret_forward[0] == prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
793
        {
Shucai Xiao's avatar
Shucai Xiao committed
794
            cell_outputs = concat_cell_output;
Shucai Xiao's avatar
Shucai Xiao committed
795
796
797
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
798
            ret_forward[1] =
Shucai Xiao's avatar
Shucai Xiao committed
799
                prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
Shucai Xiao's avatar
Shucai Xiao committed
800
            ret_reverse[1] =
Shucai Xiao's avatar
Shucai Xiao committed
801
                prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
Shucai Xiao's avatar
Shucai Xiao committed
802
803
804
805
806
807
808

            ret_forward[3] =
                prog.insert_instruction(ins, op::concat{0}, ret_forward[2], ret_forward[3]);
            ret_reverse[3] =
                prog.insert_instruction(ins, op::concat{0}, ret_reverse[3], ret_reverse[2]);
            cell_outputs =
                prog.insert_instruction(ins, op::concat{1}, ret_forward[3], ret_reverse[3]);
Shucai Xiao's avatar
Shucai Xiao committed
809
        }
Shucai Xiao's avatar
Shucai Xiao committed
810
811
812

        hidden_state =
            prog.replace_instruction(ins, op::concat{1}, {ret_forward[1], ret_reverse[1]});
Shucai Xiao's avatar
Shucai Xiao committed
813
814
815
    }
    else
    {
Shucai Xiao's avatar
Shucai Xiao committed
816
        bool is_forward = (dirct == op::rnn_direction::forward);
Shucai Xiao's avatar
Shucai Xiao committed
817
818
819
820
821
822
        // weight matrices
        auto w = args[1];
        auto r = args[2];

        // bias
        instruction_ref bias = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
823
        if(args.size() >= 4 && args[3]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
824
825
826
827
828
829
        {
            bias = args[3];
        }

        // initial hidden state
        instruction_ref ih{};
Shucai Xiao's avatar
Shucai Xiao committed
830
        if(args.size() >= 6 && args[5]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
831
832
833
834
835
836
837
838
839
840
        {
            ih = args[5];
        }
        else
        {
            ih = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
        }

        // initial cell value
        instruction_ref ic{};
Shucai Xiao's avatar
Shucai Xiao committed
841
        if(args.size() >= 7 && args[6]->name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
842
843
844
845
846
847
848
849
850
        {
            ic = args[6];
        }
        else
        {
            ic = prog.add_literal(migraphx::literal{ihc_shape, ihc_data});
        }

        // process weight of the peephole
Shucai Xiao's avatar
Shucai Xiao committed
851
        instruction_ref pph = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
852
853
854
855
        if(args.size() == 8 && args[7]->name() != "undefined")
        {
            pph = args[7];
        }
Shucai Xiao's avatar
Shucai Xiao committed
856

Shucai Xiao's avatar
Shucai Xiao committed
857
858
859
860
861
        if(!is_forward and variable_seq_len)
        {
            args[0] =
                prog.insert_instruction(ins, op::rnn_var_sl_shift_sequence{}, args[0], seq_lens);
        }
Shucai Xiao's avatar
Shucai Xiao committed
862
        auto ret = lstm_cell(is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
863
864
                             prog,
                             ins,
Shucai Xiao's avatar
Shucai Xiao committed
865
                             {args[0], w, r, bias, seq_lens, ih, ic, pph},
Shucai Xiao's avatar
Shucai Xiao committed
866
867
868
869
                             actv_funcs.at(0),
                             actv_funcs.at(1),
                             actv_funcs.at(2));

Shucai Xiao's avatar
Shucai Xiao committed
870
871
872
        last_hs_output   = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
        last_cell_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[3]);

Shucai Xiao's avatar
Shucai Xiao committed
873
        if(ret[0] == prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
874
        {
Shucai Xiao's avatar
Shucai Xiao committed
875
876
            cell_outputs = ret[3];
            hidden_state = prog.replace_instruction(ins, op::concat{0}, ret[1]);
Shucai Xiao's avatar
Shucai Xiao committed
877
878
879
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
880
881
882
883
884
            auto concat_cell_arg0 = is_forward ? ret[2] : ret[3];
            auto concat_cell_arg1 = is_forward ? ret[3] : ret[2];
            cell_outputs =
                prog.insert_instruction(ins, op::concat{0}, concat_cell_arg0, concat_cell_arg1);

Shucai Xiao's avatar
Shucai Xiao committed
885
886
            auto concat_arg0 = is_forward ? ret[0] : ret[1];
            auto concat_arg1 = is_forward ? ret[1] : ret[0];
Shucai Xiao's avatar
Shucai Xiao committed
887
            hidden_state = prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
Shucai Xiao's avatar
Shucai Xiao committed
888
889
890
        }
    }

Shucai Xiao's avatar
Shucai Xiao committed
891
892
    ins = replace_last_hs_output(prog, hidden_state, seq_lens, last_hs_output, dirct);
    replace_last_cell_output(prog, ins, seq_lens, cell_outputs, last_cell_output, dirct);
Shucai Xiao's avatar
Shucai Xiao committed
893
894
895
}

std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
896
897
898
899
900
901
                                                    program& prog,
                                                    instruction_ref ins,
                                                    std::vector<instruction_ref> inputs,
                                                    const operation& actv_func1,
                                                    const operation& actv_func2,
                                                    const operation& actv_func3) const
Shucai Xiao's avatar
Shucai Xiao committed
902
{
Shucai Xiao's avatar
Shucai Xiao committed
903
    // must have 7 args in the input vector
Shucai Xiao's avatar
Shucai Xiao committed
904
905
906
907
908
909
910
911
912
    assert(inputs.size() == 8);
    auto seq      = inputs.at(0);
    auto w        = inputs.at(1);
    auto r        = inputs.at(2);
    auto bias     = inputs.at(3);
    auto seq_lens = inputs.at(4);
    auto ih       = inputs.at(5);
    auto ic       = inputs.at(6);
    auto pph      = inputs.at(7);
Shucai Xiao's avatar
Shucai Xiao committed
913

Shucai Xiao's avatar
Shucai Xiao committed
914
    instruction_ref hidden_states = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
915
916
917
    instruction_ref cell_outputs  = prog.end();

    instruction_ref last_hs_output{};
Shucai Xiao's avatar
Shucai Xiao committed
918
919
920
    instruction_ref last_cell_output{};

    migraphx::shape seq_shape = seq->get_shape();
Shucai Xiao's avatar
Shucai Xiao committed
921
    migraphx::shape r_shape   = r->get_shape();
Shucai Xiao's avatar
Shucai Xiao committed
922
    long max_seq_len          = static_cast<long>(seq_shape.lens()[0]);
Shucai Xiao's avatar
Shucai Xiao committed
923
    long hs                   = static_cast<long>(r_shape.lens()[2]);
Shucai Xiao's avatar
Shucai Xiao committed
924
    auto bs                   = ih->get_shape().lens()[1];
Shucai Xiao's avatar
Shucai Xiao committed
925
926

    std::vector<int64_t> perm{1, 0};
927
    // w matrix, squeeze and transpose
Shucai Xiao's avatar
Shucai Xiao committed
928
    auto sw  = prog.insert_instruction(ins, op::squeeze{{0}}, w);
929
    auto tsw = prog.insert_instruction(ins, op::transpose{perm}, sw);
Shucai Xiao's avatar
Shucai Xiao committed
930

931
    // r matrix, squeeze and transpose
Shucai Xiao's avatar
Shucai Xiao committed
932
    auto sr  = prog.insert_instruction(ins, op::squeeze{{0}}, r);
933
    auto tsr = prog.insert_instruction(ins, op::transpose{perm}, sr);
Shucai Xiao's avatar
Shucai Xiao committed
934

Shucai Xiao's avatar
Shucai Xiao committed
935
936
937
938
    // initial hidden state
    auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);

    // initial cell state
Shucai Xiao's avatar
Shucai Xiao committed
939
    auto sic     = prog.insert_instruction(ins, op::squeeze{{0}}, ic);
940
    auto ic_lens = sic->get_shape().lens();
Shucai Xiao's avatar
Shucai Xiao committed
941

942
    // bias
943
    instruction_ref wrb{};
Shucai Xiao's avatar
Shucai Xiao committed
944
    if(bias != prog.end())
945
    {
946

Shucai Xiao's avatar
Shucai Xiao committed
947
948
949
        auto sbias  = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
        auto ub_wb  = prog.insert_instruction(ins, op::slice{{0}, {0}, {4 * hs}}, sbias);
        auto ub_rb  = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {8 * hs}}, sbias);
950
        auto ub_wrb = prog.insert_instruction(ins, op::add{}, ub_wb, ub_rb);
951

952
953
        wrb = prog.insert_instruction(
            ins, op::broadcast{1, {bs, 4 * static_cast<size_t>(hs)}}, ub_wrb);
954
955
    }

Shucai Xiao's avatar
Shucai Xiao committed
956
    // peep hole
Shucai Xiao's avatar
Shucai Xiao committed
957
958
959
    instruction_ref pphi_brcst{};
    instruction_ref ppho_brcst{};
    instruction_ref pphf_brcst{};
Shucai Xiao's avatar
Shucai Xiao committed
960
    if(pph != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
961
    {
Shucai Xiao's avatar
Shucai Xiao committed
962
963
        auto spph  = prog.insert_instruction(ins, op::squeeze{{0}}, pph);
        auto pphi  = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, spph);
964
        pphi_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, pphi);
Shucai Xiao's avatar
Shucai Xiao committed
965

Shucai Xiao's avatar
Shucai Xiao committed
966
        auto ppho  = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, spph);
967
        ppho_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, ppho);
Shucai Xiao's avatar
Shucai Xiao committed
968

Shucai Xiao's avatar
Shucai Xiao committed
969
        auto pphf  = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, spph);
970
        pphf_brcst = prog.insert_instruction(ins, op::broadcast{1, ic_lens}, pphf);
Shucai Xiao's avatar
Shucai Xiao committed
971
    }
Shucai Xiao's avatar
Shucai Xiao committed
972

Shucai Xiao's avatar
Shucai Xiao committed
973
    long seq_len = static_cast<long>(get_seq_len(prog, seq, seq_lens));
Shucai Xiao's avatar
Shucai Xiao committed
974
    for(long i = 0; i < seq_len; ++i)
Shucai Xiao's avatar
Shucai Xiao committed
975
976
977
    {
        long seq_index = is_forward ? i : (seq_len - 1 - i);
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
Shucai Xiao's avatar
Shucai Xiao committed
978
979
        auto cont_xt = prog.insert_instruction(ins, op::contiguous{}, xt);
        xt           = prog.insert_instruction(ins, op::squeeze{{0}}, cont_xt);
Shucai Xiao's avatar
Shucai Xiao committed
980

981
982
        auto xt_tsw  = prog.insert_instruction(ins, op::dot{}, xt, tsw);
        auto sih_tsr = prog.insert_instruction(ins, op::dot{}, sih, tsr);
Shucai Xiao's avatar
Shucai Xiao committed
983
        auto xt_sih  = prog.insert_instruction(ins, op::add{}, xt_tsw, sih_tsr);
Shucai Xiao's avatar
Shucai Xiao committed
984
        if(bias != prog.end())
985
        {
Shucai Xiao's avatar
Shucai Xiao committed
986
            xt_sih = prog.insert_instruction(ins, op::add{}, xt_sih, wrb);
987
        }
Shucai Xiao's avatar
Shucai Xiao committed
988

989
990
        auto it_before_actv = prog.insert_instruction(ins, op::slice{{1}, {0}, {hs}}, xt_sih);
        auto ot_before_actv = prog.insert_instruction(ins, op::slice{{1}, {hs}, {2 * hs}}, xt_sih);
Shucai Xiao's avatar
Shucai Xiao committed
991
992
993
994
        auto ft_before_actv =
            prog.insert_instruction(ins, op::slice{{1}, {2 * hs}, {3 * hs}}, xt_sih);
        auto ct_before_actv =
            prog.insert_instruction(ins, op::slice{{1}, {3 * hs}, {4 * hs}}, xt_sih);
995

Shucai Xiao's avatar
Shucai Xiao committed
996
        if(pph != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
997
        {
Shucai Xiao's avatar
Shucai Xiao committed
998
999
1000
            auto pphi_ct   = prog.insert_instruction(ins, op::mul{}, pphi_brcst, sic);
            it_before_actv = prog.insert_instruction(ins, op::add{}, it_before_actv, pphi_ct);

Shucai Xiao's avatar
Shucai Xiao committed
1001
1002
            auto pphf_ct   = prog.insert_instruction(ins, op::mul{}, pphf_brcst, sic);
            ft_before_actv = prog.insert_instruction(ins, op::add{}, ft_before_actv, pphf_ct);
Shucai Xiao's avatar
Shucai Xiao committed
1003
        }
Shucai Xiao's avatar
Shucai Xiao committed
1004
        auto it = prog.insert_instruction(ins, actv_func1, it_before_actv);
1005
        auto ft = prog.insert_instruction(ins, actv_func1, ft_before_actv);
1006
        auto ct = prog.insert_instruction(ins, actv_func2, ct_before_actv);
1007
1008

        // equation Ct = ft (.) Ct-1 + it (.) ct
Shucai Xiao's avatar
Shucai Xiao committed
1009
1010
1011
        auto ft_cell = prog.insert_instruction(ins, op::mul{}, ft, sic);
        auto it_ct   = prog.insert_instruction(ins, op::mul{}, it, ct);
        auto cellt   = prog.insert_instruction(ins, op::add{}, ft_cell, it_ct);
1012

Shucai Xiao's avatar
Shucai Xiao committed
1013
        if(pph != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
1014
        {
Shucai Xiao's avatar
Shucai Xiao committed
1015
1016
            auto ppho_cellt = prog.insert_instruction(ins, op::mul{}, ppho_brcst, cellt);
            ot_before_actv  = prog.insert_instruction(ins, op::add{}, ot_before_actv, ppho_cellt);
Shucai Xiao's avatar
Shucai Xiao committed
1017
        }
1018
1019
1020
1021
        auto ot = prog.insert_instruction(ins, actv_func1, ot_before_actv);

        // Ht = ot (.) h(Ct)
        auto h_cellt = prog.insert_instruction(ins, actv_func3, cellt);
Shucai Xiao's avatar
Shucai Xiao committed
1022
        auto ht      = prog.insert_instruction(ins, op::mul{}, ot, h_cellt);
1023
1024
1025
1026

        sic = cellt;
        sih = ht;

Shucai Xiao's avatar
Shucai Xiao committed
1027
1028
        last_hs_output   = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht);
        last_cell_output = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, cellt);
1029

Shucai Xiao's avatar
Shucai Xiao committed
1030
        if(i < seq_len - 1)
1031
        {
Shucai Xiao's avatar
Shucai Xiao committed
1032
            if(i == 0)
1033
            {
Shucai Xiao's avatar
Shucai Xiao committed
1034
1035
                hidden_states = last_hs_output;
                cell_outputs  = last_cell_output;
1036
1037
1038
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
1039
1040
                auto concat_hs_arg0 = is_forward ? hidden_states : last_hs_output;
                auto concat_hs_arg1 = is_forward ? last_hs_output : hidden_states;
Shucai Xiao's avatar
Shucai Xiao committed
1041
                hidden_states =
Shucai Xiao's avatar
Shucai Xiao committed
1042
1043
1044
1045
1046
1047
                    prog.insert_instruction(ins, op::concat{0}, concat_hs_arg0, concat_hs_arg1);

                auto concat_cell_arg0 = is_forward ? cell_outputs : last_cell_output;
                auto concat_cell_arg1 = is_forward ? last_cell_output : cell_outputs;
                cell_outputs =
                    prog.insert_instruction(ins, op::concat{0}, concat_cell_arg0, concat_cell_arg1);
1048
1049
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
1050
    }
1051

Shucai Xiao's avatar
Shucai Xiao committed
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
    // condition of all sequence are of the same length and
    // less than max_seq_len, we need to append the hs outputs
    // In this case, the cell_output is not used at all, so
    // no need to extand it to the avariable length
    if(seq_len < max_seq_len)
    {
        auto s        = last_hs_output->get_shape();
        auto pad_lens = s.lens();
        pad_lens[0]   = static_cast<std::size_t>(max_seq_len - seq_len);
        shape pad_s{s.type(), pad_lens};
        std::vector<float> data(pad_s.elements(), 0.0f);
        auto pl       = prog.add_literal(pad_s, data.begin(), data.end());
        hidden_states = prog.insert_instruction(ins, op::concat{0}, hidden_states, pl);
    }
1066

Shucai Xiao's avatar
Shucai Xiao committed
1067
    return {hidden_states, last_hs_output, cell_outputs, last_cell_output};
Shucai Xiao's avatar
Shucai Xiao committed
1068
1069
1070
1071
1072
1073
1074
1075
1076
}

std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
{
    auto lstm_op = any_cast<op::lstm>(ins->get_operator());
    // before rewrite the lstm operator, need to ensure
    // we have 6 actv funcs, even though a user does not
    // specifiy any actv func. If less than 46, use the
    // algorithm in parse_lstm to make 6 actv functions
Shucai Xiao's avatar
Shucai Xiao committed
1077
    const auto& actv_funcs     = lstm_op.actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1078
    std::size_t num_actv_funcs = actv_funcs.size();
Shucai Xiao's avatar
Shucai Xiao committed
1079
    if(lstm_op.direction == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1080
1081
1082
1083
    {
        switch(num_actv_funcs)
        {
        case 0:
Shucai Xiao's avatar
Shucai Xiao committed
1084
            return {op::sigmoid{}, op::tanh{}, op::tanh{}, op::sigmoid{}, op::tanh{}, op::tanh{}};
Shucai Xiao's avatar
Shucai Xiao committed
1085
1086

        case 1:
Shucai Xiao's avatar
Shucai Xiao committed
1087
1088
1089
1090
1091
1092
            return {actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0),
                    actv_funcs.at(0)};
Shucai Xiao's avatar
Shucai Xiao committed
1093
1094

        case 2:
Shucai Xiao's avatar
Shucai Xiao committed
1095
1096
1097
1098
1099
1100
1101
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(1),
                    actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(1)};

Shucai Xiao's avatar
Shucai Xiao committed
1102
        case 3:
Shucai Xiao's avatar
Shucai Xiao committed
1103
1104
1105
1106
1107
1108
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2)};
Shucai Xiao's avatar
Shucai Xiao committed
1109
1110

        case 4:
Shucai Xiao's avatar
Shucai Xiao committed
1111
1112
1113
1114
1115
1116
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(3),
                    actv_funcs.at(3),
                    actv_funcs.at(3)};
Shucai Xiao's avatar
Shucai Xiao committed
1117
1118

        case 5:
Shucai Xiao's avatar
Shucai Xiao committed
1119
1120
1121
1122
1123
1124
            return {actv_funcs.at(0),
                    actv_funcs.at(1),
                    actv_funcs.at(2),
                    actv_funcs.at(3),
                    actv_funcs.at(4),
                    actv_funcs.at(4)};
Shucai Xiao's avatar
Shucai Xiao committed
1125

Shucai Xiao's avatar
Shucai Xiao committed
1126
        default: return actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1127
1128
1129
1130
1131
1132
        }
    }
    else
    {
        switch(num_actv_funcs)
        {
Shucai Xiao's avatar
Shucai Xiao committed
1133
        case 0: return {op::sigmoid{}, op::tanh{}, op::tanh{}};
Shucai Xiao's avatar
Shucai Xiao committed
1134

Shucai Xiao's avatar
Shucai Xiao committed
1135
        case 1: return {actv_funcs.at(0), actv_funcs.at(0), actv_funcs.at(0)};
Shucai Xiao's avatar
Shucai Xiao committed
1136

Shucai Xiao's avatar
Shucai Xiao committed
1137
1138
1139
        case 2: return {actv_funcs.at(0), actv_funcs.at(1), actv_funcs.at(1)};

        default: return actv_funcs;
Shucai Xiao's avatar
Shucai Xiao committed
1140
1141
1142
1143
        }
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
bool rewrite_rnn::is_variable_seq_lens(const program& prog, instruction_ref seq_lens) const
{
    bool is_var_lens = false;
    if(seq_lens != prog.end())
    {
        if(seq_lens->can_eval())
        {
            auto arg_lens = seq_lens->eval();
            std::vector<int64_t> vec_lens;
            arg_lens.visit([&](auto l) { vec_lens.assign(l.begin(), l.end()); });
            int64_t l = 0;
            if(!vec_lens.empty())
            {
                l = vec_lens[0];
            }
            if(!std::all_of(vec_lens.begin(), vec_lens.end(), [&](auto v) { return v == l; }))
            {
                is_var_lens = true;
            }
        }
        else
        {
            is_var_lens = true;
        }
    }

    return is_var_lens;
}

std::size_t
rewrite_rnn::get_seq_len(const program& prog, instruction_ref input, instruction_ref seq_lens) const
{
    bool is_var_lens = is_variable_seq_lens(prog, seq_lens);
    auto input_shape = input->get_shape();
    auto length      = input_shape.lens()[0];
    if(!is_var_lens and seq_lens != prog.end())
    {
        auto arg_len = seq_lens->eval();
        std::vector<std::size_t> vec_lens;
        arg_len.visit([&](auto l) { vec_lens.assign(l.begin(), l.end()); });
        length = vec_lens.empty() ? length : vec_lens[0];
    }

    return length;
}

instruction_ref rewrite_rnn::replace_last_hs_output(program& prog,
                                                    instruction_ref ins,
                                                    instruction_ref seq_lens,
                                                    instruction_ref last_hs_output,
                                                    op::rnn_direction dirct) const
{
    bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);
    instruction_ref result_ins{};
    if(variable_seq_len)
    {
        result_ins = prog.insert_instruction(
            std::next(ins), op::rnn_var_sl_shift_output{"hidden_states", dirct}, ins, seq_lens);
        prog.replace_instruction(ins, result_ins);

        // correct the direction used for the operator
        auto last_hs_output_it = result_ins->outputs().begin();
        while(last_hs_output_it != result_ins->outputs().end())
        {
            last_hs_output_it =
                std::find_if(last_hs_output_it, result_ins->outputs().end(), [](auto i) {
                    return i->name() == "rnn_last_hs_output";
                });

            if(last_hs_output_it != result_ins->outputs().end())
            {
                auto inputs = (*last_hs_output_it)->inputs();
                prog.replace_instruction(*last_hs_output_it,
                                         op::rnn_var_sl_last_output{dirct},
                                         inputs.front(),
                                         seq_lens);
                last_hs_output_it++;
            }
        }
    }
    else
    {
        auto last_hs_output_it = ins->outputs().begin();
        while(last_hs_output_it != ins->outputs().end())
        {
            last_hs_output_it = std::find_if(last_hs_output_it, ins->outputs().end(), [](auto i) {
                return i->name() == "rnn_last_hs_output";
            });

            if(last_hs_output_it != ins->outputs().end())
            {
                prog.replace_instruction(*last_hs_output_it, last_hs_output);
                last_hs_output_it++;
            }
        }
        result_ins = ins;
    }

    return result_ins;
}

void rewrite_rnn::replace_last_cell_output(program& prog,
                                           instruction_ref ins,
                                           instruction_ref seq_lens,
                                           instruction_ref cell_outputs,
                                           instruction_ref last_cell_output,
                                           op::rnn_direction dirct) const
{
    bool variable_seq_len = is_variable_seq_lens(prog, seq_lens);

    if(variable_seq_len)
    {
        auto last_cell_output_it =
            std::find_if(ins->outputs().begin(), ins->outputs().end(), [](auto i) {
                return i->name() == "rnn_last_cell_output";
            });
        if(last_cell_output_it != ins->outputs().end())
        {
            cell_outputs =
                prog.insert_instruction(std::next(ins),
                                        op::rnn_var_sl_shift_output{"cell_outputs", dirct},
                                        cell_outputs,
                                        seq_lens);
        }

        last_cell_output_it = ins->outputs().begin();
        while(last_cell_output_it != ins->outputs().end())
        {
            last_cell_output_it =
                std::find_if(last_cell_output_it, ins->outputs().end(), [](auto i) {
                    return i->name() == "rnn_last_cell_output";
                });

            if(last_cell_output_it != ins->outputs().end())
            {
                auto inputs = (*last_cell_output_it)->inputs();
                inputs[0]   = cell_outputs;
                prog.replace_instruction(*last_cell_output_it,
                                         op::rnn_var_sl_last_output{dirct},
                                         inputs.front(),
                                         seq_lens);
                last_cell_output_it++;
            }
        }
    }
    // replace the rnn_last_cell_output with the last_cell_output. The while
    // loop is to handle the case of multiple rnn_last_cell_output operators
    else
    {
        auto last_cell_output_it = ins->outputs().begin();
        while(last_cell_output_it != ins->outputs().end())
        {
            last_cell_output_it =
                std::find_if(last_cell_output_it, ins->outputs().end(), [](auto i) {
                    return i->name() == "rnn_last_cell_output";
                });

            if(last_cell_output_it != ins->outputs().end())
            {
                prog.replace_instruction(*last_cell_output_it, last_cell_output);
                last_cell_output_it++;
            }
        }
    }
}

1310
namespace op {
Shucai Xiao's avatar
Shucai Xiao committed
1311
std::ostream& operator<<(std::ostream& os, rnn_direction v)
1312
{
Shucai Xiao's avatar
Shucai Xiao committed
1313
1314
    std::vector<std::string> rnn_direction_str = {"forward", "reverse", "bidirectional"};
    os << rnn_direction_str[static_cast<std::underlying_type<rnn_direction>::type>(v)];
1315
1316
    return os;
}
Shucai Xiao's avatar
Shucai Xiao committed
1317
} // namespace op
1318

Shucai Xiao's avatar
Shucai Xiao committed
1319
1320
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx