"vscode:/vscode.git/clone" did not exist on "e8cb5106ff6a3fabd5f5b25d38e34ea7c74c31f0"
rewrite_rnn.cpp 12.3 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
4
5
6
7
8
9
10
11
12
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

void rewrite_rnn::apply(program& prog) const
{
Shucai Xiao's avatar
Shucai Xiao committed
13
    std::unordered_map<instruction_ref, instruction_ref> map_last_output;
Shucai Xiao's avatar
Shucai Xiao committed
14
15
    for(auto ins : iterator_for(prog))
    {
Shucai Xiao's avatar
Shucai Xiao committed
16
17
        // rewrite rnn operator
        if(ins->name() == "rnn")
Shucai Xiao's avatar
Shucai Xiao committed
18
        {
19
20
21
22
            // could be 3 to 6 inputs, but the parse_rnn function will
            // append undefined operators to make 6 arguments when parsing
            // an onnx file. Another case is user can have only 3 arguments
            // when writing their program.
Shucai Xiao's avatar
Shucai Xiao committed
23
24
25
26
27
28
            auto args = ins->inputs();

            shape seq_shape         = args[0]->get_shape();
            std::size_t hidden_size = args[1]->get_shape().lens()[1];
            std::size_t batch_size  = seq_shape.lens()[1];
            shape::type_t type      = seq_shape.type();
29
30
            migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
            std::vector<float> data(ih_shape.elements(), 0);
Shucai Xiao's avatar
Shucai Xiao committed
31

Shucai Xiao's avatar
Shucai Xiao committed
32
            auto actv_funcs                = compute_actv_funcs(ins);
Shucai Xiao's avatar
Shucai Xiao committed
33
34
            auto rnn_op                    = any_cast<op::rnn>(ins->get_operator());
            op::rnn::rnn_direction_t dicrt = rnn_op.direction;
35
            if(dicrt == op::rnn::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
36
            {
Shucai Xiao's avatar
Shucai Xiao committed
37
                // input weight matrix
Shucai Xiao's avatar
Shucai Xiao committed
38
39
                auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
                auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
Shucai Xiao's avatar
Shucai Xiao committed
40
41

                // hidden state weight matrix
Shucai Xiao's avatar
Shucai Xiao committed
42
43
                auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
                auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
Shucai Xiao's avatar
Shucai Xiao committed
44
45
46
47

                // process bias
                instruction_ref bias_forward, bias_reverse;
                bias_forward = bias_reverse = prog.end();
48
                if(args.size() >= 4 && args[3]->get_operator().name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
49
50
51
52
53
54
55
56
                {
                    bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
                    bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
                }

                // process intial hidden state, it could be the 6th argument
                // or the 5th one (if the sequence len argument is ignored)
                instruction_ref ih_forward, ih_reverse;
57
                if(args.size() == 6 && args[5]->get_operator().name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
58
                {
Shucai Xiao's avatar
Shucai Xiao committed
59
60
                    ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
                    ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
Shucai Xiao's avatar
Shucai Xiao committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
                }
                else
                {
                    ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
                    ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
                }

                auto ret_forward = rnn_cell(true,
                                            prog,
                                            ins,
                                            args[0],
                                            w_forward,
                                            r_forward,
                                            bias_forward,
                                            ih_forward,
76
                                            actv_funcs.at(0));
Shucai Xiao's avatar
Shucai Xiao committed
77
78
79
80
81
82
83
84
                auto ret_reverse = rnn_cell(false,
                                            prog,
                                            ins,
                                            args[0],
                                            w_reverse,
                                            r_reverse,
                                            bias_reverse,
                                            ih_reverse,
85
                                            actv_funcs.at(1));
Shucai Xiao's avatar
Shucai Xiao committed
86

Shucai Xiao's avatar
Shucai Xiao committed
87
88
                auto concat_output =
                    prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
Shucai Xiao's avatar
Shucai Xiao committed
89
                auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
Shucai Xiao's avatar
Shucai Xiao committed
90

Shucai Xiao's avatar
Shucai Xiao committed
91
92
93
                // The following logic is to ensure the last instruction rewritten from
                // rnn operator is a concat instruction
                // sequence len is 1
Shucai Xiao's avatar
Shucai Xiao committed
94
                instruction_ref hidden_output{};
Shucai Xiao's avatar
Shucai Xiao committed
95
                if(ret_forward[0] == prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
96
                {
Shucai Xiao's avatar
Shucai Xiao committed
97
98
                    hidden_output = prog.replace_instruction(
                        ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
Shucai Xiao's avatar
Shucai Xiao committed
99
                }
Shucai Xiao's avatar
Shucai Xiao committed
100
                else
Shucai Xiao's avatar
Shucai Xiao committed
101
                {
Shucai Xiao's avatar
Shucai Xiao committed
102
103
104
105
                    ret_forward[0] =
                        prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
                    ret_reverse[0] =
                        prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
Shucai Xiao's avatar
Shucai Xiao committed
106
107
                    hidden_output = prog.replace_instruction(
                        ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
Shucai Xiao's avatar
Shucai Xiao committed
108
                }
Shucai Xiao's avatar
Shucai Xiao committed
109
                map_last_output[hidden_output] = last_output;
Shucai Xiao's avatar
Shucai Xiao committed
110
111
112
            }
            else
            {
113
                bool is_forward = (dicrt == op::rnn::forward);
Shucai Xiao's avatar
Shucai Xiao committed
114
115
116
117
118
119
120
121
                // input weight matrix
                auto w = args[1];

                // hidden state weight matrix
                auto r = args[2];

                // process bias and initial hidden state
                instruction_ref bias = prog.end();
122
                if(args.size() >= 4 && args[3]->get_operator().name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
123
124
125
126
127
128
                {
                    bias = args[3];
                }

                // process intial hidden state
                instruction_ref ih;
Shucai Xiao's avatar
Shucai Xiao committed
129
                if(args.size() == 6 && args[5]->get_operator().name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
130
                {
131
                    ih = args[5];
Shucai Xiao's avatar
Shucai Xiao committed
132
133
134
135
136
137
                }
                else
                {
                    ih = prog.add_literal(migraphx::literal{ih_shape, data});
                }

Shucai Xiao's avatar
Shucai Xiao committed
138
139
                auto ret =
                    rnn_cell(is_forward, prog, ins, args[0], w, r, bias, ih, actv_funcs.at(0));
Shucai Xiao's avatar
Shucai Xiao committed
140
                auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
Shucai Xiao's avatar
Shucai Xiao committed
141

Shucai Xiao's avatar
Shucai Xiao committed
142
                // following logic is to ensure the last instruction is a
Shucai Xiao's avatar
Shucai Xiao committed
143
144
                // concat instruction
                // sequence len is 1
Shucai Xiao's avatar
Shucai Xiao committed
145
                instruction_ref hidden_output{};
Shucai Xiao's avatar
Shucai Xiao committed
146
                if(ret[0] == prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
147
                {
Shucai Xiao's avatar
Shucai Xiao committed
148
                    hidden_output = prog.replace_instruction(ins, op::concat{0}, ret[1]);
Shucai Xiao's avatar
Shucai Xiao committed
149
150
151
152
153
                }
                else
                {
                    auto concat_arg0 = is_forward ? ret[0] : ret[1];
                    auto concat_arg1 = is_forward ? ret[1] : ret[0];
Shucai Xiao's avatar
Shucai Xiao committed
154
155
                    hidden_output =
                        prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
Shucai Xiao's avatar
Shucai Xiao committed
156
                }
Shucai Xiao's avatar
Shucai Xiao committed
157
                map_last_output[hidden_output] = last_output;
Shucai Xiao's avatar
Shucai Xiao committed
158
159
160
            }
        }

Shucai Xiao's avatar
Shucai Xiao committed
161
        // rewrite the rnn_last_output operator that right after the rnn
Shucai Xiao's avatar
Shucai Xiao committed
162
163
164
        // operator. Intuitively, we can do a slice on the input to get
        // the last output, but it is already existed in the rnn operator,
        // so we can just use it as the output here
Shucai Xiao's avatar
Shucai Xiao committed
165
        if(ins->name() == "rnn_last_output")
Shucai Xiao's avatar
Shucai Xiao committed
166
        {
Shucai Xiao's avatar
Shucai Xiao committed
167
168
169
            auto inputs = ins->inputs();
            assert(inputs.size() == 1);
            auto arg = inputs[0];
Shucai Xiao's avatar
Shucai Xiao committed
170
            if(map_last_output.count(arg) == 0)
Shucai Xiao's avatar
Shucai Xiao committed
171
            {
Shucai Xiao's avatar
Shucai Xiao committed
172
                MIGRAPHX_THROW("RNN_LAST_OUTPUT: no related rnn operator as its input");
Shucai Xiao's avatar
Shucai Xiao committed
173
            }
Shucai Xiao's avatar
Shucai Xiao committed
174
175

            prog.replace_instruction(ins, map_last_output[arg]);
Shucai Xiao's avatar
Shucai Xiao committed
176
        }
Shucai Xiao's avatar
Shucai Xiao committed
177
178
179
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
180
std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
181
182
183
                                                   program& prog,
                                                   instruction_ref ins,
                                                   instruction_ref input,
Shucai Xiao's avatar
Shucai Xiao committed
184
185
                                                   instruction_ref w,
                                                   instruction_ref r,
Shucai Xiao's avatar
Shucai Xiao committed
186
                                                   instruction_ref bias,
Shucai Xiao's avatar
Shucai Xiao committed
187
                                                   instruction_ref ih,
Shucai Xiao's avatar
Shucai Xiao committed
188
                                                   operation& actv_func) const
Shucai Xiao's avatar
Shucai Xiao committed
189
{
Shucai Xiao's avatar
Shucai Xiao committed
190
191
    // squeeze and transpose w
    std::vector<int64_t> perm{1, 0};
Shucai Xiao's avatar
Shucai Xiao committed
192
    auto sw      = prog.insert_instruction(ins, op::squeeze{{0}}, w);
Shucai Xiao's avatar
Shucai Xiao committed
193
    auto tran_sw = prog.insert_instruction(ins, op::transpose{perm}, sw);
Shucai Xiao's avatar
Shucai Xiao committed
194
195

    // squeeze and transpose r
Shucai Xiao's avatar
Shucai Xiao committed
196
    auto sr      = prog.insert_instruction(ins, op::squeeze{{0}}, r);
Shucai Xiao's avatar
Shucai Xiao committed
197
198
199
200
201
202
    auto tran_sr = prog.insert_instruction(ins, op::transpose{perm}, sr);

    // initial hidden state
    auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);

    // bias
Shucai Xiao's avatar
Shucai Xiao committed
203
    if(bias != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
204
    {
Shucai Xiao's avatar
Shucai Xiao committed
205
206
207
208
209
210
        long hs    = r->get_shape().lens()[2];
        auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
        auto wb    = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
        auto rb    = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
        auto b     = prog.insert_instruction(ins, op::add{}, wb, rb);
        bias       = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, b);
Shucai Xiao's avatar
Shucai Xiao committed
211
212
    }

Shucai Xiao's avatar
Shucai Xiao committed
213
    instruction_ref hidden_out = prog.end(), last_out;
Shucai Xiao's avatar
Shucai Xiao committed
214
    last_out                   = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);
Shucai Xiao's avatar
Shucai Xiao committed
215
    std::size_t seq_len        = input->get_shape().lens()[0];
Shucai Xiao's avatar
Shucai Xiao committed
216
217
    for(std::size_t i = 0; i < seq_len; i++)
    {
Shucai Xiao's avatar
Shucai Xiao committed
218
        long seq_index = is_forward ? i : (seq_len - 1 - i);
Shucai Xiao's avatar
Shucai Xiao committed
219
220
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
        xt      = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
Shucai Xiao's avatar
Shucai Xiao committed
221
222
223
224
        auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
        auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
        auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
        instruction_ref ht;
Shucai Xiao's avatar
Shucai Xiao committed
225
226
        if(bias != prog.end())
        {
Shucai Xiao's avatar
Shucai Xiao committed
227
            ht = prog.insert_instruction(ins, op::add{}, xt_ht, bias);
Shucai Xiao's avatar
Shucai Xiao committed
228
229
230
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
231
            ht = xt_ht;
Shucai Xiao's avatar
Shucai Xiao committed
232
233
234
        }

        // apply activation function
Shucai Xiao's avatar
Shucai Xiao committed
235
        ht  = prog.insert_instruction(ins, actv_func, ht);
Shucai Xiao's avatar
Shucai Xiao committed
236
        sih = ht;
Shucai Xiao's avatar
Shucai Xiao committed
237

Shucai Xiao's avatar
Shucai Xiao committed
238
239
240
        // add the dimensions of sequence length (axis 0 for sequence length,
        // axis 1 for num_directions
        last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht);
Shucai Xiao's avatar
Shucai Xiao committed
241

Shucai Xiao's avatar
Shucai Xiao committed
242
243
244
        // concatenation for the last last_out is performed in the apply()
        // function to ensure the last instruction is concat, then we have
        // output inserted
Shucai Xiao's avatar
Shucai Xiao committed
245
        if(i < seq_len - 1)
Shucai Xiao's avatar
Shucai Xiao committed
246
        {
Shucai Xiao's avatar
Shucai Xiao committed
247
248
            if(is_forward)
            {
Shucai Xiao's avatar
Shucai Xiao committed
249
250
251
252
                hidden_out =
                    (seq_index == 0)
                        ? last_out
                        : prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out);
Shucai Xiao's avatar
Shucai Xiao committed
253
254
255
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
256
257
258
259
                hidden_out =
                    (seq_index == seq_len - 1)
                        ? last_out
                        : prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out);
Shucai Xiao's avatar
Shucai Xiao committed
260
            }
Shucai Xiao's avatar
Shucai Xiao committed
261
262
263
        }
    }

264
    return {hidden_out, last_out};
Shucai Xiao's avatar
Shucai Xiao committed
265
266
}

267
268
269
270
std::vector<operation> rewrite_rnn::compute_actv_funcs(instruction_ref ins) const
{
    auto rnn_op = any_cast<op::rnn>(ins->get_operator());
    // before rewrite the rnn operator, need to ensure
Shucai Xiao's avatar
Shucai Xiao committed
271
    // we have 2 actv funcs. If less than 2, use the
272
    // algorithm in parse_rnn to make 2 actv functions
Shucai Xiao's avatar
Shucai Xiao committed
273
    if(rnn_op.direction == op::rnn::bidirectional)
274
    {
Shucai Xiao's avatar
Shucai Xiao committed
275
        if(rnn_op.actv_funcs.empty())
276
277
278
279
        {
            // default is tanh
            return {op::tanh{}, op::tanh{}};
        }
Shucai Xiao's avatar
Shucai Xiao committed
280
        else if(rnn_op.actv_funcs.size() == 1)
281
282
283
284
285
286
287
288
289
290
        {
            return {rnn_op.actv_funcs.at(0), rnn_op.actv_funcs.at(0)};
        }
        else
        {
            return rnn_op.actv_funcs;
        }
    }
    else
    {
Shucai Xiao's avatar
Shucai Xiao committed
291
        if(rnn_op.actv_funcs.empty())
292
293
294
295
296
297
298
299
300
301
302
        {
            // default is tanh
            return {op::tanh{}};
        }
        else
        {
            return rnn_op.actv_funcs;
        }
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
303
304
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx