"src/targets/gpu/vscode:/vscode.git/clone" did not exist on "8ba7c0a07344faab8e2936fb8185a366cd2a019c"
rewrite_gru.cpp 13.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#include <migraphx/rewrite_gru.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

void rewrite_gru::apply(program& prog) const
{
    for(auto ins : iterator_for(prog))
    {
        if(ins->name() != "gru")
        {
            continue;
        }

        // could be 3 to 5 inputs (though onnx::rnn has 6 inputs,
        // the 5th one is undefined and ignored by protobuf. so
        // we need to process up to 5 inputs
        auto args = ins->inputs();

        shape seq_shape         = args[0]->get_shape();
        std::size_t hidden_size = args[2]->get_shape().lens()[2];
Shucai Xiao's avatar
Shucai Xiao committed
27
        std::size_t batchs      = seq_shape.lens()[1];
28
29
30
31
32
33
34
35
36
37
        shape::type_t type      = seq_shape.type();
        migraphx::shape ih_shape{type, {batchs, hidden_size}};
        std::vector<char> data(ih_shape.bytes(), 0);

        auto gru_op                    = any_cast<op::gru>(ins->get_operator());
        op::gru::gru_direction_t dicrt = gru_op.direction;
        if(dicrt == op::gru::bidirectional)
        {
            // forward weight
            auto uw_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
Shucai Xiao's avatar
Shucai Xiao committed
38
            auto w_forward  = prog.insert_instruction(ins, op::squeeze{{0}}, uw_forward);
39
40

            auto ur_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
Shucai Xiao's avatar
Shucai Xiao committed
41
            auto r_forward  = prog.insert_instruction(ins, op::squeeze{{0}}, ur_forward);
42
43
44

            // reverse weight
            auto uw_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
Shucai Xiao's avatar
Shucai Xiao committed
45
            auto w_reverse  = prog.insert_instruction(ins, op::squeeze{{0}}, uw_reverse);
46
47

            auto ur_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
Shucai Xiao's avatar
Shucai Xiao committed
48
            auto r_reverse  = prog.insert_instruction(ins, op::squeeze{{0}}, ur_reverse);
49
50
51
52
53
54
55
56

            // process bias
            instruction_ref bias_forward, bias_reverse;
            bias_forward = bias_reverse = prog.end();
            if(args.size() >= 4)
            {
                // forward bias
                auto uwb_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
Shucai Xiao's avatar
Shucai Xiao committed
57
                bias_forward     = prog.insert_instruction(ins, op::squeeze{{0}}, uwb_forward);
58
59
60

                // backward bias
                auto uwb_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
Shucai Xiao's avatar
Shucai Xiao committed
61
62
                bias_reverse     = prog.insert_instruction(ins, op::squeeze{{0}}, uwb_reverse);
            }
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

            // intial hidden state
            instruction_ref ih_forward, ih_reverse;
            if(args.size() >= 5)
            {
                // forward
                ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[4]);
                ih_forward = prog.insert_instruction(ins, op::squeeze{{0}}, ih_forward);

                // reverse
                ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[4]);
                ih_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, ih_reverse);
            }
            else
            {
                ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
                ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
            }

            auto ret_forward = gru_oper(true,
                                        prog,
                                        ins,
                                        args[0],
                                        w_forward,
                                        r_forward,
                                        ih_forward,
                                        bias_forward,
                                        gru_op.linear_before_reset,
                                        gru_op.actv_funcs.at(0),
                                        gru_op.actv_funcs.at(1));

Shucai Xiao's avatar
Shucai Xiao committed
94
            auto ret_reverse = gru_oper(false,
95
96
97
98
99
100
101
102
103
104
105
                                        prog,
                                        ins,
                                        args[0],
                                        w_reverse,
                                        r_reverse,
                                        ih_reverse,
                                        bias_reverse,
                                        gru_op.linear_before_reset,
                                        gru_op.actv_funcs.at(2),
                                        gru_op.actv_funcs.at(3));

Shucai Xiao's avatar
Shucai Xiao committed
106
107
            // auto final_output = prog.insert_instruction(ins, op::concat{0}, ret_forward[1],
            // ret_reverse[1]);
108
109
110
111
112
113
114
115
116
117
118
119

            // add the dimension of num_direction
            ret_forward[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_forward[0]);
            ret_reverse[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_reverse[0]);

            // concat the forward and reverse output
            prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
        }
        else
        {
            bool is_forward = (dicrt == op::gru::forward) ? true : false;
            // weight matrix
Shucai Xiao's avatar
Shucai Xiao committed
120
121
            auto w = prog.insert_instruction(ins, op::squeeze{{0}}, args[1]);
            auto r = prog.insert_instruction(ins, op::squeeze{{0}}, args[2]);
122
123
124
125
126

            // bias
            instruction_ref bias = prog.end();
            if(args.size() >= 4)
            {
Shucai Xiao's avatar
Shucai Xiao committed
127
                bias = prog.insert_instruction(ins, op::squeeze{{0}}, args[3]);
128
129
130
131
132
133
134
135
136
137
            }

            // intial hidden state
            instruction_ref ih;
            if(args.size() >= 5)
            {
                ih = prog.insert_instruction(ins, op::squeeze{{0}}, args[4]);
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
138
                ih = prog.add_literal(migraphx::literal{ih_shape, data});
139
140
            }

Shucai Xiao's avatar
Shucai Xiao committed
141
142
143
144
145
146
147
148
149
150
151
            auto ret = gru_oper(is_forward,
                                prog,
                                ins,
                                args[0],
                                w,
                                r,
                                ih,
                                bias,
                                gru_op.linear_before_reset,
                                gru_op.actv_funcs.at(0),
                                gru_op.actv_funcs.at(1));
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171

            // add the dimension of num_direction
            prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]);
        }
    }
}

std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
                                                   program& prog,
                                                   instruction_ref ins,
                                                   instruction_ref input,
                                                   instruction_ref w,
                                                   instruction_ref r,
                                                   instruction_ref ih,
                                                   instruction_ref bias,
                                                   int linear_before_reset,
                                                   operation& actv_func1,
                                                   operation& actv_func2) const
{
    instruction_ref hidden_out, final_out;
Shucai Xiao's avatar
Shucai Xiao committed
172
173
174
    long seq_len   = static_cast<long>(input->get_shape().lens()[0]);
    long hs        = static_cast<long>(r->get_shape().lens()[1]);
    long seq_index = is_forward ? 0 : seq_len - 1;
175
176

    migraphx::shape s(input->get_shape().type(), {1});
Shucai Xiao's avatar
Shucai Xiao committed
177
    auto l1 = prog.add_literal(migraphx::literal{s, {1}});
178
179
180

    // weight matrix
    std::vector<int64_t> perm{1, 0};
Shucai Xiao's avatar
Shucai Xiao committed
181
    auto wz  = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, w);
182
    auto twz = prog.insert_instruction(ins, op::transpose{perm}, wz);
Shucai Xiao's avatar
Shucai Xiao committed
183
    auto wr  = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, w);
184
    auto twr = prog.insert_instruction(ins, op::transpose{perm}, wr);
Shucai Xiao's avatar
Shucai Xiao committed
185
    auto wh  = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, w);
186
187
    auto twh = prog.insert_instruction(ins, op::transpose{perm}, wh);

Shucai Xiao's avatar
Shucai Xiao committed
188
    auto rz  = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, r);
189
    auto trz = prog.insert_instruction(ins, op::transpose{perm}, rz);
Shucai Xiao's avatar
Shucai Xiao committed
190
    auto rr  = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, r);
191
    auto trr = prog.insert_instruction(ins, op::transpose{perm}, rr);
Shucai Xiao's avatar
Shucai Xiao committed
192
    auto rh  = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, r);
193
194
195
196
    auto trh = prog.insert_instruction(ins, op::transpose{perm}, rh);

    // bias
    instruction_ref br_bz, br_br, br_wbh, br_rbh, br_bh;
Shucai Xiao's avatar
Shucai Xiao committed
197
    if(bias != prog.end())
198
199
    {
        auto wbz = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, bias);
Shucai Xiao's avatar
Shucai Xiao committed
200
        auto wbr = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, bias);
Shucai Xiao's avatar
Shucai Xiao committed
201
        auto wbh = prog.insert_instruction(ins, op::slice{{0}, {2 * hs}, {3 * hs}}, bias);
Shucai Xiao's avatar
Shucai Xiao committed
202
        br_wbh   = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, wbh);
203

Shucai Xiao's avatar
Shucai Xiao committed
204
205
        auto rbz = prog.insert_instruction(ins, op::slice{{0}, {3 * hs}, {4 * hs}}, bias);
        auto rbr = prog.insert_instruction(ins, op::slice{{0}, {4 * hs}, {5 * hs}}, bias);
Shucai Xiao's avatar
Shucai Xiao committed
206
        auto rbh = prog.insert_instruction(ins, op::slice{{0}, {5 * hs}, {6 * hs}}, bias);
Shucai Xiao's avatar
Shucai Xiao committed
207
        br_rbh   = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, rbh);
208
209

        auto bz = prog.insert_instruction(ins, op::add{}, wbz, rbz);
Shucai Xiao's avatar
Shucai Xiao committed
210
        br_bz   = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, bz);
211
        auto br = prog.insert_instruction(ins, op::add{}, wbr, rbr);
Shucai Xiao's avatar
Shucai Xiao committed
212
213
        br_br   = prog.insert_instruction(ins, op::broadcast{1, ih->get_shape()}, br);
        br_bh   = prog.insert_instruction(ins, op::add{}, br_wbh, br_rbh);
214
215
216
217
218
219
220
    }

    for(long i = 0; i < seq_len; i++)
    {
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
        xt      = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
        // equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
Shucai Xiao's avatar
Shucai Xiao committed
221
222
        auto xwzt    = prog.insert_instruction(ins, op::dot{}, xt, twz);
        auto hrzt    = prog.insert_instruction(ins, op::dot{}, ih, trz);
223
        auto xwhr_zt = prog.insert_instruction(ins, op::add{}, xwzt, hrzt);
Shucai Xiao's avatar
Shucai Xiao committed
224
        if(bias != prog.end())
225
226
227
228
229
230
        {
            xwhr_zt = prog.insert_instruction(ins, op::add{}, xwhr_zt, br_bz);
        }
        auto zt = prog.insert_instruction(ins, actv_func1, xwhr_zt);

        // equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
Shucai Xiao's avatar
Shucai Xiao committed
231
232
        auto xwrt    = prog.insert_instruction(ins, op::dot{}, xt, twr);
        auto hrrt    = prog.insert_instruction(ins, op::dot{}, xt, trr);
233
        auto xwhr_rt = prog.insert_instruction(ins, op::add{}, xwrt, hrrt);
Shucai Xiao's avatar
Shucai Xiao committed
234
        if(bias != prog.end())
235
236
237
238
239
240
        {
            xwhr_rt = prog.insert_instruction(ins, op::add{}, xwhr_rt, br_br);
        }
        auto rt = prog.insert_instruction(ins, actv_func1, xwhr_rt);

        instruction_ref xwhh_rt;
Shucai Xiao's avatar
Shucai Xiao committed
241
        if(linear_before_reset == 0)
242
243
        {
            // equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
Shucai Xiao's avatar
Shucai Xiao committed
244
            auto xwht  = prog.insert_instruction(ins, op::dot{}, xt, twh);
245
246
            auto rt_ht = prog.insert_instruction(ins, op::mul{}, rt, ih);
            auto rt_rh = prog.insert_instruction(ins, op::dot{}, rt_ht, trh);
Shucai Xiao's avatar
Shucai Xiao committed
247
            xwhh_rt    = prog.insert_instruction(ins, op::add{}, xwht, rt_rh);
Shucai Xiao's avatar
Shucai Xiao committed
248
            if(bias != prog.end())
249
250
251
252
            {
                xwhh_rt = prog.insert_instruction(ins, op::add{}, xwhh_rt, br_bh);
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
253
        else
254
255
        {
            // equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
Shucai Xiao's avatar
Shucai Xiao committed
256
            auto xwht   = prog.insert_instruction(ins, op::dot{}, xt, twh);
257
            auto ih_rht = prog.insert_instruction(ins, op::dot{}, ih, twh);
Shucai Xiao's avatar
Shucai Xiao committed
258
            if(bias != prog.end())
259
260
261
262
            {
                ih_rht = prog.insert_instruction(ins, op::add{}, ih_rht, br_rbh);
            }
            auto rt_rh = prog.insert_instruction(ins, op::mul{}, rt, ih_rht);
Shucai Xiao's avatar
Shucai Xiao committed
263
264
            xwhh_rt    = prog.insert_instruction(ins, op::add{}, xwht, rt_rh);
            if(bias != prog.end())
265
266
267
268
            {
                xwhh_rt = prog.insert_instruction(ins, op::add{}, xwhh_rt, br_wbh);
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
269
        auto ht = prog.insert_instruction(ins, actv_func2, xwhh_rt);
270
271

        // equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
Shucai Xiao's avatar
Shucai Xiao committed
272
273
        auto z1t   = prog.insert_instruction(ins, op::sub{}, l1, zt);
        auto z1tht = prog.insert_instruction(ins, op::mul{}, z1t, ht);
274
        auto ztht1 = prog.insert_instruction(ins, op::mul{}, zt, ih);
Shucai Xiao's avatar
Shucai Xiao committed
275
        ih         = prog.insert_instruction(ins, op::add{}, z1tht, ztht1);
Shucai Xiao's avatar
Shucai Xiao committed
276
        final_out  = ih;
277
278
279

        if(is_forward)
        {
Shucai Xiao's avatar
Shucai Xiao committed
280
281
            hidden_out =
                (seq_index == 0) ? ih : prog.insert_instruction(ins, op::concat{0}, hidden_out, ih);
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
        }
        else
        {
            hidden_out = (seq_index == seq_len - 1)
                             ? ih
                             : prog.insert_instruction(ins, op::concat{0}, ih, hidden_out);
        }
        seq_index = is_forward ? (seq_index + 1) : (seq_index - 1);
    }

    std::vector<instruction_ref> out_args;
    out_args.push_back(hidden_out);
    out_args.push_back(final_out);

    return out_args;
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx