rewrite_gru.cpp 12.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
#include <migraphx/rewrite_gru.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

void rewrite_gru::apply(program& prog) const
{
    for(auto ins : iterator_for(prog))
    {
        if(ins->name() != "gru")
        {
            continue;
        }

        // could be 3 to 5 inputs (though onnx::rnn has 6 inputs,
        // the 5th one is undefined and ignored by protobuf. so
        // we need to process up to 5 inputs
        auto args = ins->inputs();

        shape seq_shape         = args[0]->get_shape();
        std::size_t hidden_size = args[2]->get_shape().lens()[2];
        std::size_t batchs  = seq_shape.lens()[1];
        shape::type_t type      = seq_shape.type();
        migraphx::shape ih_shape{type, {batchs, hidden_size}};
        std::vector<char> data(ih_shape.bytes(), 0);

        auto gru_op                    = any_cast<op::gru>(ins->get_operator());
        op::gru::gru_direction_t dicrt = gru_op.direction;
        if(dicrt == op::gru::bidirectional)
        {
            long hs = static_cast<long>(hidden_size);
            // forward weight
            auto uw_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
            auto w_forward = prog.insert_instruction(ins, op::squeeze{{0}, uw_forward});

            auto ur_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
            auto r_forward = prog.insert_instruction(ins, op::squeeze{{0}}, ur_forward);

            // reverse weight
            auto uw_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
            auto w_reverse = prog.insert_instruction(ins, op::squeeze{{0}, uw_reverse});

            auto ur_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
            auto r_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, ur_reverse);

            // process bias
            instruction_ref bias_forward, bias_reverse;
            bias_forward = bias_reverse = prog.end();
            if(args.size() >= 4)
            {
                // forward bias
                auto uwb_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
                bias_forward = prog.insert_instruction(ins, op::squeeze{{0}}, uwb_forward);

                // backward bias
                auto uwb_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
                bias_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, uwb_reverse);
           }

            // intial hidden state
            instruction_ref ih_forward, ih_reverse;
            if(args.size() >= 5)
            {
                // forward
                ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[4]);
                ih_forward = prog.insert_instruction(ins, op::squeeze{{0}}, ih_forward);

                // reverse
                ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[4]);
                ih_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, ih_reverse);
            }
            else
            {
                ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
                ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
            }

            auto ret_forward = gru_oper(true,
                                        prog,
                                        ins,
                                        args[0],
                                        w_forward,
                                        r_forward,
                                        ih_forward,
                                        bias_forward,
                                        gru_op.linear_before_reset,
                                        gru_op.actv_funcs.at(0),
                                        gru_op.actv_funcs.at(1));

            auto ret_reverse = rnn_oper(false,
                                        prog,
                                        ins,
                                        args[0],
                                        w_reverse,
                                        r_reverse,
                                        ih_reverse,
                                        bias_reverse,
                                        gru_op.linear_before_reset,
                                        gru_op.actv_funcs.at(2),
                                        gru_op.actv_funcs.at(3));

            // auto final_output = prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]);

            // add the dimension of num_direction
            ret_forward[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_forward[0]);
            ret_reverse[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_reverse[0]);

            // concat the forward and reverse output
            prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
        }
        else
        {
            bool is_forward = (dicrt == op::gru::forward) ? true : false;
            // weight matrix
            auto w      = prog.insert_instruction(ins, op::squeeze{{0}}, args[1]);
            auto r      = prog.insert_instruction(ins, op::squeeze{{0}}, args[2]);

            // bias
            instruction_ref bias = prog.end();
            if(args.size() >= 4)
            {
                bias    = prog.insert_instruction(ins, op::squeeze{{0}}, args[3]);
            }

            // intial hidden state
            instruction_ref ih;
            if(args.size() >= 5)
            {
                ih = prog.insert_instruction(ins, op::squeeze{{0}}, args[4]);
            }
            else
            {
                ih = prog.add_literal(migraphx::literal{s, data});
            }

            auto ret = gru_oper(
                is_forward, prog, ins, args[0], w, r, ih, bias, gru_op.linear_before_reset, gru_op.actv_funcs.at(0), gru_op.actv_funcs.at(1));

            // add the dimension of num_direction
            prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]);
        }
    }
}

std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
                                                   program& prog,
                                                   instruction_ref ins,
                                                   instruction_ref input,
                                                   instruction_ref w,
                                                   instruction_ref r,
                                                   instruction_ref ih,
                                                   instruction_ref bias,
                                                   int linear_before_reset,
                                                   operation& actv_func1,
                                                   operation& actv_func2) const
{
    instruction_ref hidden_out, final_out;
    long seq_len = static_cast<long>(input->get_shape().lens()[0]);
    long hs = static_cast<long>(r->get_shape().lens()[1]);
    long seq_index              = is_forward ? 0 : seq_len - 1;

    migraphx::shape s(input->get_shape().type(), {1});
    auto l1 = prog.add_literal(migraphx::leteral{s, {1}});

    // weight matrix
    std::vector<int64_t> perm{1, 0};
    auto wz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, w);
    auto twz = prog.insert_instruction(ins, op::transpose{perm}, wz);
    auto wr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, w);
    auto twr = prog.insert_instruction(ins, op::transpose{perm}, wr);
    auto wh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, w);
    auto twh = prog.insert_instruction(ins, op::transpose{perm}, wh);

    auto rz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, r);
    auto trz = prog.insert_instruction(ins, op::transpose{perm}, rz);
    auto rr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, r);
    auto trr = prog.insert_instruction(ins, op::transpose{perm}, rr);
    auto rh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, r);
    auto trh = prog.insert_instruction(ins, op::transpose{perm}, rh);

    // bias
    instruction_ref br_bz, br_br, br_wbh, br_rbh, br_bh;
    if (bias != prog.end())
    {
        auto wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, bias);
        auto wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2*hs}}, bias);
        wbh = prog.insert_instruction(ins, op::slice{{0}, {2*hs}, {3*hs}}, bias);
        br_wbh = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, wbh);

        auto rbz = prog.insert_instruction(ins, op::slice{{0}, {3*hs}, {4*hs}}, bias);
        auto rbr = prog.insert_instruction(ins, op::slice{{0}, {4*hs}, {5*hs}}, bias);
        rbh = prog.insert_instruction(ins, op::slice{{0}, {5*hs}, {6*hs}}, bias);
        br_rbh = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, rbh);

        auto bz = prog.insert_instruction(ins, op::add{}, wbz, rbz);
        br_bz = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, bz);
        auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr);
        br_br = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, br);
        br_bh = prog.insert_instruction(ins, op::add{}, br_wbh, br_rbh);
    }

    for(long i = 0; i < seq_len; i++)
    {
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
        xt      = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
        // equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
        auto xwzt = prog.insert_instruction(ins, op::dot{}, xt, twz);
        auto hrzt = prog.insert_instruction(ins, op::dot{}, ih, trz);
        auto xwhr_zt = prog.insert_instruction(ins, op::add{}, xwzt, hrzt);
        if (bias != prog.end()) 
        {
            xwhr_zt = prog.insert_instruction(ins, op::add{}, xwhr_zt, br_bz);
        }
        auto zt = prog.insert_instruction(ins, actv_func1, xwhr_zt);

        // equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
        auto xwrt = prog.insert_instruction(ins, op::dot{}, xt, twr);
        auto hrrt = prog.insert_instruction(ins, op::dot{}, xt, trr);
        auto xwhr_rt = prog.insert_instruction(ins, op::add{}, xwrt, hrrt);
        if (bias != prog.end()) 
        {
            xwhr_rt = prog.insert_instruction(ins, op::add{}, xwhr_rt, br_br);
        }
        auto rt = prog.insert_instruction(ins, actv_func1, xwhr_rt);

        instruction_ref xwhh_rt;
        if (linear_before_reset == 0)
        {
            // equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
            auto xwht = prog.insert_instruction(ins, op::dot{}, xt, twh);
            auto rt_ht = prog.insert_instruction(ins, op::mul{}, rt, ih);
            auto rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht, trh);
            xwhh_rt = prog.insert_instruction(ins, op::add{}, xwht, rt_rt);
            if (bias != prog.end())
            {
                xwhh_rt = prog.insert_instruction(ins, op::add{}, xwhh_rt, br_bh);
            }
        }
        else 
        {
            // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
            auto xwht = prog.insert_instruction(ins, op::dot{}, xt, twh);
            auto ih_rht = prog.insert_instruction(ins, op::dot{}, ih, twh);
            if (bias != prog.end())
            {
                ih_rht = prog.insert_instruction(ins, op::add{}, ih_rht, br_rbh);
            }
            auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ih_rht);
            xwhh_rt = prog.insert_instruction(ins, op::add{}, xwht, rt_rh);
            if (bias != prog.end())
            {
                xwhh_rt = prog.insert_instruction(ins, op::add{}, xwhh_rt, br_wbh);
            }
        }
        ht = prog.insert_instruction(ins, actv_func2, xwhh_rt);

        // equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
        auto 1zt = prog.insert_instruction(ins, op::sub{}, l1, zt);
        auto 1ztht = prog.insert_instruction(ins, op::mul{}, 1zt, ht);
        auto ztht1 = prog.insert_instruction(ins, op::mul{}, zt, ih);
        ih = prog.insert_instruction(ins, op::add{}, 1ztht ztht1);
        final_out   = ih;

        if(is_forward)
        {
            hidden_out = (seq_index == 0)
                             ? ih
                             : prog.insert_instruction(ins, op::concat{0}, hidden_out, ih);
        }
        else
        {
            hidden_out = (seq_index == seq_len - 1)
                             ? ih
                             : prog.insert_instruction(ins, op::concat{0}, ih, hidden_out);
        }
        seq_index = is_forward ? (seq_index + 1) : (seq_index - 1);
    }

    std::vector<instruction_ref> out_args;
    out_args.push_back(hidden_out);
    out_args.push_back(final_out);

    return out_args;
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx