rewrite_rnn.cpp 12.2 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

void rewrite_rnn::apply(program& prog) const
{
    for(auto ins : iterator_for(prog))
    {
Shucai Xiao's avatar
Shucai Xiao committed
15
16
        // rewrite rnn operator
        if(ins->name() == "rnn")
Shucai Xiao's avatar
Shucai Xiao committed
17
        {
18
19
20
21
            // could be 3 to 6 inputs, but the parse_rnn function will
            // append undefined operators to make 6 arguments when parsing
            // an onnx file. Another case is user can have only 3 arguments
            // when writing their program.
Shucai Xiao's avatar
Shucai Xiao committed
22
23
24
25
26
27
            auto args = ins->inputs();

            shape seq_shape         = args[0]->get_shape();
            std::size_t hidden_size = args[1]->get_shape().lens()[1];
            std::size_t batch_size  = seq_shape.lens()[1];
            shape::type_t type      = seq_shape.type();
28
            migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
29
            std::vector<float> data(ih_shape.elements(), 0);
Shucai Xiao's avatar
Shucai Xiao committed
30

Shucai Xiao's avatar
Shucai Xiao committed
31
            auto actv_funcs                = compute_actv_funcs(ins);
Shucai Xiao's avatar
Shucai Xiao committed
32
33
            auto rnn_op                    = any_cast<op::rnn>(ins->get_operator());
            op::rnn::rnn_direction_t dicrt = rnn_op.direction;
34
            instruction_ref last_output{};
35
            if(dicrt == op::rnn::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
36
            {
Shucai Xiao's avatar
Shucai Xiao committed
37
                // input weight matrix
Shucai Xiao's avatar
Shucai Xiao committed
38
39
                auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
                auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
Shucai Xiao's avatar
Shucai Xiao committed
40
41

                // hidden state weight matrix
Shucai Xiao's avatar
Shucai Xiao committed
42
43
                auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
                auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
Shucai Xiao's avatar
Shucai Xiao committed
44
45

                // process bias
Shucai Xiao's avatar
Shucai Xiao committed
46
47
                instruction_ref bias_forward = prog.end();
                instruction_ref bias_reverse = prog.end();
48
                if(args.size() >= 4 && args[3]->get_operator().name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
49
50
51
52
53
54
55
                {
                    bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
                    bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
                }

                // process intial hidden state, it could be the 6th argument
                // or the 5th one (if the sequence len argument is ignored)
Shucai Xiao's avatar
Shucai Xiao committed
56
57
                instruction_ref ih_forward{};
                instruction_ref ih_reverse{};
58
                if(args.size() == 6 && args[5]->get_operator().name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
59
                {
Shucai Xiao's avatar
Shucai Xiao committed
60
61
                    ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
                    ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
Shucai Xiao's avatar
Shucai Xiao committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
                }
                else
                {
                    ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
                    ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
                }

                auto ret_forward = rnn_cell(true,
                                            prog,
                                            ins,
                                            args[0],
                                            w_forward,
                                            r_forward,
                                            bias_forward,
                                            ih_forward,
77
                                            actv_funcs.at(0));
Shucai Xiao's avatar
Shucai Xiao committed
78
79
80
81
82
83
84
85
                auto ret_reverse = rnn_cell(false,
                                            prog,
                                            ins,
                                            args[0],
                                            w_reverse,
                                            r_reverse,
                                            bias_reverse,
                                            ih_reverse,
86
                                            actv_funcs.at(1));
Shucai Xiao's avatar
Shucai Xiao committed
87

Shucai Xiao's avatar
Shucai Xiao committed
88
89
                auto concat_output =
                    prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
90
                last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
Shucai Xiao's avatar
Shucai Xiao committed
91

Shucai Xiao's avatar
Shucai Xiao committed
92
93
94
                // The following logic is to ensure the last instruction rewritten from
                // rnn operator is a concat instruction
                // sequence len is 1
Shucai Xiao's avatar
Shucai Xiao committed
95
                instruction_ref hidden_output{};
Shucai Xiao's avatar
Shucai Xiao committed
96
                if(ret_forward[0] == prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
97
                {
Shucai Xiao's avatar
Shucai Xiao committed
98
99
                    hidden_output = prog.replace_instruction(
                        ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
Shucai Xiao's avatar
Shucai Xiao committed
100
                }
Shucai Xiao's avatar
Shucai Xiao committed
101
                else
Shucai Xiao's avatar
Shucai Xiao committed
102
                {
Shucai Xiao's avatar
Shucai Xiao committed
103
104
105
106
                    ret_forward[0] =
                        prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
                    ret_reverse[0] =
                        prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
Shucai Xiao's avatar
Shucai Xiao committed
107
108
                    hidden_output = prog.replace_instruction(
                        ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
Shucai Xiao's avatar
Shucai Xiao committed
109
                }
Shucai Xiao's avatar
Shucai Xiao committed
110
111
112
            }
            else
            {
113
                bool is_forward = (dicrt == op::rnn::forward);
Shucai Xiao's avatar
Shucai Xiao committed
114
115
116
117
118
119
120
121
                // input weight matrix
                auto w = args[1];

                // hidden state weight matrix
                auto r = args[2];

                // process bias and initial hidden state
                instruction_ref bias = prog.end();
122
                if(args.size() >= 4 && args[3]->get_operator().name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
123
124
125
126
127
128
                {
                    bias = args[3];
                }

                // process intial hidden state
                instruction_ref ih;
Shucai Xiao's avatar
Shucai Xiao committed
129
                if(args.size() == 6 && args[5]->get_operator().name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
130
                {
131
                    ih = args[5];
Shucai Xiao's avatar
Shucai Xiao committed
132
133
134
135
136
137
                }
                else
                {
                    ih = prog.add_literal(migraphx::literal{ih_shape, data});
                }

Shucai Xiao's avatar
Shucai Xiao committed
138
139
                auto ret =
                    rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0));
140
                last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
Shucai Xiao's avatar
Shucai Xiao committed
141

Shucai Xiao's avatar
Shucai Xiao committed
142
                // following logic is to ensure the last instruction is a
Shucai Xiao's avatar
Shucai Xiao committed
143
144
                // concat instruction
                // sequence len is 1
Shucai Xiao's avatar
Shucai Xiao committed
145
                instruction_ref hidden_output{};
Shucai Xiao's avatar
Shucai Xiao committed
146
                if(ret[0] == prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
147
                {
Shucai Xiao's avatar
Shucai Xiao committed
148
                    hidden_output = prog.replace_instruction(ins, op::concat{0}, ret[1]);
Shucai Xiao's avatar
Shucai Xiao committed
149
150
151
152
153
                }
                else
                {
                    auto concat_arg0 = is_forward ? ret[0] : ret[1];
                    auto concat_arg1 = is_forward ? ret[1] : ret[0];
Shucai Xiao's avatar
Shucai Xiao committed
154
155
                    hidden_output =
                        prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
Shucai Xiao's avatar
Shucai Xiao committed
156
                }
Shucai Xiao's avatar
Shucai Xiao committed
157
158
            }

159
160
161
            // search its output to find if there are rnn_last_output operator
            // while loop to handle case of multiple rnn_last_output operators
            auto last_output_it = ins->outputs().begin();
Shucai Xiao's avatar
Shucai Xiao committed
162
            while(last_output_it != ins->outputs().end())
Shucai Xiao's avatar
Shucai Xiao committed
163
            {
Shucai Xiao's avatar
Shucai Xiao committed
164
                last_output_it = std::find_if(last_output_it, ins->outputs().end(), [](auto i) {
165
166
                    return i->name() == "rnn_last_output";
                });
Shucai Xiao's avatar
Shucai Xiao committed
167

168
169
170
171
172
173
                if(last_output_it != ins->outputs().end())
                {
                    prog.replace_instruction(*last_output_it, last_output);
                    last_output_it++;
                }
            }
Shucai Xiao's avatar
Shucai Xiao committed
174
        }
Shucai Xiao's avatar
Shucai Xiao committed
175
176
177
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
178
std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
179
180
181
                                                   program& prog,
                                                   instruction_ref ins,
                                                   instruction_ref input,
Shucai Xiao's avatar
Shucai Xiao committed
182
183
                                                   instruction_ref w,
                                                   instruction_ref r,
Shucai Xiao's avatar
Shucai Xiao committed
184
                                                   instruction_ref bias,
Shucai Xiao's avatar
Shucai Xiao committed
185
                                                   instruction_ref ih,
Shucai Xiao's avatar
Shucai Xiao committed
186
                                                   operation& actv_func) const
Shucai Xiao's avatar
Shucai Xiao committed
187
{
Shucai Xiao's avatar
Shucai Xiao committed
188
189
    // squeeze and transpose w
    std::vector<int64_t> perm{1, 0};
Shucai Xiao's avatar
Shucai Xiao committed
190
    auto sw      = prog.insert_instruction(ins, op::squeeze{{0}}, w);
Shucai Xiao's avatar
Shucai Xiao committed
191
    auto tran_sw = prog.insert_instruction(ins, op::transpose{perm}, sw);
Shucai Xiao's avatar
Shucai Xiao committed
192
193

    // squeeze and transpose r
Shucai Xiao's avatar
Shucai Xiao committed
194
    auto sr      = prog.insert_instruction(ins, op::squeeze{{0}}, r);
Shucai Xiao's avatar
Shucai Xiao committed
195
196
197
198
199
200
    auto tran_sr = prog.insert_instruction(ins, op::transpose{perm}, sr);

    // initial hidden state
    auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);

    // bias
Shucai Xiao's avatar
Shucai Xiao committed
201
    if(bias != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
202
    {
Shucai Xiao's avatar
Shucai Xiao committed
203
204
205
206
207
208
        long hs    = r->get_shape().lens()[2];
        auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
        auto wb    = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
        auto rb    = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
        auto b     = prog.insert_instruction(ins, op::add{}, wb, rb);
        bias       = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, b);
Shucai Xiao's avatar
Shucai Xiao committed
209
210
    }

Shucai Xiao's avatar
Shucai Xiao committed
211
212
    instruction_ref hidden_out = prog.end();
    instruction_ref last_out{};
Shucai Xiao's avatar
Shucai Xiao committed
213
214
    last_out            = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);
    std::size_t seq_len = input->get_shape().lens()[0];
Shucai Xiao's avatar
Shucai Xiao committed
215
216
    for(std::size_t i = 0; i < seq_len; i++)
    {
Shucai Xiao's avatar
Shucai Xiao committed
217
        long seq_index = is_forward ? i : (seq_len - 1 - i);
Shucai Xiao's avatar
Shucai Xiao committed
218
219
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
        xt      = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
Shucai Xiao's avatar
Shucai Xiao committed
220
221
222
223
        auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
        auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
        auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
        instruction_ref ht;
Shucai Xiao's avatar
Shucai Xiao committed
224
225
        if(bias != prog.end())
        {
Shucai Xiao's avatar
Shucai Xiao committed
226
            ht = prog.insert_instruction(ins, op::add{}, xt_ht, bias);
Shucai Xiao's avatar
Shucai Xiao committed
227
228
229
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
230
            ht = xt_ht;
Shucai Xiao's avatar
Shucai Xiao committed
231
232
233
        }

        // apply activation function
Shucai Xiao's avatar
Shucai Xiao committed
234
        ht  = prog.insert_instruction(ins, actv_func, ht);
Shucai Xiao's avatar
Shucai Xiao committed
235
        sih = ht;
Shucai Xiao's avatar
Shucai Xiao committed
236

Shucai Xiao's avatar
Shucai Xiao committed
237
238
239
        // add the dimensions of sequence length (axis 0 for sequence length,
        // axis 1 for num_directions
        last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht);
Shucai Xiao's avatar
Shucai Xiao committed
240

Shucai Xiao's avatar
Shucai Xiao committed
241
242
243
        // concatenation for the last last_out is performed in the apply()
        // function to ensure the last instruction is concat, then we have
        // output inserted
Shucai Xiao's avatar
Shucai Xiao committed
244
        if(i < seq_len - 1)
Shucai Xiao's avatar
Shucai Xiao committed
245
        {
Shucai Xiao's avatar
Shucai Xiao committed
246
247
            if(is_forward)
            {
Shucai Xiao's avatar
Shucai Xiao committed
248
249
250
251
                hidden_out =
                    (seq_index == 0)
                        ? last_out
                        : prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out);
Shucai Xiao's avatar
Shucai Xiao committed
252
253
254
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
255
256
257
258
                hidden_out =
                    (seq_index == seq_len - 1)
                        ? last_out
                        : prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out);
Shucai Xiao's avatar
Shucai Xiao committed
259
            }
Shucai Xiao's avatar
Shucai Xiao committed
260
261
262
        }
    }

263
    return {hidden_out, last_out};
Shucai Xiao's avatar
Shucai Xiao committed
264
265
}

266
267
268
269
std::vector<operation> rewrite_rnn::compute_actv_funcs(instruction_ref ins) const
{
    auto rnn_op = any_cast<op::rnn>(ins->get_operator());
    // before rewrite the rnn operator, need to ensure
Shucai Xiao's avatar
Shucai Xiao committed
270
    // we have 2 actv funcs. If less than 2, use the
271
    // algorithm in parse_rnn to make 2 actv functions
Shucai Xiao's avatar
Shucai Xiao committed
272
    if(rnn_op.direction == op::rnn::bidirectional)
273
    {
Shucai Xiao's avatar
Shucai Xiao committed
274
        if(rnn_op.actv_funcs.empty())
275
276
277
278
        {
            // default is tanh
            return {op::tanh{}, op::tanh{}};
        }
Shucai Xiao's avatar
Shucai Xiao committed
279
        else if(rnn_op.actv_funcs.size() == 1)
280
281
282
283
284
285
286
287
288
289
        {
            return {rnn_op.actv_funcs.at(0), rnn_op.actv_funcs.at(0)};
        }
        else
        {
            return rnn_op.actv_funcs;
        }
    }
    else
    {
Shucai Xiao's avatar
Shucai Xiao committed
290
        if(rnn_op.actv_funcs.empty())
291
292
293
294
295
296
297
298
299
300
301
        {
            // default is tanh
            return {op::tanh{}};
        }
        else
        {
            return rnn_op.actv_funcs;
        }
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
302
303
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx