rewrite_rnn.cpp 12.3 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
4
5
6
7
8
9
10
11
12
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

void rewrite_rnn::apply(program& prog) const
{
Shucai Xiao's avatar
Shucai Xiao committed
13
    std::unordered_map<instruction_ref, instruction_ref> map_last_output;
Shucai Xiao's avatar
Shucai Xiao committed
14
15
    for(auto ins : iterator_for(prog))
    {
Shucai Xiao's avatar
Shucai Xiao committed
16
17
        // rewrite rnn operator
        if(ins->name() == "rnn")
Shucai Xiao's avatar
Shucai Xiao committed
18
        {
19
20
21
22
            // could be 3 to 6 inputs, but the parse_rnn function will
            // append undefined operators to make 6 arguments when parsing
            // an onnx file. Another case is user can have only 3 arguments
            // when writing their program.
Shucai Xiao's avatar
Shucai Xiao committed
23
24
25
26
27
28
            auto args = ins->inputs();

            shape seq_shape         = args[0]->get_shape();
            std::size_t hidden_size = args[1]->get_shape().lens()[1];
            std::size_t batch_size  = seq_shape.lens()[1];
            shape::type_t type      = seq_shape.type();
29
30
            migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
            std::vector<float> data(ih_shape.elements(), 0);
Shucai Xiao's avatar
Shucai Xiao committed
31

32
            auto actv_funcs = compute_actv_funcs(ins);
Shucai Xiao's avatar
Shucai Xiao committed
33
34
            auto rnn_op                    = any_cast<op::rnn>(ins->get_operator());
            op::rnn::rnn_direction_t dicrt = rnn_op.direction;
35
            if(dicrt == op::rnn::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
36
            {
Shucai Xiao's avatar
Shucai Xiao committed
37
                // input weight matrix
Shucai Xiao's avatar
Shucai Xiao committed
38
39
                auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
                auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
Shucai Xiao's avatar
Shucai Xiao committed
40
41

                // hidden state weight matrix
Shucai Xiao's avatar
Shucai Xiao committed
42
43
                auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
                auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
Shucai Xiao's avatar
Shucai Xiao committed
44
45
46
47

                // process bias
                instruction_ref bias_forward, bias_reverse;
                bias_forward = bias_reverse = prog.end();
48
                if(args.size() >= 4 && args[3]->get_operator().name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
49
50
51
52
53
54
55
56
                {
                    bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
                    bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
                }

                // process intial hidden state, it could be the 6th argument
                // or the 5th one (if the sequence len argument is ignored)
                instruction_ref ih_forward, ih_reverse;
57
                if(args.size() == 6 && args[5]->get_operator().name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
58
                {
Shucai Xiao's avatar
Shucai Xiao committed
59
60
                    ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
                    ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
Shucai Xiao's avatar
Shucai Xiao committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
                }
                else
                {
                    ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
                    ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
                }

                auto ret_forward = rnn_cell(true,
                                            prog,
                                            ins,
                                            args[0],
                                            w_forward,
                                            r_forward,
                                            bias_forward,
                                            ih_forward,
76
                                            actv_funcs.at(0));
Shucai Xiao's avatar
Shucai Xiao committed
77
78
79
80
81
82
83
84
                auto ret_reverse = rnn_cell(false,
                                            prog,
                                            ins,
                                            args[0],
                                            w_reverse,
                                            r_reverse,
                                            bias_reverse,
                                            ih_reverse,
85
                                            actv_funcs.at(1));
Shucai Xiao's avatar
Shucai Xiao committed
86

Shucai Xiao's avatar
Shucai Xiao committed
87
88
                auto concat_output =
                    prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
Shucai Xiao's avatar
Shucai Xiao committed
89
                auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
Shucai Xiao's avatar
Shucai Xiao committed
90

Shucai Xiao's avatar
Shucai Xiao committed
91
92
93
                // The following logic is to ensure the last instruction rewritten from
                // rnn operator is a concat instruction
                // sequence len is 1
Shucai Xiao's avatar
Shucai Xiao committed
94
                instruction_ref hidden_output{};
Shucai Xiao's avatar
Shucai Xiao committed
95
                if(ret_forward[0] == prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
96
                {
Shucai Xiao's avatar
Shucai Xiao committed
97
98
                    hidden_output = prog.replace_instruction(
                        ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
Shucai Xiao's avatar
Shucai Xiao committed
99
                }
Shucai Xiao's avatar
Shucai Xiao committed
100
                else
Shucai Xiao's avatar
Shucai Xiao committed
101
                {
Shucai Xiao's avatar
Shucai Xiao committed
102
103
104
105
                    ret_forward[0] =
                        prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
                    ret_reverse[0] =
                        prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
Shucai Xiao's avatar
Shucai Xiao committed
106
107
                    hidden_output = prog.replace_instruction(
                        ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
Shucai Xiao's avatar
Shucai Xiao committed
108
                }
Shucai Xiao's avatar
Shucai Xiao committed
109
                map_last_output[hidden_output] = last_output;
Shucai Xiao's avatar
Shucai Xiao committed
110
111
112
            }
            else
            {
113
                bool is_forward = (dicrt == op::rnn::forward);
Shucai Xiao's avatar
Shucai Xiao committed
114
115
116
117
118
119
120
121
                // input weight matrix
                auto w = args[1];

                // hidden state weight matrix
                auto r = args[2];

                // process bias and initial hidden state
                instruction_ref bias = prog.end();
122
                if(args.size() >= 4 && args[3]->get_operator().name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
123
124
125
126
127
128
                {
                    bias = args[3];
                }

                // process intial hidden state
                instruction_ref ih;
Shucai Xiao's avatar
Shucai Xiao committed
129
                if(args.size() == 6 && args[5]->get_operator().name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
130
                {
131
                    ih = args[5];
Shucai Xiao's avatar
Shucai Xiao committed
132
133
134
135
136
137
                }
                else
                {
                    ih = prog.add_literal(migraphx::literal{ih_shape, data});
                }

Shucai Xiao's avatar
Shucai Xiao committed
138
                auto ret = rnn_cell(
139
                    is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0));
Shucai Xiao's avatar
Shucai Xiao committed
140
                auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
Shucai Xiao's avatar
Shucai Xiao committed
141

Shucai Xiao's avatar
Shucai Xiao committed
142
                // following logic is to ensure the last instruction is a
Shucai Xiao's avatar
Shucai Xiao committed
143
144
                // concat instruction
                // sequence len is 1
Shucai Xiao's avatar
Shucai Xiao committed
145
                instruction_ref hidden_output{};
Shucai Xiao's avatar
Shucai Xiao committed
146
                if(ret[0] == prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
147
                {
Shucai Xiao's avatar
Shucai Xiao committed
148
                    hidden_output = prog.replace_instruction(ins, op::concat{0}, ret[1]);
Shucai Xiao's avatar
Shucai Xiao committed
149
150
151
152
153
                }
                else
                {
                    auto concat_arg0 = is_forward ? ret[0] : ret[1];
                    auto concat_arg1 = is_forward ? ret[1] : ret[0];
Shucai Xiao's avatar
Shucai Xiao committed
154
155
                    hidden_output =
                        prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
Shucai Xiao's avatar
Shucai Xiao committed
156
                }
Shucai Xiao's avatar
Shucai Xiao committed
157
                map_last_output[hidden_output] = last_output;
Shucai Xiao's avatar
Shucai Xiao committed
158
159
160
            }
        }

Shucai Xiao's avatar
Shucai Xiao committed
161
        // rewrite the rnn_last_output operator that right after the rnn
Shucai Xiao's avatar
Shucai Xiao committed
162
163
164
        // operator. Intuitively, we can do a slice on the input to get
        // the last output, but it is already existed in the rnn operator,
        // so we can just use it as the output here
Shucai Xiao's avatar
Shucai Xiao committed
165
        if(ins->name() == "rnn_last_output")
Shucai Xiao's avatar
Shucai Xiao committed
166
        {
Shucai Xiao's avatar
Shucai Xiao committed
167
168
169
            auto inputs = ins->inputs();
            assert(inputs.size() == 1);
            auto arg = inputs[0];
Shucai Xiao's avatar
Shucai Xiao committed
170
            if(map_last_output.count(arg) == 0)
Shucai Xiao's avatar
Shucai Xiao committed
171
            {
Shucai Xiao's avatar
Shucai Xiao committed
172
                MIGRAPHX_THROW("RNN_LAST_OUTPUT: no related rnn operator as its input");
Shucai Xiao's avatar
Shucai Xiao committed
173
            }
Shucai Xiao's avatar
Shucai Xiao committed
174
175

            prog.replace_instruction(ins, map_last_output[arg]);
Shucai Xiao's avatar
Shucai Xiao committed
176
        }
Shucai Xiao's avatar
Shucai Xiao committed
177
178
179
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
180
std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
181
182
183
                                                   program& prog,
                                                   instruction_ref ins,
                                                   instruction_ref input,
Shucai Xiao's avatar
Shucai Xiao committed
184
185
                                                   instruction_ref w,
                                                   instruction_ref r,
Shucai Xiao's avatar
Shucai Xiao committed
186
                                                   instruction_ref bias,
Shucai Xiao's avatar
Shucai Xiao committed
187
                                                   instruction_ref ih,
Shucai Xiao's avatar
Shucai Xiao committed
188
                                                   operation& actv_func) const
Shucai Xiao's avatar
Shucai Xiao committed
189
{
Shucai Xiao's avatar
Shucai Xiao committed
190
191
    // squeeze and transpose w
    std::vector<int64_t> perm{1, 0};
Shucai Xiao's avatar
Shucai Xiao committed
192
    auto sw      = prog.insert_instruction(ins, op::squeeze{{0}}, w);
Shucai Xiao's avatar
Shucai Xiao committed
193
    auto tran_sw = prog.insert_instruction(ins, op::transpose{perm}, sw);
Shucai Xiao's avatar
Shucai Xiao committed
194
195

    // squeeze and transpose r
Shucai Xiao's avatar
Shucai Xiao committed
196
    auto sr      = prog.insert_instruction(ins, op::squeeze{{0}}, r);
Shucai Xiao's avatar
Shucai Xiao committed
197
198
199
200
201
202
    auto tran_sr = prog.insert_instruction(ins, op::transpose{perm}, sr);

    // initial hidden state
    auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);

    // bias
Shucai Xiao's avatar
Shucai Xiao committed
203
    if(bias != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
204
    {
Shucai Xiao's avatar
Shucai Xiao committed
205
206
207
208
209
210
        long hs    = r->get_shape().lens()[2];
        auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
        auto wb    = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
        auto rb    = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
        auto b     = prog.insert_instruction(ins, op::add{}, wb, rb);
        bias       = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, b);
Shucai Xiao's avatar
Shucai Xiao committed
211
212
    }

Shucai Xiao's avatar
Shucai Xiao committed
213
    instruction_ref hidden_out = prog.end(), last_out;
Shucai Xiao's avatar
Shucai Xiao committed
214
    last_out                   = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);
Shucai Xiao's avatar
Shucai Xiao committed
215
    std::size_t seq_len        = input->get_shape().lens()[0];
Shucai Xiao's avatar
Shucai Xiao committed
216
217
    for(std::size_t i = 0; i < seq_len; i++)
    {
Shucai Xiao's avatar
Shucai Xiao committed
218
        long seq_index = is_forward ? i : (seq_len - 1 - i);
Shucai Xiao's avatar
Shucai Xiao committed
219
220
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
        xt      = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
Shucai Xiao's avatar
Shucai Xiao committed
221
222
223
224
        auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
        auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
        auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
        instruction_ref ht;
Shucai Xiao's avatar
Shucai Xiao committed
225
226
        if(bias != prog.end())
        {
Shucai Xiao's avatar
Shucai Xiao committed
227
            ht = prog.insert_instruction(ins, op::add{}, xt_ht, bias);
Shucai Xiao's avatar
Shucai Xiao committed
228
229
230
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
231
            ht = xt_ht;
Shucai Xiao's avatar
Shucai Xiao committed
232
233
234
        }

        // apply activation function
Shucai Xiao's avatar
Shucai Xiao committed
235
        ht  = prog.insert_instruction(ins, actv_func, ht);
Shucai Xiao's avatar
Shucai Xiao committed
236
        sih = ht;
Shucai Xiao's avatar
Shucai Xiao committed
237

Shucai Xiao's avatar
Shucai Xiao committed
238
239
240
        // add the dimensions of sequence length (axis 0 for sequence length,
        // axis 1 for num_directions
        last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht);
Shucai Xiao's avatar
Shucai Xiao committed
241

Shucai Xiao's avatar
Shucai Xiao committed
242
243
244
        // concatenation for the last last_out is performed in the apply()
        // function to ensure the last instruction is concat, then we have
        // output inserted
Shucai Xiao's avatar
Shucai Xiao committed
245
        if(i < seq_len - 1)
Shucai Xiao's avatar
Shucai Xiao committed
246
        {
Shucai Xiao's avatar
Shucai Xiao committed
247
248
            if(is_forward)
            {
Shucai Xiao's avatar
Shucai Xiao committed
249
250
251
252
                hidden_out =
                    (seq_index == 0)
                        ? last_out
                        : prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out);
Shucai Xiao's avatar
Shucai Xiao committed
253
254
255
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
256
257
258
259
                hidden_out =
                    (seq_index == seq_len - 1)
                        ? last_out
                        : prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out);
Shucai Xiao's avatar
Shucai Xiao committed
260
            }
Shucai Xiao's avatar
Shucai Xiao committed
261
262
263
        }
    }

264
    return {hidden_out, last_out};
Shucai Xiao's avatar
Shucai Xiao committed
265
266
}

267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303

std::vector<operation> rewrite_rnn::compute_actv_funcs(instruction_ref ins) const
{
    auto rnn_op = any_cast<op::rnn>(ins->get_operator());
    // before rewrite the rnn operator, need to ensure
    // we have 2 actv funcs. If less than 2, use the 
    // algorithm in parse_rnn to make 2 actv functions
    if (rnn_op.direction == op::rnn::bidirectional)
    {
        if (rnn_op.actv_funcs.size() == 0)
        {
            // default is tanh
            return {op::tanh{}, op::tanh{}};
        }
        else if (rnn_op.actv_funcs.size() == 1)
        {
            return {rnn_op.actv_funcs.at(0), rnn_op.actv_funcs.at(0)};
        }
        else
        {
            return rnn_op.actv_funcs;
        }
    }
    else
    {
        if (rnn_op.actv_funcs.size() == 0)
        {
            // default is tanh
            return {op::tanh{}};
        }
        else
        {
            return rnn_op.actv_funcs;
        }
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
304
305
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx