rewrite_gru.cpp 14.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
#include <migraphx/rewrite_gru.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

void rewrite_gru::apply(program& prog) const
{
13
    instruction_ref last_output = prog.end();
14
15
    for(auto ins : iterator_for(prog))
    {
16
        if(ins->name() == "gru")
17
        {
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
            // could be 3 to 5 inputs (though onnx::rnn has 6 inputs,
            // the 5th one is undefined and ignored by protobuf. so
            // we need to process up to 5 inputs
            auto args = ins->inputs();

            shape seq_shape         = args[0]->get_shape();
            std::size_t hidden_size = args[2]->get_shape().lens()[2];
            std::size_t batchs      = seq_shape.lens()[1];
            shape::type_t type      = seq_shape.type();
            migraphx::shape ih_shape{type, {1, batchs, hidden_size}};
            std::vector<char> data(ih_shape.bytes(), 0);

            auto gru_op                    = any_cast<op::gru>(ins->get_operator());
            op::gru::gru_direction_t dicrt = gru_op.direction;
            if(dicrt == op::gru::bidirectional)
33
            {
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
                // w weight matrix
                auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
                auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);

                // r weight matrix
                auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
                auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);

                // bias
                instruction_ref bias_forward, bias_reverse;
                bias_forward = bias_reverse = prog.end();
                if(args.size() >= 4)
                {
                    bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
                    bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
                }

                // intial hidden state
                instruction_ref ih_forward, ih_reverse;
                if(args.size() == 6 || (args.size() == 5 && args[4]->get_shape().lens().size() == 3))
                {
                    auto arg_ih = (args.size() == 6) ? args[5] : args[4];
                    ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, arg_ih);
                    ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, arg_ih);
                }
                else
                {
                    ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
                    ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
                }

                auto ret_forward = gru_cell(true,
                                            prog,
                                            ins,
                                            args[0],
                                            w_forward,
                                            r_forward,
                                            bias_forward,
                                            ih_forward,
                                            gru_op.linear_before_reset,
                                            gru_op.actv_funcs.at(0),
                                            gru_op.actv_funcs.at(1));

                auto ret_reverse = gru_cell(false,
                                            prog,
                                            ins,
                                            args[0],
                                            w_reverse,
                                            r_reverse,
                                            bias_reverse,
                                            ih_reverse,
                                            gru_op.linear_before_reset,
                                            gru_op.actv_funcs.at(2),
                                            gru_op.actv_funcs.at(3));

                last_output = prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]);

                // add the dimension of num_direction
                ret_forward[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_forward[0]);
                ret_reverse[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_reverse[0]);

                // concat the forward and reverse output
                prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
97
98
99
            }
            else
            {
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
                bool is_forward = (dicrt == op::gru::forward) ? true : false;
                // weight matrix
                auto w = args[1];
                auto r = args[2];

                // bias
                instruction_ref bias = prog.end();
                if(args.size() >= 4)
                {
                    bias = args[3];
                }

                // intial hidden state
                instruction_ref ih;
                if(args.size() == 6 || (args.size() == 5 && args[4]->get_shape().lens().size() == 3))
                {
                    ih = args.size() == 6 ? args[5]: args[4];
                }
                else
                {
                    ih = prog.add_literal(migraphx::literal{ih_shape, data});
                }

                auto ret = gru_cell(is_forward,
                                    prog,
                                    ins,
                                    args[0],
                                    w,
                                    r,
                                    bias,
                                    ih,
                                    gru_op.linear_before_reset,
                                    gru_op.actv_funcs.at(0),
                                    gru_op.actv_funcs.at(1));

                last_output = ret[1];

                // add the dimension of num_direction
                prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]);
139
140
141
            }
        }

142
143
144
145
146
147
148
        // rewrite the gru_last_output operator that right after the gru
        // operator. Intuitively, we can do a slice on its input to get
        // the last output, but it is already existed in the rnn operator,
        // so we can just use it as the output here
        if (ins->name() == "gru_last_output")
        {
            if (last_output != prog.end())
149
            {
150
151
                prog.replace_instruction(ins, op::identity{}, last_output);
                last_output = prog.end();
152
153
154
155
156
            }
        }
    }
}

157
std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward,
158
159
160
161
162
163
                                                   program& prog,
                                                   instruction_ref ins,
                                                   instruction_ref input,
                                                   instruction_ref w,
                                                   instruction_ref r,
                                                   instruction_ref bias,
164
                                                   instruction_ref ih,
165
166
167
168
                                                   int linear_before_reset,
                                                   operation& actv_func1,
                                                   operation& actv_func2) const
{
169
    instruction_ref hidden_out, last_out;
Shucai Xiao's avatar
Shucai Xiao committed
170
    long seq_len   = static_cast<long>(input->get_shape().lens()[0]);
171
    long hs        = static_cast<long>(r->get_shape().lens()[2]);
172

Shucai Xiao's avatar
Shucai Xiao committed
173
174
    migraphx::shape s(input->get_shape().type(),
                      {input->get_shape().lens()[1], static_cast<std::size_t>(hs)});
175
176
    std::vector<int> data(s.elements(), 1);
    auto l1 = prog.add_literal(migraphx::literal{s, data});
177
178
179

    // weight matrix
    std::vector<int64_t> perm{1, 0};
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
    auto wz  = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sw);
    auto tran_wz = prog.insert_instruction(ins, op::transpose{perm}, wz);

    auto wr  = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sw);
    auto tran_wr = prog.insert_instruction(ins, op::transpose{perm}, wr);
    
    auto wh  = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sw);
    auto tran_wh = prog.insert_instruction(ins, op::transpose{perm}, wh);

    auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
    auto rz  = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sr);
    auto tran_rz = prog.insert_instruction(ins, op::transpose{perm}, rz);

    auto rr  = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sr);
    auto tran_rr = prog.insert_instruction(ins, op::transpose{perm}, rr);

    auto rh  = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
    auto tran_rh = prog.insert_instruction(ins, op::transpose{perm}, rh);

    // initial states
    auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
202
203

    // bias
204
    instruction_ref brcst_bz, brcst_br, brcst_wbh, brcst_rbh, brcst_bh;
Shucai Xiao's avatar
Shucai Xiao committed
205
    if(bias != prog.end())
206
    {
207
208
209
210
211
        auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
        auto wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
        auto wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
        auto wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
        brcst_wbh   = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, wbh);
212

213
214
215
216
        auto rbz = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
        auto rbr = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
        auto rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
        brcst_rbh   = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, rbh);
217
218

        auto bz = prog.insert_instruction(ins, op::add{}, wbz, rbz);
219
220
        brcst_bz   = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bz);

221
        auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr);
222
223
        brcst_br   = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, br);
        
Shucai Xiao's avatar
Shucai Xiao committed
224
        auto bh = prog.insert_instruction(ins, op::add{}, wbh, rbh);
225
        brcst_bh   = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bh);
226
227
    }

228
    long seq_index = is_forward ? 0 : seq_len - 1;
229
230
231
232
    for(long i = 0; i < seq_len; i++)
    {
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
        xt      = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
233

234
        // equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
235
236
237
        auto xt_wz    = prog.insert_instruction(ins, op::dot{}, xt, tran_wz);
        auto ht_rz    = prog.insert_instruction(ins, op::dot{}, sih, tran_rz);
        auto xht_z = prog.insert_instruction(ins, op::add{}, xt_wz, ht_rz);
Shucai Xiao's avatar
Shucai Xiao committed
238
        if(bias != prog.end())
239
        {
240
            xht_z = prog.insert_instruction(ins, op::add{}, xht_z, brcst_bz);
241
        }
242
        auto zt = prog.insert_instruction(ins, actv_func1, xht_z);
243
244

        // equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
245
246
247
        auto xt_wr    = prog.insert_instruction(ins, op::dot{}, xt, tran_wr);
        auto ht_rr    = prog.insert_instruction(ins, op::dot{}, sih, tran_rr);
        auto xht_r = prog.insert_instruction(ins, op::add{}, xt_wr, ht_rr);
Shucai Xiao's avatar
Shucai Xiao committed
248
        if(bias != prog.end())
249
        {
250
            xht_r = prog.insert_instruction(ins, op::add{}, xht_r, brcst_br);
251
        }
252
        auto rt = prog.insert_instruction(ins, actv_func1, xht_r);
253

254
        instruction_ref xht_h;
Shucai Xiao's avatar
Shucai Xiao committed
255
        if(linear_before_reset == 0)
256
257
        {
            // equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
258
259
260
261
            auto xt_wh  = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
            auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
            auto rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht1, tran_rh);
            xht_h    = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
Shucai Xiao's avatar
Shucai Xiao committed
262
            if(bias != prog.end())
263
            {
264
                xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_bh);
265
266
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
267
        else
268
269
        {
            // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
270
271
            auto xt_wh   = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
            auto ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, tran_rh);
Shucai Xiao's avatar
Shucai Xiao committed
272
            if(bias != prog.end())
273
            {
274
                ht1_rh = prog.insert_instruction(ins, op::add{}, ht1_rh, brcst_rbh);
275
            }
276
277
            auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
            xht_h    = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
Shucai Xiao's avatar
Shucai Xiao committed
278
            if(bias != prog.end())
279
            {
280
                xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_wbh);
281
282
            }
        }
283
        auto ht = prog.insert_instruction(ins, actv_func2, xht_h);
284
285

        // equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
286
287
288
289
290
        auto one_minus_zt   = prog.insert_instruction(ins, op::sub{}, l1, zt);
        auto one_minus_zt_ht = prog.insert_instruction(ins, op::mul{}, one_minus_zt, ht);
        auto zt_ht1 = prog.insert_instruction(ins, op::mul{}, zt, sih);
        sih         = prog.insert_instruction(ins, op::add{}, one_minus_zt_ht, zt_ht1);
        last_out  = prog.insert_instruction(ins, op::unsqueeze{{0}}, sih);
291
292
293

        if(is_forward)
        {
Shucai Xiao's avatar
Shucai Xiao committed
294
            hidden_out = (seq_index == 0)
295
296
                             ? last_out
                             : prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out);
297
298
299
300
        }
        else
        {
            hidden_out = (seq_index == seq_len - 1)
301
302
                             ? last_out
                             : prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out);
303
304
305
306
307
308
        }
        seq_index = is_forward ? (seq_index + 1) : (seq_index - 1);
    }

    std::vector<instruction_ref> out_args;
    out_args.push_back(hidden_out);
309
    out_args.push_back(last_out);
310
311
312
313
314
315

    return out_args;
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx