rewrite_rnn.cpp 11.5 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
4
5
6
7
8
9
10
11
12
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

void rewrite_rnn::apply(program& prog) const
{
Shucai Xiao's avatar
Shucai Xiao committed
13
    std::unordered_map<instruction_ref, instruction_ref> map_last_output;
Shucai Xiao's avatar
Shucai Xiao committed
14
15
    for(auto ins : iterator_for(prog))
    {
Shucai Xiao's avatar
Shucai Xiao committed
16
17
        // rewrite rnn operator
        if(ins->name() == "rnn")
Shucai Xiao's avatar
Shucai Xiao committed
18
        {
Shucai Xiao's avatar
Shucai Xiao committed
19
20
21
            // could be 3 to 6 inputs, but the 5th input is undefined in
            // pytorch exported onnx, and it is ignored by protobuf. So
            // for input arguments 5 and 6, we need to check the shape,
Shucai Xiao's avatar
Shucai Xiao committed
22
23
24
25
26
27
28
            // then based on the shape to judge the specific input info
            auto args = ins->inputs();

            shape seq_shape         = args[0]->get_shape();
            std::size_t hidden_size = args[1]->get_shape().lens()[1];
            std::size_t batch_size  = seq_shape.lens()[1];
            shape::type_t type      = seq_shape.type();
29
            migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
30
            std::vector<float> data(ih_shape.elements(), 0);
Shucai Xiao's avatar
Shucai Xiao committed
31
32
33
34

            auto rnn_op                    = any_cast<op::rnn>(ins->get_operator());
            op::rnn::rnn_direction_t dicrt = rnn_op.direction;
            if(dicrt == op::rnn::rnn_direction_t::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
35
            {
Shucai Xiao's avatar
Shucai Xiao committed
36
                // input weight matrix
Shucai Xiao's avatar
Shucai Xiao committed
37
38
                auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
                auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
Shucai Xiao's avatar
Shucai Xiao committed
39
40

                // hidden state weight matrix
Shucai Xiao's avatar
Shucai Xiao committed
41
42
                auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
                auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
Shucai Xiao's avatar
Shucai Xiao committed
43
44
45
46
47
48
49
50
51
52
53
54
55

                // process bias
                instruction_ref bias_forward, bias_reverse;
                bias_forward = bias_reverse = prog.end();
                if(args.size() >= 4)
                {
                    bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
                    bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
                }

                // process intial hidden state, it could be the 6th argument
                // or the 5th one (if the sequence len argument is ignored)
                instruction_ref ih_forward, ih_reverse;
Shucai Xiao's avatar
Shucai Xiao committed
56
57
                if(args.size() == 6 ||
                   (args.size() == 5 && args[4]->get_shape().lens().size() == 3))
Shucai Xiao's avatar
Shucai Xiao committed
58
59
                {
                    auto arg_ih = (args.size() == 6) ? args[5] : args[4];
Shucai Xiao's avatar
Shucai Xiao committed
60
61
                    ih_forward  = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, arg_ih);
                    ih_reverse  = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, arg_ih);
Shucai Xiao's avatar
Shucai Xiao committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
                }
                else
                {
                    ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
                    ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
                }

                auto ret_forward = rnn_cell(true,
                                            prog,
                                            ins,
                                            args[0],
                                            w_forward,
                                            r_forward,
                                            bias_forward,
                                            ih_forward,
                                            rnn_op.actv_funcs.at(0));
                auto ret_reverse = rnn_cell(false,
                                            prog,
                                            ins,
                                            args[0],
                                            w_reverse,
                                            r_reverse,
                                            bias_reverse,
                                            ih_reverse,
                                            rnn_op.actv_funcs.at(1));

Shucai Xiao's avatar
Shucai Xiao committed
88
89
                auto concat_output =
                    prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
Shucai Xiao's avatar
Shucai Xiao committed
90
                auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
Shucai Xiao's avatar
Shucai Xiao committed
91

Shucai Xiao's avatar
Shucai Xiao committed
92
93
94
                // The following logic is to ensure the last instruction rewritten from
                // rnn operator is a concat instruction
                // sequence len is 1
Shucai Xiao's avatar
Shucai Xiao committed
95
                instruction_ref hidden_output{};
Shucai Xiao's avatar
Shucai Xiao committed
96
                if(ret_forward[0] == prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
97
                {
Shucai Xiao's avatar
Shucai Xiao committed
98
99
                    hidden_output = prog.replace_instruction(
                        ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
Shucai Xiao's avatar
Shucai Xiao committed
100
                }
Shucai Xiao's avatar
Shucai Xiao committed
101
                else
Shucai Xiao's avatar
Shucai Xiao committed
102
                {
Shucai Xiao's avatar
Shucai Xiao committed
103
104
105
106
                    ret_forward[0] =
                        prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
                    ret_reverse[0] =
                        prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
Shucai Xiao's avatar
Shucai Xiao committed
107
108
                    hidden_output = prog.replace_instruction(
                        ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
Shucai Xiao's avatar
Shucai Xiao committed
109
                }
Shucai Xiao's avatar
Shucai Xiao committed
110
                map_last_output[hidden_output] = last_output;
Shucai Xiao's avatar
Shucai Xiao committed
111
112
113
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
114
                bool is_forward = (dicrt == op::rnn::rnn_direction_t::forward);
Shucai Xiao's avatar
Shucai Xiao committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
                // input weight matrix
                auto w = args[1];

                // hidden state weight matrix
                auto r = args[2];

                // process bias and initial hidden state
                instruction_ref bias = prog.end();
                if(args.size() >= 4)
                {
                    bias = args[3];
                }

                // process intial hidden state
                instruction_ref ih;
Shucai Xiao's avatar
Shucai Xiao committed
130
131
                if(args.size() == 6 ||
                   (args.size() == 5 && args[4]->get_shape().lens().size() == 3))
Shucai Xiao's avatar
Shucai Xiao committed
132
133
134
135
136
137
138
139
                {
                    ih = (args.size() == 6) ? args[5] : args[4];
                }
                else
                {
                    ih = prog.add_literal(migraphx::literal{ih_shape, data});
                }

Shucai Xiao's avatar
Shucai Xiao committed
140
141
                auto ret = rnn_cell(
                    is_forward, prog, ins, args[0], w, r, bias, ih, rnn_op.actv_funcs.at(0));
Shucai Xiao's avatar
Shucai Xiao committed
142
                auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
Shucai Xiao's avatar
Shucai Xiao committed
143

Shucai Xiao's avatar
Shucai Xiao committed
144
                // following logic is to ensure the last instruction is a
Shucai Xiao's avatar
Shucai Xiao committed
145
146
                // concat instruction
                // sequence len is 1
Shucai Xiao's avatar
Shucai Xiao committed
147
                instruction_ref hidden_output{};
Shucai Xiao's avatar
Shucai Xiao committed
148
                if(ret[0] == prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
149
                {
Shucai Xiao's avatar
Shucai Xiao committed
150
                    hidden_output = prog.replace_instruction(ins, op::concat{0}, ret[1]);
Shucai Xiao's avatar
Shucai Xiao committed
151
152
153
154
155
                }
                else
                {
                    auto concat_arg0 = is_forward ? ret[0] : ret[1];
                    auto concat_arg1 = is_forward ? ret[1] : ret[0];
Shucai Xiao's avatar
Shucai Xiao committed
156
157
                    hidden_output =
                        prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
Shucai Xiao's avatar
Shucai Xiao committed
158
                }
Shucai Xiao's avatar
Shucai Xiao committed
159
                map_last_output[hidden_output] = last_output;
Shucai Xiao's avatar
Shucai Xiao committed
160
161
162
            }
        }

Shucai Xiao's avatar
Shucai Xiao committed
163
        // rewrite the rnn_last_output operator that right after the rnn
164
        // operator. Intuitively, we can do a slice on its input to get
Shucai Xiao's avatar
Shucai Xiao committed
165
166
        // the last output, but it is already existed in the rnn operator,
        // so we can just use it as the output here
Shucai Xiao's avatar
Shucai Xiao committed
167
        if(ins->name() == "rnn_last_output")
Shucai Xiao's avatar
Shucai Xiao committed
168
        {
Shucai Xiao's avatar
Shucai Xiao committed
169
170
171
            auto inputs = ins->inputs();
            assert(inputs.size() == 1);
            auto arg = inputs[0];
Shucai Xiao's avatar
Shucai Xiao committed
172
            if(map_last_output.count(arg) == 0)
Shucai Xiao's avatar
Shucai Xiao committed
173
            {
Shucai Xiao's avatar
Shucai Xiao committed
174
                MIGRAPHX_THROW("RNN_LAST_OUTPUT: no related rnn operator as its input");
Shucai Xiao's avatar
Shucai Xiao committed
175
            }
Shucai Xiao's avatar
Shucai Xiao committed
176
177

            prog.replace_instruction(ins, map_last_output[arg]);
Shucai Xiao's avatar
Shucai Xiao committed
178
        }
Shucai Xiao's avatar
Shucai Xiao committed
179
180
181
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
182
std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
183
184
185
                                                   program& prog,
                                                   instruction_ref ins,
                                                   instruction_ref input,
Shucai Xiao's avatar
Shucai Xiao committed
186
187
                                                   instruction_ref w,
                                                   instruction_ref r,
Shucai Xiao's avatar
Shucai Xiao committed
188
                                                   instruction_ref bias,
Shucai Xiao's avatar
Shucai Xiao committed
189
                                                   instruction_ref ih,
Shucai Xiao's avatar
Shucai Xiao committed
190
                                                   operation& actv_func) const
Shucai Xiao's avatar
Shucai Xiao committed
191
{
Shucai Xiao's avatar
Shucai Xiao committed
192
193
    // squeeze and transpose w
    std::vector<int64_t> perm{1, 0};
Shucai Xiao's avatar
Shucai Xiao committed
194
    auto sw      = prog.insert_instruction(ins, op::squeeze{{0}}, w);
Shucai Xiao's avatar
Shucai Xiao committed
195
    auto tran_sw = prog.insert_instruction(ins, op::transpose{perm}, sw);
Shucai Xiao's avatar
Shucai Xiao committed
196
197

    // squeeze and transpose r
Shucai Xiao's avatar
Shucai Xiao committed
198
    auto sr      = prog.insert_instruction(ins, op::squeeze{{0}}, r);
Shucai Xiao's avatar
Shucai Xiao committed
199
200
201
202
203
204
    auto tran_sr = prog.insert_instruction(ins, op::transpose{perm}, sr);

    // initial hidden state
    auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);

    // bias
Shucai Xiao's avatar
Shucai Xiao committed
205
    if(bias != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
206
    {
Shucai Xiao's avatar
Shucai Xiao committed
207
208
209
210
211
212
        long hs    = r->get_shape().lens()[2];
        auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
        auto wb    = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
        auto rb    = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
        auto b     = prog.insert_instruction(ins, op::add{}, wb, rb);
        bias       = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, b);
Shucai Xiao's avatar
Shucai Xiao committed
213
214
    }

Shucai Xiao's avatar
Shucai Xiao committed
215
    instruction_ref hidden_out = prog.end(), last_out;
Shucai Xiao's avatar
Shucai Xiao committed
216
    last_out                   = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);
Shucai Xiao's avatar
Shucai Xiao committed
217
    std::size_t seq_len        = input->get_shape().lens()[0];
Shucai Xiao's avatar
Shucai Xiao committed
218
219
    for(std::size_t i = 0; i < seq_len; i++)
    {
Shucai Xiao's avatar
Shucai Xiao committed
220
        long seq_index = is_forward ? i : (seq_len - 1 - i);
Shucai Xiao's avatar
Shucai Xiao committed
221
222
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
        xt      = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
Shucai Xiao's avatar
Shucai Xiao committed
223
224
225
226
        auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
        auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
        auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
        instruction_ref ht;
Shucai Xiao's avatar
Shucai Xiao committed
227
228
        if(bias != prog.end())
        {
Shucai Xiao's avatar
Shucai Xiao committed
229
            ht = prog.insert_instruction(ins, op::add{}, xt_ht, bias);
Shucai Xiao's avatar
Shucai Xiao committed
230
231
232
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
233
            ht = xt_ht;
Shucai Xiao's avatar
Shucai Xiao committed
234
235
236
        }

        // apply activation function
Shucai Xiao's avatar
Shucai Xiao committed
237
        ht  = prog.insert_instruction(ins, actv_func, ht);
Shucai Xiao's avatar
Shucai Xiao committed
238
        sih = ht;
Shucai Xiao's avatar
Shucai Xiao committed
239

Shucai Xiao's avatar
Shucai Xiao committed
240
241
242
        // add the dimensions of sequence length (axis 0 for sequence length,
        // axis 1 for num_directions
        last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht);
Shucai Xiao's avatar
Shucai Xiao committed
243

Shucai Xiao's avatar
Shucai Xiao committed
244
245
246
        // concatenation for the last last_out is performed in the apply()
        // function to ensure the last instruction is concat, then we have
        // output inserted
Shucai Xiao's avatar
Shucai Xiao committed
247
        if(i < seq_len - 1)
Shucai Xiao's avatar
Shucai Xiao committed
248
        {
Shucai Xiao's avatar
Shucai Xiao committed
249
250
            if(is_forward)
            {
Shucai Xiao's avatar
Shucai Xiao committed
251
252
253
254
                hidden_out =
                    (seq_index == 0)
                        ? last_out
                        : prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out);
Shucai Xiao's avatar
Shucai Xiao committed
255
256
257
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
258
259
260
261
                hidden_out =
                    (seq_index == seq_len - 1)
                        ? last_out
                        : prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out);
Shucai Xiao's avatar
Shucai Xiao committed
262
            }
Shucai Xiao's avatar
Shucai Xiao committed
263
264
265
        }
    }

266
    return {hidden_out, last_out};
Shucai Xiao's avatar
Shucai Xiao committed
267
268
269
270
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx