rewrite_rnn.cpp 11.4 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
4
5
6
7
8
9
10
11
12
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

void rewrite_rnn::apply(program& prog) const
{
Shucai Xiao's avatar
Shucai Xiao committed
13
    std::unordered_map<instruction_ref, instruction_ref> map_last_output;
Shucai Xiao's avatar
Shucai Xiao committed
14
15
    for(auto ins : iterator_for(prog))
    {
Shucai Xiao's avatar
Shucai Xiao committed
16
17
        // rewrite rnn operator
        if(ins->name() == "rnn")
Shucai Xiao's avatar
Shucai Xiao committed
18
        {
19
20
21
22
            // could be 3 to 6 inputs, but the parse_rnn function will
            // append undefined operators to make 6 arguments when parsing
            // an onnx file. Another case is user can have only 3 arguments
            // when writing their program.
Shucai Xiao's avatar
Shucai Xiao committed
23
24
25
26
27
28
            auto args = ins->inputs();

            shape seq_shape         = args[0]->get_shape();
            std::size_t hidden_size = args[1]->get_shape().lens()[1];
            std::size_t batch_size  = seq_shape.lens()[1];
            shape::type_t type      = seq_shape.type();
29
30
            migraphx::shape ih_shape{type, {1, batch_size, hidden_size}};
            std::vector<float> data(ih_shape.elements(), 0);
Shucai Xiao's avatar
Shucai Xiao committed
31
32
33
34

            auto rnn_op                    = any_cast<op::rnn>(ins->get_operator());
            op::rnn::rnn_direction_t dicrt = rnn_op.direction;
            if(dicrt == op::rnn::rnn_direction_t::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
35
            {
Shucai Xiao's avatar
Shucai Xiao committed
36
                // input weight matrix
Shucai Xiao's avatar
Shucai Xiao committed
37
38
                auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
                auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
Shucai Xiao's avatar
Shucai Xiao committed
39
40

                // hidden state weight matrix
Shucai Xiao's avatar
Shucai Xiao committed
41
42
                auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
                auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
Shucai Xiao's avatar
Shucai Xiao committed
43
44
45
46

                // process bias
                instruction_ref bias_forward, bias_reverse;
                bias_forward = bias_reverse = prog.end();
47
                if(args.size() >= 4 && args[3]->get_operator().name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
48
49
50
51
52
53
54
55
                {
                    bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
                    bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
                }

                // process intial hidden state, it could be the 6th argument
                // or the 5th one (if the sequence len argument is ignored)
                instruction_ref ih_forward, ih_reverse;
56
                if(args.size() == 6 && args[5]->get_operator().name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
57
                {
58
59
                    ih_forward  = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
                    ih_reverse  = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[5]);
Shucai Xiao's avatar
Shucai Xiao committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
                }
                else
                {
                    ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
                    ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
                }

                auto ret_forward = rnn_cell(true,
                                            prog,
                                            ins,
                                            args[0],
                                            w_forward,
                                            r_forward,
                                            bias_forward,
                                            ih_forward,
                                            rnn_op.actv_funcs.at(0));
                auto ret_reverse = rnn_cell(false,
                                            prog,
                                            ins,
                                            args[0],
                                            w_reverse,
                                            r_reverse,
                                            bias_reverse,
                                            ih_reverse,
                                            rnn_op.actv_funcs.at(1));

Shucai Xiao's avatar
Shucai Xiao committed
86
87
                auto concat_output =
                    prog.insert_instruction(ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
Shucai Xiao's avatar
Shucai Xiao committed
88
                auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, concat_output);
Shucai Xiao's avatar
Shucai Xiao committed
89

Shucai Xiao's avatar
Shucai Xiao committed
90
91
92
                // The following logic is to ensure the last instruction rewritten from
                // rnn operator is a concat instruction
                // sequence len is 1
Shucai Xiao's avatar
Shucai Xiao committed
93
                instruction_ref hidden_output{};
Shucai Xiao's avatar
Shucai Xiao committed
94
                if(ret_forward[0] == prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
95
                {
Shucai Xiao's avatar
Shucai Xiao committed
96
97
                    hidden_output = prog.replace_instruction(
                        ins, op::concat{1}, ret_forward[1], ret_reverse[1]);
Shucai Xiao's avatar
Shucai Xiao committed
98
                }
Shucai Xiao's avatar
Shucai Xiao committed
99
                else
Shucai Xiao's avatar
Shucai Xiao committed
100
                {
Shucai Xiao's avatar
Shucai Xiao committed
101
102
103
104
                    ret_forward[0] =
                        prog.insert_instruction(ins, op::concat{0}, ret_forward[0], ret_forward[1]);
                    ret_reverse[0] =
                        prog.insert_instruction(ins, op::concat{0}, ret_reverse[1], ret_reverse[0]);
Shucai Xiao's avatar
Shucai Xiao committed
105
106
                    hidden_output = prog.replace_instruction(
                        ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
Shucai Xiao's avatar
Shucai Xiao committed
107
                }
Shucai Xiao's avatar
Shucai Xiao committed
108
                map_last_output[hidden_output] = last_output;
Shucai Xiao's avatar
Shucai Xiao committed
109
110
111
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
112
                bool is_forward = (dicrt == op::rnn::rnn_direction_t::forward);
Shucai Xiao's avatar
Shucai Xiao committed
113
114
115
116
117
118
119
120
                // input weight matrix
                auto w = args[1];

                // hidden state weight matrix
                auto r = args[2];

                // process bias and initial hidden state
                instruction_ref bias = prog.end();
121
                if(args.size() >= 4 && args[3]->get_operator().name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
122
123
124
125
126
127
                {
                    bias = args[3];
                }

                // process intial hidden state
                instruction_ref ih;
128
                if(args.size() == 6  && args[5]->get_operator().name() != "undefined")
Shucai Xiao's avatar
Shucai Xiao committed
129
                {
130
                    ih = args[5];
Shucai Xiao's avatar
Shucai Xiao committed
131
132
133
134
135
136
                }
                else
                {
                    ih = prog.add_literal(migraphx::literal{ih_shape, data});
                }

Shucai Xiao's avatar
Shucai Xiao committed
137
138
                auto ret = rnn_cell(
                    is_forward, prog, ins, args[0], w, r, bias, ih, rnn_op.actv_funcs.at(0));
Shucai Xiao's avatar
Shucai Xiao committed
139
                auto last_output = prog.insert_instruction(ins, op::squeeze{{0}}, ret[1]);
Shucai Xiao's avatar
Shucai Xiao committed
140

Shucai Xiao's avatar
Shucai Xiao committed
141
                // following logic is to ensure the last instruction is a
Shucai Xiao's avatar
Shucai Xiao committed
142
143
                // concat instruction
                // sequence len is 1
Shucai Xiao's avatar
Shucai Xiao committed
144
                instruction_ref hidden_output{};
Shucai Xiao's avatar
Shucai Xiao committed
145
                if(ret[0] == prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
146
                {
Shucai Xiao's avatar
Shucai Xiao committed
147
                    hidden_output = prog.replace_instruction(ins, op::concat{0}, ret[1]);
Shucai Xiao's avatar
Shucai Xiao committed
148
149
150
151
152
                }
                else
                {
                    auto concat_arg0 = is_forward ? ret[0] : ret[1];
                    auto concat_arg1 = is_forward ? ret[1] : ret[0];
Shucai Xiao's avatar
Shucai Xiao committed
153
154
                    hidden_output =
                        prog.replace_instruction(ins, op::concat{0}, concat_arg0, concat_arg1);
Shucai Xiao's avatar
Shucai Xiao committed
155
                }
Shucai Xiao's avatar
Shucai Xiao committed
156
                map_last_output[hidden_output] = last_output;
Shucai Xiao's avatar
Shucai Xiao committed
157
158
159
            }
        }

Shucai Xiao's avatar
Shucai Xiao committed
160
        // rewrite the rnn_last_output operator that right after the rnn
Shucai Xiao's avatar
Shucai Xiao committed
161
162
163
        // operator. Intuitively, we can do a slice on the input to get
        // the last output, but it is already existed in the rnn operator,
        // so we can just use it as the output here
Shucai Xiao's avatar
Shucai Xiao committed
164
        if(ins->name() == "rnn_last_output")
Shucai Xiao's avatar
Shucai Xiao committed
165
        {
Shucai Xiao's avatar
Shucai Xiao committed
166
167
168
            auto inputs = ins->inputs();
            assert(inputs.size() == 1);
            auto arg = inputs[0];
Shucai Xiao's avatar
Shucai Xiao committed
169
            if(map_last_output.count(arg) == 0)
Shucai Xiao's avatar
Shucai Xiao committed
170
            {
Shucai Xiao's avatar
Shucai Xiao committed
171
                MIGRAPHX_THROW("RNN_LAST_OUTPUT: no related rnn operator as its input");
Shucai Xiao's avatar
Shucai Xiao committed
172
            }
Shucai Xiao's avatar
Shucai Xiao committed
173
174

            prog.replace_instruction(ins, map_last_output[arg]);
Shucai Xiao's avatar
Shucai Xiao committed
175
        }
Shucai Xiao's avatar
Shucai Xiao committed
176
177
178
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
179
std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
180
181
182
                                                   program& prog,
                                                   instruction_ref ins,
                                                   instruction_ref input,
Shucai Xiao's avatar
Shucai Xiao committed
183
184
                                                   instruction_ref w,
                                                   instruction_ref r,
Shucai Xiao's avatar
Shucai Xiao committed
185
                                                   instruction_ref bias,
Shucai Xiao's avatar
Shucai Xiao committed
186
                                                   instruction_ref ih,
Shucai Xiao's avatar
Shucai Xiao committed
187
                                                   operation& actv_func) const
Shucai Xiao's avatar
Shucai Xiao committed
188
{
Shucai Xiao's avatar
Shucai Xiao committed
189
190
    // squeeze and transpose w
    std::vector<int64_t> perm{1, 0};
Shucai Xiao's avatar
Shucai Xiao committed
191
    auto sw      = prog.insert_instruction(ins, op::squeeze{{0}}, w);
Shucai Xiao's avatar
Shucai Xiao committed
192
    auto tran_sw = prog.insert_instruction(ins, op::transpose{perm}, sw);
Shucai Xiao's avatar
Shucai Xiao committed
193
194

    // squeeze and transpose r
Shucai Xiao's avatar
Shucai Xiao committed
195
    auto sr      = prog.insert_instruction(ins, op::squeeze{{0}}, r);
Shucai Xiao's avatar
Shucai Xiao committed
196
197
198
199
200
201
    auto tran_sr = prog.insert_instruction(ins, op::transpose{perm}, sr);

    // initial hidden state
    auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);

    // bias
Shucai Xiao's avatar
Shucai Xiao committed
202
    if(bias != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
203
    {
Shucai Xiao's avatar
Shucai Xiao committed
204
205
206
207
208
209
        long hs    = r->get_shape().lens()[2];
        auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
        auto wb    = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
        auto rb    = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
        auto b     = prog.insert_instruction(ins, op::add{}, wb, rb);
        bias       = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, b);
Shucai Xiao's avatar
Shucai Xiao committed
210
211
    }

Shucai Xiao's avatar
Shucai Xiao committed
212
    instruction_ref hidden_out = prog.end(), last_out;
Shucai Xiao's avatar
Shucai Xiao committed
213
    last_out                   = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);
Shucai Xiao's avatar
Shucai Xiao committed
214
    std::size_t seq_len        = input->get_shape().lens()[0];
Shucai Xiao's avatar
Shucai Xiao committed
215
216
    for(std::size_t i = 0; i < seq_len; i++)
    {
Shucai Xiao's avatar
Shucai Xiao committed
217
        long seq_index = is_forward ? i : (seq_len - 1 - i);
Shucai Xiao's avatar
Shucai Xiao committed
218
219
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
        xt      = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
Shucai Xiao's avatar
Shucai Xiao committed
220
221
222
223
        auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
        auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
        auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
        instruction_ref ht;
Shucai Xiao's avatar
Shucai Xiao committed
224
225
        if(bias != prog.end())
        {
Shucai Xiao's avatar
Shucai Xiao committed
226
            ht = prog.insert_instruction(ins, op::add{}, xt_ht, bias);
Shucai Xiao's avatar
Shucai Xiao committed
227
228
229
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
230
            ht = xt_ht;
Shucai Xiao's avatar
Shucai Xiao committed
231
232
233
        }

        // apply activation function
Shucai Xiao's avatar
Shucai Xiao committed
234
        ht  = prog.insert_instruction(ins, actv_func, ht);
Shucai Xiao's avatar
Shucai Xiao committed
235
        sih = ht;
Shucai Xiao's avatar
Shucai Xiao committed
236

Shucai Xiao's avatar
Shucai Xiao committed
237
238
239
        // add the dimensions of sequence length (axis 0 for sequence length,
        // axis 1 for num_directions
        last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, ht);
Shucai Xiao's avatar
Shucai Xiao committed
240

Shucai Xiao's avatar
Shucai Xiao committed
241
242
243
        // concatenation for the last last_out is performed in the apply()
        // function to ensure the last instruction is concat, then we have
        // output inserted
Shucai Xiao's avatar
Shucai Xiao committed
244
        if(i < seq_len - 1)
Shucai Xiao's avatar
Shucai Xiao committed
245
        {
Shucai Xiao's avatar
Shucai Xiao committed
246
247
            if(is_forward)
            {
Shucai Xiao's avatar
Shucai Xiao committed
248
249
250
251
                hidden_out =
                    (seq_index == 0)
                        ? last_out
                        : prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out);
Shucai Xiao's avatar
Shucai Xiao committed
252
253
254
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
255
256
257
258
                hidden_out =
                    (seq_index == seq_len - 1)
                        ? last_out
                        : prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out);
Shucai Xiao's avatar
Shucai Xiao committed
259
            }
Shucai Xiao's avatar
Shucai Xiao committed
260
261
262
        }
    }

263
    return {hidden_out, last_out};
Shucai Xiao's avatar
Shucai Xiao committed
264
265
266
267
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx