"vscode:/vscode.git/clone" did not exist on "e06a9ddafcec68f82730d608be1b0521c9c05ae0"
rewrite_gru.cpp 16.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
#include <migraphx/rewrite_gru.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

void rewrite_gru::apply(program& prog) const
{
Shucai Xiao's avatar
Shucai Xiao committed
13
    std::unordered_map<instruction_ref, instruction_ref> map_last_output;
14
15
    for(auto ins : iterator_for(prog))
    {
16
        if(ins->name() == "gru")
17
        {
Shucai Xiao's avatar
Shucai Xiao committed
18
            const auto actv_funcs = compute_actv_funcs(ins);            
19
20
21
22
23
24
25
            // could be 3 to 5 inputs (though onnx::rnn has 6 inputs,
            // the 5th one is undefined and ignored by protobuf. so
            // we need to process up to 5 inputs
            auto args = ins->inputs();

            shape seq_shape         = args[0]->get_shape();
            std::size_t hidden_size = args[2]->get_shape().lens()[2];
Shucai Xiao's avatar
Shucai Xiao committed
26
            std::size_t batch_size  = seq_shape.lens()[1];
27
            shape::type_t type      = seq_shape.type();
Shucai Xiao's avatar
Shucai Xiao committed
28
            migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
29
30
31
32
33
            std::vector<char> data(ih_shape.bytes(), 0);

            auto gru_op                    = any_cast<op::gru>(ins->get_operator());
            op::gru::gru_direction_t dicrt = gru_op.direction;
            if(dicrt == op::gru::bidirectional)
34
            {
35
36
37
38
39
40
41
42
43
44
45
                // w weight matrix
                auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
                auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);

                // r weight matrix
                auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
                auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);

                // bias
                instruction_ref bias_forward, bias_reverse;
                bias_forward = bias_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
46
                if(args.size() >= 4 && args[3]->get_operator().name() != "undefined")
47
48
49
50
51
52
53
                {
                    bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
                    bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
                }

                // intial hidden state
                instruction_ref ih_forward, ih_reverse;
Shucai Xiao's avatar
Shucai Xiao committed
54
                if(args.size() == 6 && args[5]->get_operator().name() != "undefined")
55
                {
Shucai Xiao's avatar
Shucai Xiao committed
56
57
                    ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
                    ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
                }
                else
                {
                    ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
                    ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
                }

                auto ret_forward = gru_cell(true,
                                            prog,
                                            ins,
                                            args[0],
                                            w_forward,
                                            r_forward,
                                            bias_forward,
                                            ih_forward,
                                            gru_op.linear_before_reset,
Shucai Xiao's avatar
Shucai Xiao committed
74
75
                                            actv_funcs.at(0), 
                                            actv_funcs.at(1));
76
77
78
79
80
81
82
83
84
85

                auto ret_reverse = gru_cell(false,
                                            prog,
                                            ins,
                                            args[0],
                                            w_reverse,
                                            r_reverse,
                                            bias_reverse,
                                            ih_reverse,
                                            gru_op.linear_before_reset,
Shucai Xiao's avatar
Shucai Xiao committed
86
87
                                            actv_funcs.at(2), 
                                            actv_funcs.at(3));
88

89
90
91
92
93
94
95
                auto concat_output =
                    prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
                auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);

                // The following logic is to ensure the last instruction rewritten
                // from gru operator is a concat
                instruction_ref hidden_state{};
Shucai Xiao's avatar
Shucai Xiao committed
96
                if(ret_forward[0] == prog.end())
97
                {
Shucai Xiao's avatar
Shucai Xiao committed
98
99
                    hidden_state = prog.replace_instruction(
                        ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
100
                }
Shucai Xiao's avatar
Shucai Xiao committed
101
                else
102
103
104
105
106
107
108
                {
                    ret_forward[0] =
                        prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
                    ret_reverse[0] =
                        prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
                    hidden_state = prog.replace_instruction(
                        ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
Shucai Xiao's avatar
Shucai Xiao committed
109
                }
Shucai Xiao's avatar
Shucai Xiao committed
110
                map_last_output[hidden_state] = last_output;
111
112
113
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
114
                bool is_forward = (dicrt == op::gru::forward);
115
116
117
118
119
120
                // weight matrix
                auto w = args[1];
                auto r = args[2];

                // bias
                instruction_ref bias = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
121
                if(args.size() >= 4 && args[3]->get_operator().name() != "undefined")
122
123
124
125
126
127
                {
                    bias = args[3];
                }

                // intial hidden state
                instruction_ref ih;
Shucai Xiao's avatar
Shucai Xiao committed
128
                if(args.size() == 6 && args[5]->get_operator().name() != "undefined")
129
                {
Shucai Xiao's avatar
Shucai Xiao committed
130
                    ih = args[5];
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
                }
                else
                {
                    ih = prog.add_literal(migraphx::literal{ih_shape, data});
                }

                auto ret = gru_cell(is_forward,
                                    prog,
                                    ins,
                                    args[0],
                                    w,
                                    r,
                                    bias,
                                    ih,
                                    gru_op.linear_before_reset,
Shucai Xiao's avatar
Shucai Xiao committed
146
147
                                    actv_funcs.at(0), 
                                    actv_funcs.at(1));
148

149
                auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
150

151
                instruction_ref hidden_state{};
Shucai Xiao's avatar
Shucai Xiao committed
152
                if(ret[0] == prog.end())
153
154
155
156
157
158
                {
                    hidden_state = prog.replace_instruction(ins, op::concat{0}, ret[1]);
                }
                else
                {
                    auto concat_arg0 = is_forward ? ret[0] : ret[1];
Shucai Xiao's avatar
Shucai Xiao committed
159
                    auto concat_arg1 = is_forward ? ret[1] : ret[0];
Shucai Xiao's avatar
Shucai Xiao committed
160
161
                    hidden_state =
                        prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
162
                }
Shucai Xiao's avatar
Shucai Xiao committed
163
                map_last_output[hidden_state] = last_output;
164
165
166
            }
        }

167
168
169
170
        // rewrite the gru_last_output operator that right after the gru
        // operator. Intuitively, we can do a slice on its input to get
        // the last output, but it is already existed in the rnn operator,
        // so we can just use it as the output here
Shucai Xiao's avatar
Shucai Xiao committed
171
        if(ins->name() == "gru_last_output")
172
        {
Shucai Xiao's avatar
Shucai Xiao committed
173
174
175
176
            auto inputs = ins->inputs();
            assert(inputs.size() == 1);
            assert(map_last_output.count(inputs[0]) > 0);
            prog.replace_instruction(ins, map_last_output[inputs[0]]);
177
178
179
180
        }
    }
}

181
std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward,
182
183
184
185
186
187
                                                   program& prog,
                                                   instruction_ref ins,
                                                   instruction_ref input,
                                                   instruction_ref w,
                                                   instruction_ref r,
                                                   instruction_ref bias,
188
                                                   instruction_ref ih,
189
                                                   int linear_before_reset,
Shucai Xiao's avatar
Shucai Xiao committed
190
191
                                                   const operation& actv_func1,
                                                   const operation& actv_func2) const
192
{
Shucai Xiao's avatar
Shucai Xiao committed
193
    assert(actv_funcs.size() == 2);
194
    instruction_ref hidden_states = prog.end(), last_output;
Shucai Xiao's avatar
Shucai Xiao committed
195
196
    long seq_len                  = static_cast<long>(input->get_shape().lens()[0]);
    long hs                       = static_cast<long>(r->get_shape().lens()[2]);
197

Shucai Xiao's avatar
Shucai Xiao committed
198
199
    migraphx::shape s(input->get_shape().type(),
                      {input->get_shape().lens()[1], static_cast<std::size_t>(hs)});
200
201
    std::vector<int> data(s.elements(), 1);
    auto l1 = prog.add_literal(migraphx::literal{s, data});
202
203
204

    // weight matrix
    std::vector<int64_t> perm{1, 0};
Shucai Xiao's avatar
Shucai Xiao committed
205
206
    auto sw      = prog.insert_instruction(ins, op::squeeze{{0}}, w);
    auto wz      = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sw);
207
208
    auto tran_wz = prog.insert_instruction(ins, op::transpose{perm}, wz);

Shucai Xiao's avatar
Shucai Xiao committed
209
    auto wr      = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sw);
210
    auto tran_wr = prog.insert_instruction(ins, op::transpose{perm}, wr);
Shucai Xiao's avatar
Shucai Xiao committed
211
212

    auto wh      = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sw);
213
214
    auto tran_wh = prog.insert_instruction(ins, op::transpose{perm}, wh);

Shucai Xiao's avatar
Shucai Xiao committed
215
216
    auto sr      = prog.insert_instruction(ins, op::squeeze{{0}}, r);
    auto rz      = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sr);
217
218
    auto tran_rz = prog.insert_instruction(ins, op::transpose{perm}, rz);

Shucai Xiao's avatar
Shucai Xiao committed
219
    auto rr      = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sr);
220
221
    auto tran_rr = prog.insert_instruction(ins, op::transpose{perm}, rr);

Shucai Xiao's avatar
Shucai Xiao committed
222
    auto rh      = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
223
224
225
226
    auto tran_rh = prog.insert_instruction(ins, op::transpose{perm}, rh);

    // initial states
    auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
227
228

    // bias
229
    instruction_ref brcst_bz, brcst_br, brcst_wbh, brcst_rbh, brcst_bh;
Shucai Xiao's avatar
Shucai Xiao committed
230
    if(bias != prog.end())
231
    {
232
        auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
Shucai Xiao's avatar
Shucai Xiao committed
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
        auto wbz   = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
        auto wbr   = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
        auto wbh   = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
        brcst_wbh  = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, wbh);

        auto rbz  = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
        auto rbr  = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
        auto rbh  = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
        brcst_rbh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, rbh);

        auto bz  = prog.insert_instruction(ins, op::add{}, wbz, rbz);
        brcst_bz = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bz);

        auto br  = prog.insert_instruction(ins, op::add{}, wbr, rbr);
        brcst_br = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, br);

        auto bh  = prog.insert_instruction(ins, op::add{}, wbh, rbh);
        brcst_bh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bh);
251
252
253
254
    }

    for(long i = 0; i < seq_len; i++)
    {
255
        long seq_index = is_forward ? i : (seq_len - 1 - i);
256
257
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
        xt      = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
258

259
        // equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
Shucai Xiao's avatar
Shucai Xiao committed
260
261
        auto xt_wz = prog.insert_instruction(ins, op::dot{}, xt, tran_wz);
        auto ht_rz = prog.insert_instruction(ins, op::dot{}, sih, tran_rz);
262
        auto xht_z = prog.insert_instruction(ins, op::add{}, xt_wz, ht_rz);
Shucai Xiao's avatar
Shucai Xiao committed
263
        if(bias != prog.end())
264
        {
265
            xht_z = prog.insert_instruction(ins, op::add{}, xht_z, brcst_bz);
266
        }
267
        auto zt = prog.insert_instruction(ins, actv_func1, xht_z);
268
269

        // equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
Shucai Xiao's avatar
Shucai Xiao committed
270
271
        auto xt_wr = prog.insert_instruction(ins, op::dot{}, xt, tran_wr);
        auto ht_rr = prog.insert_instruction(ins, op::dot{}, sih, tran_rr);
272
        auto xht_r = prog.insert_instruction(ins, op::add{}, xt_wr, ht_rr);
Shucai Xiao's avatar
Shucai Xiao committed
273
        if(bias != prog.end())
274
        {
275
            xht_r = prog.insert_instruction(ins, op::add{}, xht_r, brcst_br);
276
        }
277
        auto rt = prog.insert_instruction(ins, actv_func1, xht_r);
278

279
        instruction_ref xht_h;
Shucai Xiao's avatar
Shucai Xiao committed
280
        if(linear_before_reset == 0)
281
282
        {
            // equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
283
284
            auto xt_wh  = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
            auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
Shucai Xiao's avatar
Shucai Xiao committed
285
286
            auto rt_rh  = prog.insert_instruction(ins, op::dot{}, rt_ht1, tran_rh);
            xht_h       = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
Shucai Xiao's avatar
Shucai Xiao committed
287
            if(bias != prog.end())
288
            {
289
                xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_bh);
290
291
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
292
        else
293
294
        {
            // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
Shucai Xiao's avatar
Shucai Xiao committed
295
            auto xt_wh  = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
296
            auto ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, tran_rh);
Shucai Xiao's avatar
Shucai Xiao committed
297
            if(bias != prog.end())
298
            {
299
                ht1_rh = prog.insert_instruction(ins, op::add{}, ht1_rh, brcst_rbh);
300
            }
301
            auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
Shucai Xiao's avatar
Shucai Xiao committed
302
            xht_h      = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
Shucai Xiao's avatar
Shucai Xiao committed
303
            if(bias != prog.end())
304
            {
305
                xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_wbh);
306
307
            }
        }
308
        auto ht = prog.insert_instruction(ins, actv_func2, xht_h);
309
310

        // equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
Shucai Xiao's avatar
Shucai Xiao committed
311
        auto one_minus_zt    = prog.insert_instruction(ins, op::sub{}, l1, zt);
312
        auto one_minus_zt_ht = prog.insert_instruction(ins, op::mul{}, one_minus_zt, ht);
Shucai Xiao's avatar
Shucai Xiao committed
313
314
        auto zt_ht1          = prog.insert_instruction(ins, op::mul{}, zt, sih);
        sih                  = prog.insert_instruction(ins, op::add{}, one_minus_zt_ht, zt_ht1);
315
        last_output          = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);
316

Shucai Xiao's avatar
Shucai Xiao committed
317
        if(i < seq_len - 1)
318
        {
319
320
            if(is_forward)
            {
Shucai Xiao's avatar
Shucai Xiao committed
321
322
323
324
                hidden_states =
                    (seq_index == 0)
                        ? last_output
                        : prog.insert_instruction(ins, op::concat{0}, hidden_states, last_output);
325
326
327
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
328
329
330
331
                hidden_states =
                    (seq_index == seq_len - 1)
                        ? last_output
                        : prog.insert_instruction(ins, op::concat{0}, last_output, hidden_states);
332
            }
333
334
335
        }
    }

336
    return {hidden_states, last_output};
337
338
}

Shucai Xiao's avatar
Shucai Xiao committed
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
std::vector<operation> rewrite_gru::compute_actv_funcs(instruction_ref ins) const
{
    auto gru_op = any_cast<op::gru>(ins->get_operator());
    // before rewrite the gru operator, need to ensure
    // we have 4 actv funcs, even though a user does not
    // specifiy any actv func. If less than 4, use the
    // algorithm in parse_gru to make 4 actv functions
    if(gru_op.direction == op::gru::bidirectional)
    {
        if(gru_op.actv_funcs.empty())
            return {op::sigmoid{}, op::tanh{}, op::sigmoid{}, op::tanh{}};
        else if(gru_op.actv_funcs.size() == 1)
            return {gru_op.actv_funcs.at(0), 
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0),
                    gru_op.actv_funcs.at(0)};
        else if (gru_op.actv_funcs.size() == 2)
            return {gru_op.actv_funcs.at(0),
            gru_op.actv_funcs.at(1),
            gru_op.actv_funcs.at(0),
            gru_op.actv_funcs.at(1)};
        else if (gru_op.actv_funcs.size() == 3)
            return {gru_op.actv_funcs.at(0),
            gru_op.actv_funcs.at(1),
            gru_op.actv_funcs.at(2),
            gru_op.actv_funcs.at(0)};
        else
            return gru_op.actv_funcs;
    }
    else
    {
        if(gru_op.actv_funcs.empty())
            return {op::sigmoid{}, op::tanh{}};
        else if (gru_op.actv_funcs.size() == 1)
            return {gru_op.actv_funcs.at(0), 
                    gru_op.actv_funcs.at(0)};
        else
            return gru_op.actv_funcs;
    }
}

380
381
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx