rewrite_gru.cpp 14.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
#include <migraphx/rewrite_gru.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

void rewrite_gru::apply(program& prog) const
{
Shucai Xiao's avatar
Shucai Xiao committed
13
    std::unordered_map<instruction_ref, instruction_ref> map_last_output;
14
15
    for(auto ins : iterator_for(prog))
    {
16
        if(ins->name() == "gru")
17
        {
18
19
20
21
22
23
24
            // could be 3 to 5 inputs (though onnx::rnn has 6 inputs,
            // the 5th one is undefined and ignored by protobuf. so
            // we need to process up to 5 inputs
            auto args = ins->inputs();

            shape seq_shape         = args[0]->get_shape();
            std::size_t hidden_size = args[2]->get_shape().lens()[2];
Shucai Xiao's avatar
Shucai Xiao committed
25
            std::size_t batch_size      = seq_shape.lens()[1];
26
            shape::type_t type      = seq_shape.type();
Shucai Xiao's avatar
Shucai Xiao committed
27
            migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
28
29
30
31
32
            std::vector<char> data(ih_shape.bytes(), 0);

            auto gru_op                    = any_cast<op::gru>(ins->get_operator());
            op::gru::gru_direction_t dicrt = gru_op.direction;
            if(dicrt == op::gru::bidirectional)
33
            {
34
35
36
37
38
39
40
41
42
43
44
                // w weight matrix
                auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
                auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);

                // r weight matrix
                auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
                auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);

                // bias
                instruction_ref bias_forward, bias_reverse;
                bias_forward = bias_reverse = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
45
                if(args.size() >= 4 && args[3]->get_operator().name() != "undefined")
46
47
48
49
50
51
52
                {
                    bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
                    bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
                }

                // intial hidden state
                instruction_ref ih_forward, ih_reverse;
Shucai Xiao's avatar
Shucai Xiao committed
53
                if(args.size() == 6 && args[5]->get_operator().name() != "undefined")
54
                {
Shucai Xiao's avatar
Shucai Xiao committed
55
56
                    ih_forward  = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
                    ih_reverse  = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
                }
                else
                {
                    ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
                    ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
                }

                auto ret_forward = gru_cell(true,
                                            prog,
                                            ins,
                                            args[0],
                                            w_forward,
                                            r_forward,
                                            bias_forward,
                                            ih_forward,
                                            gru_op.linear_before_reset,
                                            gru_op.actv_funcs.at(0),
                                            gru_op.actv_funcs.at(1));

                auto ret_reverse = gru_cell(false,
                                            prog,
                                            ins,
                                            args[0],
                                            w_reverse,
                                            r_reverse,
                                            bias_reverse,
                                            ih_reverse,
                                            gru_op.linear_before_reset,
                                            gru_op.actv_funcs.at(2),
                                            gru_op.actv_funcs.at(3));

Shucai Xiao's avatar
Shucai Xiao committed
88
                auto last_output =
Shucai Xiao's avatar
Shucai Xiao committed
89
                    prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]);
90
91
92
93
94
95

                // add the dimension of num_direction
                ret_forward[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_forward[0]);
                ret_reverse[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_reverse[0]);

                // concat the forward and reverse output
Shucai Xiao's avatar
Shucai Xiao committed
96
97
                auto hidden_state = prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
                map_last_output[hidden_state] = last_output;
98
99
100
            }
            else
            {
101
102
103
104
105
106
107
                bool is_forward = (dicrt == op::gru::forward) ? true : false;
                // weight matrix
                auto w = args[1];
                auto r = args[2];

                // bias
                instruction_ref bias = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
108
                if(args.size() >= 4 && args[3]->get_operator().name() != "undefined")
109
110
111
112
113
114
                {
                    bias = args[3];
                }

                // intial hidden state
                instruction_ref ih;
Shucai Xiao's avatar
Shucai Xiao committed
115
                if(args.size() == 6 && args[5]->get_operator().name() != "undefined")
116
                {
Shucai Xiao's avatar
Shucai Xiao committed
117
                    ih = args[5];
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
                }
                else
                {
                    ih = prog.add_literal(migraphx::literal{ih_shape, data});
                }

                auto ret = gru_cell(is_forward,
                                    prog,
                                    ins,
                                    args[0],
                                    w,
                                    r,
                                    bias,
                                    ih,
                                    gru_op.linear_before_reset,
                                    gru_op.actv_funcs.at(0),
                                    gru_op.actv_funcs.at(1));

Shucai Xiao's avatar
Shucai Xiao committed
136
                auto last_output = ret[1];
137
138

                // add the dimension of num_direction
Shucai Xiao's avatar
Shucai Xiao committed
139
140
                auto hidden_state = prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]);
                map_last_output[hidden_state] = last_output;
141
142
143
            }
        }

144
145
146
147
        // rewrite the gru_last_output operator that right after the gru
        // operator. Intuitively, we can do a slice on its input to get
        // the last output, but it is already existed in the rnn operator,
        // so we can just use it as the output here
Shucai Xiao's avatar
Shucai Xiao committed
148
        if(ins->name() == "gru_last_output")
149
        {
Shucai Xiao's avatar
Shucai Xiao committed
150
151
152
153
            auto inputs = ins->inputs();
            assert(inputs.size() == 1);
            assert(map_last_output.count(inputs[0]) > 0);
            prog.replace_instruction(ins, map_last_output[inputs[0]]);
154
155
156
157
        }
    }
}

158
std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward,
159
160
161
162
163
164
                                                   program& prog,
                                                   instruction_ref ins,
                                                   instruction_ref input,
                                                   instruction_ref w,
                                                   instruction_ref r,
                                                   instruction_ref bias,
165
                                                   instruction_ref ih,
166
167
168
169
                                                   int linear_before_reset,
                                                   operation& actv_func1,
                                                   operation& actv_func2) const
{
170
    instruction_ref hidden_out, last_out;
Shucai Xiao's avatar
Shucai Xiao committed
171
172
    long seq_len = static_cast<long>(input->get_shape().lens()[0]);
    long hs      = static_cast<long>(r->get_shape().lens()[2]);
173

Shucai Xiao's avatar
Shucai Xiao committed
174
175
    migraphx::shape s(input->get_shape().type(),
                      {input->get_shape().lens()[1], static_cast<std::size_t>(hs)});
176
177
    std::vector<int> data(s.elements(), 1);
    auto l1 = prog.add_literal(migraphx::literal{s, data});
178
179
180

    // weight matrix
    std::vector<int64_t> perm{1, 0};
Shucai Xiao's avatar
Shucai Xiao committed
181
182
    auto sw      = prog.insert_instruction(ins, op::squeeze{{0}}, w);
    auto wz      = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sw);
183
184
    auto tran_wz = prog.insert_instruction(ins, op::transpose{perm}, wz);

Shucai Xiao's avatar
Shucai Xiao committed
185
    auto wr      = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sw);
186
    auto tran_wr = prog.insert_instruction(ins, op::transpose{perm}, wr);
Shucai Xiao's avatar
Shucai Xiao committed
187
188

    auto wh      = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sw);
189
190
    auto tran_wh = prog.insert_instruction(ins, op::transpose{perm}, wh);

Shucai Xiao's avatar
Shucai Xiao committed
191
192
    auto sr      = prog.insert_instruction(ins, op::squeeze{{0}}, r);
    auto rz      = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sr);
193
194
    auto tran_rz = prog.insert_instruction(ins, op::transpose{perm}, rz);

Shucai Xiao's avatar
Shucai Xiao committed
195
    auto rr      = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sr);
196
197
    auto tran_rr = prog.insert_instruction(ins, op::transpose{perm}, rr);

Shucai Xiao's avatar
Shucai Xiao committed
198
    auto rh      = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sr);
199
200
201
202
    auto tran_rh = prog.insert_instruction(ins, op::transpose{perm}, rh);

    // initial states
    auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);
203
204

    // bias
205
    instruction_ref brcst_bz, brcst_br, brcst_wbh, brcst_rbh, brcst_bh;
Shucai Xiao's avatar
Shucai Xiao committed
206
    if(bias != prog.end())
207
    {
208
        auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
Shucai Xiao's avatar
Shucai Xiao committed
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
        auto wbz   = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
        auto wbr   = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
        auto wbh   = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, sbias);
        brcst_wbh  = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, wbh);

        auto rbz  = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, sbias);
        auto rbr  = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, sbias);
        auto rbh  = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, sbias);
        brcst_rbh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, rbh);

        auto bz  = prog.insert_instruction(ins, op::add{}, wbz, rbz);
        brcst_bz = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bz);

        auto br  = prog.insert_instruction(ins, op::add{}, wbr, rbr);
        brcst_br = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, br);

        auto bh  = prog.insert_instruction(ins, op::add{}, wbh, rbh);
        brcst_bh = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, bh);
227
228
    }

229
    long seq_index = is_forward ? 0 : seq_len - 1;
230
231
232
233
    for(long i = 0; i < seq_len; i++)
    {
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
        xt      = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
234

235
        // equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
Shucai Xiao's avatar
Shucai Xiao committed
236
237
        auto xt_wz = prog.insert_instruction(ins, op::dot{}, xt, tran_wz);
        auto ht_rz = prog.insert_instruction(ins, op::dot{}, sih, tran_rz);
238
        auto xht_z = prog.insert_instruction(ins, op::add{}, xt_wz, ht_rz);
Shucai Xiao's avatar
Shucai Xiao committed
239
        if(bias != prog.end())
240
        {
241
            xht_z = prog.insert_instruction(ins, op::add{}, xht_z, brcst_bz);
242
        }
243
        auto zt = prog.insert_instruction(ins, actv_func1, xht_z);
244
245

        // equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
Shucai Xiao's avatar
Shucai Xiao committed
246
247
        auto xt_wr = prog.insert_instruction(ins, op::dot{}, xt, tran_wr);
        auto ht_rr = prog.insert_instruction(ins, op::dot{}, sih, tran_rr);
248
        auto xht_r = prog.insert_instruction(ins, op::add{}, xt_wr, ht_rr);
Shucai Xiao's avatar
Shucai Xiao committed
249
        if(bias != prog.end())
250
        {
251
            xht_r = prog.insert_instruction(ins, op::add{}, xht_r, brcst_br);
252
        }
253
        auto rt = prog.insert_instruction(ins, actv_func1, xht_r);
254

255
        instruction_ref xht_h;
Shucai Xiao's avatar
Shucai Xiao committed
256
        if(linear_before_reset == 0)
257
258
        {
            // equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
259
260
            auto xt_wh  = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
            auto rt_ht1 = prog.insert_instruction(ins, op::mul{}, rt, sih);
Shucai Xiao's avatar
Shucai Xiao committed
261
262
            auto rt_rh  = prog.insert_instruction(ins, op::dot{}, rt_ht1, tran_rh);
            xht_h       = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
Shucai Xiao's avatar
Shucai Xiao committed
263
            if(bias != prog.end())
264
            {
265
                xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_bh);
266
267
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
268
        else
269
270
        {
            // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
Shucai Xiao's avatar
Shucai Xiao committed
271
            auto xt_wh  = prog.insert_instruction(ins, op::dot{}, xt, tran_wh);
272
            auto ht1_rh = prog.insert_instruction(ins, op::dot{}, sih, tran_rh);
Shucai Xiao's avatar
Shucai Xiao committed
273
            if(bias != prog.end())
274
            {
275
                ht1_rh = prog.insert_instruction(ins, op::add{}, ht1_rh, brcst_rbh);
276
            }
277
            auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ht1_rh);
Shucai Xiao's avatar
Shucai Xiao committed
278
            xht_h      = prog.insert_instruction(ins, op::add{}, xt_wh, rt_rh);
Shucai Xiao's avatar
Shucai Xiao committed
279
            if(bias != prog.end())
280
            {
281
                xht_h = prog.insert_instruction(ins, op::add{}, xht_h, brcst_wbh);
282
283
            }
        }
284
        auto ht = prog.insert_instruction(ins, actv_func2, xht_h);
285
286

        // equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
Shucai Xiao's avatar
Shucai Xiao committed
287
        auto one_minus_zt    = prog.insert_instruction(ins, op::sub{}, l1, zt);
288
        auto one_minus_zt_ht = prog.insert_instruction(ins, op::mul{}, one_minus_zt, ht);
Shucai Xiao's avatar
Shucai Xiao committed
289
290
291
        auto zt_ht1          = prog.insert_instruction(ins, op::mul{}, zt, sih);
        sih                  = prog.insert_instruction(ins, op::add{}, one_minus_zt_ht, zt_ht1);
        last_out             = prog.insert_instruction(ins, op::unsqueeze{{0}}, sih);
292
293
294

        if(is_forward)
        {
Shucai Xiao's avatar
Shucai Xiao committed
295
            hidden_out = (seq_index == 0)
296
297
                             ? last_out
                             : prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out);
298
299
300
301
        }
        else
        {
            hidden_out = (seq_index == seq_len - 1)
302
303
                             ? last_out
                             : prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out);
304
305
306
307
308
309
        }
        seq_index = is_forward ? (seq_index + 1) : (seq_index - 1);
    }

    std::vector<instruction_ref> out_args;
    out_args.push_back(hidden_out);
310
    out_args.push_back(last_out);
311
312
313
314
315
316

    return out_args;
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx