rewrite_gru.cpp 16.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
#include <migraphx/rewrite_gru.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

void rewrite_gru::apply(program& prog) const
{
    for(auto ins : iterator_for(prog))
    {
15
        if(ins->name() == "gru")
16
        {
Shucai Xiao's avatar
Shucai Xiao committed
17
            const auto actv_funcs = compute_actv_funcs(ins);
18
19
20
21
22
23
24
            // could be 3 to 5 inputs (though onnx::rnn has 6 inputs,
            // the 5th one is undefined and ignored by protobuf. so
            // we need to process up to 5 inputs
            auto args = ins->inputs();

            shape seq_shape         = args[0]->get_shape();
            std::size_t hidden_size = args[2]->get_shape().lens()[2];
Shucai Xiao's avatar
Shucai Xiao committed
25
            std::size_t batch_size  = seq_shape.lens()[1];
26
            shape::type_t type      = seq_shape.type();
Shucai Xiao's avatar
Shucai Xiao committed
27
            migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
Shucai Xiao's avatar
Shucai Xiao committed
28
            std::vector<float> data(ih_shape.elements(), 0.0);
29
30
31

            auto gru_op                    = any_cast<op::gru>(ins->get_operator());
            op::gru::gru_direction_t dicrt = gru_op.direction;
32
            instruction_ref last_output{};
33
            if(dicrt == op::gru::bidirectional)
34
            {
35
36
37
38
39
40
41
42
43
                // w weight matrix
                auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
                auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);

                // r weight matrix
                auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
                auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);

                // bias
Shucai Xiao's avatar
Shucai Xiao committed
44
45
                instruction_ref bias_forward = prog.end();
                instruction_ref bias_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
46
                if(args.size() >= 4 && args[3]->get_operator().name() != "undefined")
47
48
49
50
51
52
                {
                    bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
                    bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
                }

                // intial hidden state
Shucai Xiao's avatar
Shucai Xiao committed
53
54
                instruction_ref ih_forward{};
                instruction_ref ih_reverse{};
Shucai Xiao's avatar
Shucai Xiao committed
55
                if(args.size() == 6 && args[5]->get_operator().name() != "undefined")
56
                {
Shucai Xiao's avatar
Shucai Xiao committed
57
58
                    ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
                    ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
59
60
61
62
63
64
65
                }
                else
                {
                    ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
                    ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
                }

Shucai Xiao's avatar
Shucai Xiao committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
                auto ret_forward =
                    gru_cell(true,
                             prog,
                             ins,
                             {args[0], w_forward, r_forward, bias_forward, ih_forward},
                             gru_op.linear_before_reset,
                             actv_funcs.at(0),
                             actv_funcs.at(1));

                auto ret_reverse =
                    gru_cell(false,
                             prog,
                             ins,
                             {args[0], w_reverse, r_reverse, bias_reverse, ih_reverse},
                             gru_op.linear_before_reset,
                             actv_funcs.at(2),
                             actv_funcs.at(3));
83

84
85
                auto concat_output =
                    prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
86
                last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
87
88
89
90

                // The following logic is to ensure the last instruction rewritten
                // from gru operator is a concat
                instruction_ref hidden_state{};
Shucai Xiao's avatar
Shucai Xiao committed
91
                if(ret_forward[0] == prog.end())
92
                {
Shucai Xiao's avatar
Shucai Xiao committed
93
94
                    hidden_state = prog.replace_instruction(
                        ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
95
                }
Shucai Xiao's avatar
Shucai Xiao committed
96
                else
97
98
99
100
101
102
103
                {
                    ret_forward[0] =
                        prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
                    ret_reverse[0] =
                        prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
                    hidden_state = prog.replace_instruction(
                        ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
Shucai Xiao's avatar
Shucai Xiao committed
104
                }
105
106
107
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
108
                bool is_forward = (dicrt == op::gru::forward);
109
110
111
112
113
114
                // weight matrix
                auto w = args[1];
                auto r = args[2];

                // bias
                instruction_ref bias = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
115
                if(args.size() >= 4 && args[3]->get_operator().name() != "undefined")
116
117
118
119
120
                {
                    bias = args[3];
                }

                // intial hidden state
Shucai Xiao's avatar
Shucai Xiao committed
121
                instruction_ref ih{};
Shucai Xiao's avatar
Shucai Xiao committed
122
                if(args.size() == 6 && args[5]->get_operator().name() != "undefined")
123
                {
Shucai Xiao's avatar
Shucai Xiao committed
124
                    ih = args[5];
125
126
127
128
129
130
131
132
133
                }
                else
                {
                    ih = prog.add_literal(migraphx::literal{ih_shape, data});
                }

                auto ret = gru_cell(is_forward,
                                    prog,
                                    ins,
134
                                    {args[0], w, r, bias, ih},
135
                                    gru_op.linear_before_reset,
Shucai Xiao's avatar
Shucai Xiao committed
136
                                    actv_funcs.at(0),
Shucai Xiao's avatar
Shucai Xiao committed
137
                                    actv_funcs.at(1));
138

139
                last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
140

141
                instruction_ref hidden_state{};
Shucai Xiao's avatar
Shucai Xiao committed
142
                if(ret[0] == prog.end())
143
144
145
146
147
148
                {
                    hidden_state = prog.replace_instruction(ins, op::concat{0}, ret[1]);
                }
                else
                {
                    auto concat_arg0 = is_forward ? ret[0] : ret[1];
Shucai Xiao's avatar
Shucai Xiao committed
149
                    auto concat_arg1 = is_forward ? ret[1] : ret[0];
Shucai Xiao's avatar
Shucai Xiao committed
150
151
                    hidden_state =
                        prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
152
                }
153
154
            }

155
156
            // replace the corresponding gru_last_output instruction
            // with the last_output, if gru_last_output exists
157
158
            // while loop to handle case of multiple gru_last_output operators
            auto last_output_it = ins->outputs().begin();
Shucai Xiao's avatar
Shucai Xiao committed
159
            while(last_output_it != ins->outputs().end())
160
            {
Shucai Xiao's avatar
Shucai Xiao committed
161
                last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
Shucai Xiao's avatar
Shucai Xiao committed
162
163
                    return i->name() == "gru_last_output";
                });
164
165
166
167
168
169

                if(last_output_it != ins->outputs().end())
                {
                    prog.replace_instruction(*last_output_it, last_output);
                    last_output_it++;
                }
170
            }
171
172
173
174
        }
    }
}

175
std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward,
176
177
                                                   program& prog,
                                                   instruction_ref ins,
178
                                                   std::vector<instruction_ref> inputs,
179
                                                   int linear_before_reset,
Shucai Xiao's avatar
Shucai Xiao committed
180
181
                                                   const operation& actv_func1,
                                                   const operation& actv_func2) const
182
{
183
    assert(inputs.size() == 5);
Shucai Xiao's avatar
Shucai Xiao committed
184
185
186
    auto seq  = inputs.at(0);
    auto w    = inputs.at(1);
    auto r    = inputs.at(2);
187
    auto bias = inputs.at(3);
Shucai Xiao's avatar
Shucai Xiao committed
188
    auto ih   = inputs.at(4);
189

190
    instruction_ref hidden_states = prog.end(), last_output;
191
    long seq_len                  = static_cast<long>(seq->get_shape().lens()[0]);
Shucai Xiao's avatar
Shucai Xiao committed
192
    long hs                       = static_cast<long>(r->get_shape().lens()[2]);
193

194
195
    migraphx::shape s(seq->get_shape().type(),
                      {seq->get_shape().lens()[1], static_cast<std::size_t>(hs)});
196
197
    std::vector<int> data(s.elements(), 1);
    auto l1 = prog.add_literal(migraphx::literal{s, data});
198
199
200

    // weight matrix
    std::vector<int64_t> perm{1, 0};
Shucai Xiao's avatar
Shucai Xiao committed
201
202
    auto sw      = prog.insert_instruction(ins, op::squeeze{{0}}, w);
    auto wz      = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sw);
203
204
    auto tran_wz = prog.insert_instruction(ins, op::transpose{perm}, wz);

Shucai Xiao's avatar
Shucai Xiao committed
205
    auto wr      = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sw);
206
    auto tran_wr = prog.insert_instruction(ins, op::transpose{perm}, wr);
Shucai Xiao's avatar
Shucai Xiao committed
207
208

    auto wh      = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sw);
209
210
    auto tran_wh = prog.insert_instruction(ins, op::transpose{perm}, wh);

Shucai Xiao's avatar
Shucai Xiao committed
211
212
    auto sr      = prog.insert_instruction(ins, op::squeeze{{0}}, r);
    auto rz      = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sr);
213
214
    auto tran_rz = prog.insert_instruction(ins, op::transpose{perm}, rz);

Shucai Xiao's avatar
Shucai Xiao committed
215
    auto rr      = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sr);
216
217
    auto tran_rr = prog.insert_instruction(ins, op::transpose{perm}, rr);

Shucai Xiao's avatar
Shucai Xiao committed
218
    auto rh      = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
219
220
221
222
    auto tran_rh = prog.insert_instruction(ins, op::transpose{perm}, rh);

    // initial states
    auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
223
224

    // bias
Shucai Xiao's avatar
Shucai Xiao committed
225
226
227
228
229
    instruction_ref brcst_bz{};
    instruction_ref brcst_br{};
    instruction_ref brcst_wbh{};
    instruction_ref brcst_rbh{};
    instruction_ref brcst_bh{};
Shucai Xiao's avatar
Shucai Xiao committed
230
    if(bias != prog.end())
231
    {
232
        auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
Shucai Xiao's avatar
Shucai Xiao committed
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
        auto wbz   = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
        auto wbr   = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
        auto wbh   = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
        brcst_wbh  = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, wbh);

        auto rbz  = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
        auto rbr  = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
        auto rbh  = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
        brcst_rbh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, rbh);

        auto bz  = prog.insert_instruction(ins, op::add{}, wbz, rbz);
        brcst_bz = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bz);

        auto br  = prog.insert_instruction(ins, op::add{}, wbr, rbr);
        brcst_br = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, br);

        auto bh  = prog.insert_instruction(ins, op::add{}, wbh, rbh);
        brcst_bh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bh);
251
252
253
254
    }

    for(long i = 0; i < seq_len; i++)
    {
255
        long seq_index = is_forward ? i : (seq_len - 1 - i);
256
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, seq);
257
        xt      = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
258

259
        // equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
Shucai Xiao's avatar
Shucai Xiao committed
260
261
        auto xt_wz = prog.insert_instruction(ins, op::dot{}, xt, tran_wz);
        auto ht_rz = prog.insert_instruction(ins, op::dot{}, sih, tran_rz);
262
        auto xht_z = prog.insert_instruction(ins, op::add{}, xt_wz, ht_rz);
Shucai Xiao's avatar
Shucai Xiao committed
263
        if(bias != prog.end())
264
        {
265
            xht_z = prog.insert_instruction(ins, op::add{}, xht_z, brcst_bz);
266
        }
267
        auto zt = prog.insert_instruction(ins, actv_func1, xht_z);
268
269

        // equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
Shucai Xiao's avatar
Shucai Xiao committed
270
271
        auto xt_wr = prog.insert_instruction(ins, op::dot{}, xt, tran_wr);
        auto ht_rr = prog.insert_instruction(ins, op::dot{}, sih, tran_rr);
272
        auto xht_r = prog.insert_instruction(ins, op::add{}, xt_wr, ht_rr);
Shucai Xiao's avatar
Shucai Xiao committed
273
        if(bias != prog.end())
274
        {
275
            xht_r = prog.insert_instruction(ins, op::add{}, xht_r, brcst_br);
276
        }
277
        auto rt = prog.insert_instruction(ins, actv_func1, xht_r);
278

279
        instruction_ref xht_h;
Shucai Xiao's avatar
Shucai Xiao committed
280
        if(linear_before_reset == 0)
281
282
        {
            // equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
283
284
            auto xt_wh  = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
            auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
Shucai Xiao's avatar
Shucai Xiao committed
285
286
            auto rt_rh  = prog.insert_instruction(ins, op::dot{}, rt_ht1, tran_rh);
            xht_h       = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
Shucai Xiao's avatar
Shucai Xiao committed
287
            if(bias != prog.end())
288
            {
289
                xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_bh);
290
291
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
292
        else
293
294
        {
            // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
Shucai Xiao's avatar
Shucai Xiao committed
295
            auto xt_wh  = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
296
            auto ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, tran_rh);
Shucai Xiao's avatar
Shucai Xiao committed
297
            if(bias != prog.end())
298
            {
299
                ht1_rh = prog.insert_instruction(ins, op::add{}, ht1_rh, brcst_rbh);
300
            }
301
            auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
Shucai Xiao's avatar
Shucai Xiao committed
302
            xht_h      = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
Shucai Xiao's avatar
Shucai Xiao committed
303
            if(bias != prog.end())
304
            {
305
                xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_wbh);
306
307
            }
        }
308
        auto ht = prog.insert_instruction(ins, actv_func2, xht_h);
309
310

        // equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
Shucai Xiao's avatar
Shucai Xiao committed
311
        auto one_minus_zt    = prog.insert_instruction(ins, op::sub{}, l1, zt);
312
        auto one_minus_zt_ht = prog.insert_instruction(ins, op::mul{}, one_minus_zt, ht);
Shucai Xiao's avatar
Shucai Xiao committed
313
314
        auto zt_ht1          = prog.insert_instruction(ins, op::mul{}, zt, sih);
        sih                  = prog.insert_instruction(ins, op::add{}, one_minus_zt_ht, zt_ht1);
315
        last_output          = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);
316

Shucai Xiao's avatar
Shucai Xiao committed
317
        if(i < seq_len - 1)
318
        {
319
320
            if(is_forward)
            {
Shucai Xiao's avatar
Shucai Xiao committed
321
322
323
324
                hidden_states =
                    (seq_index == 0)
                        ? last_output
                        : prog.insert_instruction(ins, op::concat{0}, hidden_states, last_output);
325
326
327
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
328
329
330
331
                hidden_states =
                    (seq_index == seq_len - 1)
                        ? last_output
                        : prog.insert_instruction(ins, op::concat{0}, last_output, hidden_states);
332
            }
333
334
335
        }
    }

336
    return {hidden_states, last_output};
337
338
}

Shucai Xiao's avatar
Shucai Xiao committed
339
340
341
342
343
344
345
346
347
348
349
350
std::vector<operation> rewrite_gru::compute_actv_funcs(instruction_ref ins) const
{
    auto gru_op = any_cast<op::gru>(ins->get_operator());
    // before rewrite the gru operator, need to ensure
    // we have 4 actv funcs, even though a user does not
    // specifiy any actv func. If less than 4, use the
    // algorithm in parse_gru to make 4 actv functions
    if(gru_op.direction == op::gru::bidirectional)
    {
        if(gru_op.actv_funcs.empty())
            return {op::sigmoid{}, op::tanh{}, op::sigmoid{}, op::tanh{}};
        else if(gru_op.actv_funcs.size() == 1)
Shucai Xiao's avatar
Shucai Xiao committed
351
            return {gru_op.actv_funcs.at(0),
Shucai Xiao's avatar
Shucai Xiao committed
352
353
354
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0)};
Shucai Xiao's avatar
Shucai Xiao committed
355
        else if(gru_op.actv_funcs.size() == 2)
Shucai Xiao's avatar
Shucai Xiao committed
356
            return {gru_op.actv_funcs.at(0),
Shucai Xiao's avatar
Shucai Xiao committed
357
358
359
360
                    gru_op.actv_funcs.at(1),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(1)};
        else if(gru_op.actv_funcs.size() == 3)
Shucai Xiao's avatar
Shucai Xiao committed
361
            return {gru_op.actv_funcs.at(0),
Shucai Xiao's avatar
Shucai Xiao committed
362
363
364
                    gru_op.actv_funcs.at(1),
                    gru_op.actv_funcs.at(2),
                    gru_op.actv_funcs.at(0)};
Shucai Xiao's avatar
Shucai Xiao committed
365
366
367
368
369
370
371
        else
            return gru_op.actv_funcs;
    }
    else
    {
        if(gru_op.actv_funcs.empty())
            return {op::sigmoid{}, op::tanh{}};
Shucai Xiao's avatar
Shucai Xiao committed
372
373
        else if(gru_op.actv_funcs.size() == 1)
            return {gru_op.actv_funcs.at(0), gru_op.actv_funcs.at(0)};
Shucai Xiao's avatar
Shucai Xiao committed
374
375
376
377
378
        else
            return gru_op.actv_funcs;
    }
}

379
380
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx