rewrite_rnn.cpp 12.3 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
4
5
6
7
8
9
10
11
12
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

void rewrite_rnn::apply(program& prog) const
{
Shucai Xiao's avatar
Shucai Xiao committed
13
    std::unordered_map<instruction_ref, instruction_ref> map_last_output;
Shucai Xiao's avatar
Shucai Xiao committed
14
15
    for(auto ins : iterator_for(prog))
    {
Shucai Xiao's avatar
Shucai Xiao committed
16
17
        // rewrite rnn operator
        if(ins->name() == "rnn")
Shucai Xiao's avatar
Shucai Xiao committed
18
        {
19
20
21
22
            // could be 3 to 6 inputs, but the parse_rnn function will
            // append undefined operators to make 6 arguments when parsing
            // an onnx file. Another case is user can have only 3 arguments
            // when writing their program.
Shucai Xiao's avatar
Shucai Xiao committed
23
24
25
26
27
28
            auto args = ins->inputs();

            shape seq_shape         = args[0]->get_shape();
            std::size_t hidden_size = args[1]->get_shape().lens()[1];
            std::size_t batch_size  = seq_shape.lens()[1];
            shape::type_t type      = seq_shape.type();
29
30
            migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
            std::vector<float> data(ih_shape.elements(), 0);
Shucai Xiao's avatar
Shucai Xiao committed
31

Shucai Xiao's avatar
Shucai Xiao committed
32
            auto actv_funcs                = compute_actv_funcs(ins);
Shucai Xiao's avatar
Shucai Xiao committed
33
34
            auto rnn_op                    = any_cast<op::rnn>(ins->get_operator());
            op::rnn::rnn_direction_t dicrt = rnn_op.direction;
35
            if(dicrt == op::rnn::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
36
            {
Shucai Xiao's avatar
Shucai Xiao committed
37
                // input weight matrix
Shucai Xiao's avatar
Shucai Xiao committed
38
39
                auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
                auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
Shucai Xiao's avatar
Shucai Xiao committed
40
41

                // hidden state weight matrix
Shucai Xiao's avatar
Shucai Xiao committed
42
43
                auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
                auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
Shucai Xiao's avatar
Shucai Xiao committed
44
45

                // process bias
Shucai Xiao's avatar
Shucai Xiao committed
46
47
                instruction_ref bias_forward = prog.end();
                instruction_ref bias_reverse = prog.end();
48
                if(args.size() >= 4 && args[3]->get_operator().name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
49
50
51
52
53
54
55
                {
                    bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
                    bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
                }

                // process intial hidden state, it could be the 6th argument
                // or the 5th one (if the sequence len argument is ignored)
Shucai Xiao's avatar
Shucai Xiao committed
56
57
                instruction_ref ih_forward{};
                instruction_ref ih_reverse{};
58
                if(args.size() == 6 && args[5]->get_operator().name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
59
                {
Shucai Xiao's avatar
Shucai Xiao committed
60
61
                    ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
                    ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
Shucai Xiao's avatar
Shucai Xiao committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
                }
                else
                {
                    ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
                    ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
                }

                auto ret_forward = rnn_cell(true,
                                            prog,
                                            ins,
                                            args[0],
                                            w_forward,
                                            r_forward,
                                            bias_forward,
                                            ih_forward,
77
                                            actv_funcs.at(0));
Shucai Xiao's avatar
Shucai Xiao committed
78
79
80
81
82
83
84
85
                auto ret_reverse = rnn_cell(false,
                                            prog,
                                            ins,
                                            args[0],
                                            w_reverse,
                                            r_reverse,
                                            bias_reverse,
                                            ih_reverse,
86
                                            actv_funcs.at(1));
Shucai Xiao's avatar
Shucai Xiao committed
87

Shucai Xiao's avatar
Shucai Xiao committed
88
89
                auto concat_output =
                    prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
Shucai Xiao's avatar
Shucai Xiao committed
90
                auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
Shucai Xiao's avatar
Shucai Xiao committed
91

Shucai Xiao's avatar
Shucai Xiao committed
92
93
94
                // The following logic is to ensure the last instruction rewritten from
                // rnn operator is a concat instruction
                // sequence len is 1
Shucai Xiao's avatar
Shucai Xiao committed
95
                instruction_ref hidden_output{};
Shucai Xiao's avatar
Shucai Xiao committed
96
                if(ret_forward[0] == prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
97
                {
Shucai Xiao's avatar
Shucai Xiao committed
98
99
                    hidden_output = prog.replace_instruction(
                        ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
Shucai Xiao's avatar
Shucai Xiao committed
100
                }
Shucai Xiao's avatar
Shucai Xiao committed
101
                else
Shucai Xiao's avatar
Shucai Xiao committed
102
                {
Shucai Xiao's avatar
Shucai Xiao committed
103
104
105
106
                    ret_forward[0] =
                        prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
                    ret_reverse[0] =
                        prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
Shucai Xiao's avatar
Shucai Xiao committed
107
108
                    hidden_output = prog.replace_instruction(
                        ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
Shucai Xiao's avatar
Shucai Xiao committed
109
                }
Shucai Xiao's avatar
Shucai Xiao committed
110
                map_last_output[hidden_output] = last_output;
Shucai Xiao's avatar
Shucai Xiao committed
111
112
113
            }
            else
            {
114
                bool is_forward = (dicrt == op::rnn::forward);
Shucai Xiao's avatar
Shucai Xiao committed
115
116
117
118
119
120
121
122
                // input weight matrix
                auto w = args[1];

                // hidden state weight matrix
                auto r = args[2];

                // process bias and initial hidden state
                instruction_ref bias = prog.end();
123
                if(args.size() >= 4 && args[3]->get_operator().name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
124
125
126
127
128
129
                {
                    bias = args[3];
                }

                // process intial hidden state
                instruction_ref ih;
Shucai Xiao's avatar
Shucai Xiao committed
130
                if(args.size() == 6 && args[5]->get_operator().name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
131
                {
132
                    ih = args[5];
Shucai Xiao's avatar
Shucai Xiao committed
133
134
135
136
137
138
                }
                else
                {
                    ih = prog.add_literal(migraphx::literal{ih_shape, data});
                }

Shucai Xiao's avatar
Shucai Xiao committed
139
140
                auto ret =
                    rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0));
Shucai Xiao's avatar
Shucai Xiao committed
141
                auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
Shucai Xiao's avatar
Shucai Xiao committed
142

Shucai Xiao's avatar
Shucai Xiao committed
143
                // following logic is to ensure the last instruction is a
Shucai Xiao's avatar
Shucai Xiao committed
144
145
                // concat instruction
                // sequence len is 1
Shucai Xiao's avatar
Shucai Xiao committed
146
                instruction_ref hidden_output{};
Shucai Xiao's avatar
Shucai Xiao committed
147
                if(ret[0] == prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
148
                {
Shucai Xiao's avatar
Shucai Xiao committed
149
                    hidden_output = prog.replace_instruction(ins, op::concat{0}, ret[1]);
Shucai Xiao's avatar
Shucai Xiao committed
150
151
152
153
154
                }
                else
                {
                    auto concat_arg0 = is_forward ? ret[0] : ret[1];
                    auto concat_arg1 = is_forward ? ret[1] : ret[0];
Shucai Xiao's avatar
Shucai Xiao committed
155
156
                    hidden_output =
                        prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
Shucai Xiao's avatar
Shucai Xiao committed
157
                }
Shucai Xiao's avatar
Shucai Xiao committed
158
                map_last_output[hidden_output] = last_output;
Shucai Xiao's avatar
Shucai Xiao committed
159
160
161
            }
        }

Shucai Xiao's avatar
Shucai Xiao committed
162
        // rewrite the rnn_last_output operator that right after the rnn
Shucai Xiao's avatar
Shucai Xiao committed
163
164
165
        // operator. Intuitively, we can do a slice on the input to get
        // the last output, but it is already existed in the rnn operator,
        // so we can just use it as the output here
Shucai Xiao's avatar
Shucai Xiao committed
166
        if(ins->name() == "rnn_last_output")
Shucai Xiao's avatar
Shucai Xiao committed
167
        {
Shucai Xiao's avatar
Shucai Xiao committed
168
169
170
            auto inputs = ins->inputs();
            assert(inputs.size() == 1);
            auto arg = inputs[0];
Shucai Xiao's avatar
Shucai Xiao committed
171
            if(map_last_output.count(arg) == 0)
Shucai Xiao's avatar
Shucai Xiao committed
172
            {
Shucai Xiao's avatar
Shucai Xiao committed
173
                MIGRAPHX_THROW("RNN_LAST_OUTPUT: no related rnn operator as its input");
Shucai Xiao's avatar
Shucai Xiao committed
174
            }
Shucai Xiao's avatar
Shucai Xiao committed
175
176

            prog.replace_instruction(ins, map_last_output[arg]);
Shucai Xiao's avatar
Shucai Xiao committed
177
        }
Shucai Xiao's avatar
Shucai Xiao committed
178
179
180
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
181
std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
182
183
184
                                                   program& prog,
                                                   instruction_ref ins,
                                                   instruction_ref input,
Shucai Xiao's avatar
Shucai Xiao committed
185
186
                                                   instruction_ref w,
                                                   instruction_ref r,
Shucai Xiao's avatar
Shucai Xiao committed
187
                                                   instruction_ref bias,
Shucai Xiao's avatar
Shucai Xiao committed
188
                                                   instruction_ref ih,
Shucai Xiao's avatar
Shucai Xiao committed
189
                                                   operation& actv_func) const
Shucai Xiao's avatar
Shucai Xiao committed
190
{
Shucai Xiao's avatar
Shucai Xiao committed
191
192
    // squeeze and transpose w
    std::vector<int64_t> perm{1, 0};
Shucai Xiao's avatar
Shucai Xiao committed
193
    auto sw      = prog.insert_instruction(ins, op::squeeze{{0}}, w);
Shucai Xiao's avatar
Shucai Xiao committed
194
    auto tran_sw = prog.insert_instruction(ins, op::transpose{perm}, sw);
Shucai Xiao's avatar
Shucai Xiao committed
195
196

    // squeeze and transpose r
Shucai Xiao's avatar
Shucai Xiao committed
197
    auto sr      = prog.insert_instruction(ins, op::squeeze{{0}}, r);
Shucai Xiao's avatar
Shucai Xiao committed
198
199
200
201
202
203
    auto tran_sr = prog.insert_instruction(ins, op::transpose{perm}, sr);

    // initial hidden state
    auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);

    // bias
Shucai Xiao's avatar
Shucai Xiao committed
204
    if(bias != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
205
    {
Shucai Xiao's avatar
Shucai Xiao committed
206
207
208
209
210
211
        long hs    = r->get_shape().lens()[2];
        auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
        auto wb    = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
        auto rb    = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
        auto b     = prog.insert_instruction(ins, op::add{}, wb, rb);
        bias       = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, b);
Shucai Xiao's avatar
Shucai Xiao committed
212
213
    }

Shucai Xiao's avatar
Shucai Xiao committed
214
215
    instruction_ref hidden_out = prog.end();
    instruction_ref last_out{};
Shucai Xiao's avatar
Shucai Xiao committed
216
217
    last_out            = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);
    std::size_t seq_len = input->get_shape().lens()[0];
Shucai Xiao's avatar
Shucai Xiao committed
218
219
    for(std::size_t i = 0; i < seq_len; i++)
    {
Shucai Xiao's avatar
Shucai Xiao committed
220
        long seq_index = is_forward ? i : (seq_len - 1 - i);
Shucai Xiao's avatar
Shucai Xiao committed
221
222
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
        xt      = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
Shucai Xiao's avatar
Shucai Xiao committed
223
224
225
226
        auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
        auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
        auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
        instruction_ref ht;
Shucai Xiao's avatar
Shucai Xiao committed
227
228
        if(bias != prog.end())
        {
Shucai Xiao's avatar
Shucai Xiao committed
229
            ht = prog.insert_instruction(ins, op::add{}, xt_ht, bias);
Shucai Xiao's avatar
Shucai Xiao committed
230
231
232
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
233
            ht = xt_ht;
Shucai Xiao's avatar
Shucai Xiao committed
234
235
236
        }

        // apply activation function
Shucai Xiao's avatar
Shucai Xiao committed
237
        ht  = prog.insert_instruction(ins, actv_func, ht);
Shucai Xiao's avatar
Shucai Xiao committed
238
        sih = ht;
Shucai Xiao's avatar
Shucai Xiao committed
239

Shucai Xiao's avatar
Shucai Xiao committed
240
241
242
        // add the dimensions of sequence length (axis 0 for sequence length,
        // axis 1 for num_directions
        last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht);
Shucai Xiao's avatar
Shucai Xiao committed
243

Shucai Xiao's avatar
Shucai Xiao committed
244
245
246
        // concatenation for the last last_out is performed in the apply()
        // function to ensure the last instruction is concat, then we have
        // output inserted
Shucai Xiao's avatar
Shucai Xiao committed
247
        if(i < seq_len - 1)
Shucai Xiao's avatar
Shucai Xiao committed
248
        {
Shucai Xiao's avatar
Shucai Xiao committed
249
250
            if(is_forward)
            {
Shucai Xiao's avatar
Shucai Xiao committed
251
252
253
254
                hidden_out =
                    (seq_index == 0)
                        ? last_out
                        : prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out);
Shucai Xiao's avatar
Shucai Xiao committed
255
256
257
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
258
259
260
261
                hidden_out =
                    (seq_index == seq_len - 1)
                        ? last_out
                        : prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out);
Shucai Xiao's avatar
Shucai Xiao committed
262
            }
Shucai Xiao's avatar
Shucai Xiao committed
263
264
265
        }
    }

266
    return {hidden_out, last_out};
Shucai Xiao's avatar
Shucai Xiao committed
267
268
}

269
270
271
272
std::vector<operation> rewrite_rnn::compute_actv_funcs(instruction_ref ins) const
{
    auto rnn_op = any_cast<op::rnn>(ins->get_operator());
    // before rewrite the rnn operator, need to ensure
Shucai Xiao's avatar
Shucai Xiao committed
273
    // we have 2 actv funcs. If less than 2, use the
274
    // algorithm in parse_rnn to make 2 actv functions
Shucai Xiao's avatar
Shucai Xiao committed
275
    if(rnn_op.direction == op::rnn::bidirectional)
276
    {
Shucai Xiao's avatar
Shucai Xiao committed
277
        if(rnn_op.actv_funcs.empty())
278
279
280
281
        {
            // default is tanh
            return {op::tanh{}, op::tanh{}};
        }
Shucai Xiao's avatar
Shucai Xiao committed
282
        else if(rnn_op.actv_funcs.size() == 1)
283
284
285
286
287
288
289
290
291
292
        {
            return {rnn_op.actv_funcs.at(0), rnn_op.actv_funcs.at(0)};
        }
        else
        {
            return rnn_op.actv_funcs;
        }
    }
    else
    {
Shucai Xiao's avatar
Shucai Xiao committed
293
        if(rnn_op.actv_funcs.empty())
294
295
296
297
298
299
300
301
302
303
304
        {
            // default is tanh
            return {op::tanh{}};
        }
        else
        {
            return rnn_op.actv_funcs;
        }
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
305
306
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx