rewrite_rnn.cpp 9.78 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
4
5
6
7
8
9
10
11
12
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

void rewrite_rnn::apply(program& prog) const
{
Shucai Xiao's avatar
Shucai Xiao committed
13
    instruction_ref last_output = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
14
15
    for(auto ins : iterator_for(prog))
    {
Shucai Xiao's avatar
Shucai Xiao committed
16
17
        // rewrite rnn operator
        if(ins->name() == "rnn")
Shucai Xiao's avatar
Shucai Xiao committed
18
        {
Shucai Xiao's avatar
Shucai Xiao committed
19
20
21
            // could be 3 to 6 inputs, but the 5th input is undefined in
            // pytorch exported onnx, and it is ignored by protobuf. So
            // for input arguments 5 and 6, we need to check the shape,
Shucai Xiao's avatar
Shucai Xiao committed
22
23
24
25
26
27
28
29
30
31
32
33
34
            // then based on the shape to judge the specific input info
            auto args = ins->inputs();

            shape seq_shape         = args[0]->get_shape();
            std::size_t hidden_size = args[1]->get_shape().lens()[1];
            std::size_t batch_size  = seq_shape.lens()[1];
            shape::type_t type      = seq_shape.type();
            migraphx::shape ih_shape{type, {batch_size, hidden_size}};
            std::vector<char> data(ih_shape.bytes(), 0);

            auto rnn_op                    = any_cast<op::rnn>(ins->get_operator());
            op::rnn::rnn_direction_t dicrt = rnn_op.direction;
            if(dicrt == op::rnn::rnn_direction_t::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
35
            {
Shucai Xiao's avatar
Shucai Xiao committed
36
                // input weight matrix
Shucai Xiao's avatar
Shucai Xiao committed
37
38
                auto w_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
                auto w_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
Shucai Xiao's avatar
Shucai Xiao committed
39
40

                // hidden state weight matrix
Shucai Xiao's avatar
Shucai Xiao committed
41
42
                auto r_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
                auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
Shucai Xiao's avatar
Shucai Xiao committed
43
44
45
46
47
48
49
50
51
52
53
54
55

                // process bias
                instruction_ref bias_forward, bias_reverse;
                bias_forward = bias_reverse = prog.end();
                if(args.size() >= 4)
                {
                    bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
                    bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
                }

                // process intial hidden state, it could be the 6th argument
                // or the 5th one (if the sequence len argument is ignored)
                instruction_ref ih_forward, ih_reverse;
Shucai Xiao's avatar
Shucai Xiao committed
56
57
                if(args.size() == 6 ||
                   (args.size() == 5 && args[4]->get_shape().lens().size() == 3))
Shucai Xiao's avatar
Shucai Xiao committed
58
59
                {
                    auto arg_ih = (args.size() == 6) ? args[5] : args[4];
Shucai Xiao's avatar
Shucai Xiao committed
60
61
                    ih_forward  = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, arg_ih);
                    ih_reverse  = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, arg_ih);
Shucai Xiao's avatar
Shucai Xiao committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
                }
                else
                {
                    ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
                    ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
                }

                auto ret_forward = rnn_cell(true,
                                            prog,
                                            ins,
                                            args[0],
                                            w_forward,
                                            r_forward,
                                            bias_forward,
                                            ih_forward,
                                            rnn_op.actv_funcs.at(0));
                auto ret_reverse = rnn_cell(false,
                                            prog,
                                            ins,
                                            args[0],
                                            w_reverse,
                                            r_reverse,
                                            bias_reverse,
                                            ih_reverse,
                                            rnn_op.actv_funcs.at(1));

Shucai Xiao's avatar
Shucai Xiao committed
88
89
                last_output =
                    prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]);
Shucai Xiao's avatar
Shucai Xiao committed
90
91
92
93
94
95
96

                // add the dimension of num_direction
                ret_forward[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_forward[0]);
                ret_reverse[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_reverse[0]);

                // concat the forward and reverse output
                prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
Shucai Xiao's avatar
Shucai Xiao committed
97
98
99
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
                bool is_forward = (dicrt == op::rnn::rnn_direction_t::forward) ? true : false;
                // input weight matrix
                auto w = args[1];

                // hidden state weight matrix
                auto r = args[2];

                // process bias and initial hidden state
                instruction_ref bias = prog.end();
                if(args.size() >= 4)
                {
                    bias = args[3];
                }

                // process intial hidden state
                instruction_ref ih;
Shucai Xiao's avatar
Shucai Xiao committed
116
117
                if(args.size() == 6 ||
                   (args.size() == 5 && args[4]->get_shape().lens().size() == 3))
Shucai Xiao's avatar
Shucai Xiao committed
118
119
120
121
122
123
124
125
                {
                    ih = (args.size() == 6) ? args[5] : args[4];
                }
                else
                {
                    ih = prog.add_literal(migraphx::literal{ih_shape, data});
                }

Shucai Xiao's avatar
Shucai Xiao committed
126
127
                auto ret = rnn_cell(
                    is_forward, prog, ins, args[0], w, r, bias, ih, rnn_op.actv_funcs.at(0));
Shucai Xiao's avatar
Shucai Xiao committed
128
129
130
131
                last_output = ret[1];

                // add the dimension of num_direction
                prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]);
Shucai Xiao's avatar
Shucai Xiao committed
132
133
134
            }
        }

Shucai Xiao's avatar
Shucai Xiao committed
135
        // rewrite the rnn_last_output operator that right after the rnn
Shucai Xiao's avatar
Shucai Xiao committed
136
137
138
        // operator. Intuitively, we can do a slice on the input to get
        // the last output, but it is already existed in the rnn operator,
        // so we can just use it as the output here
Shucai Xiao's avatar
Shucai Xiao committed
139
        // if (ins->name() == "rnn_last_output")
Shucai Xiao's avatar
Shucai Xiao committed
140
141
142
143
144
145
146
147
        //{
        //    // if rnn operator is executed, the last_output != prog.end()
        //    if (last_output != prog.end())
        //    {
        //        prog.replace_instruction(ins, op::identity{}, last_output);
        //        last_output = prog.end();
        //    }
        //}
Shucai Xiao's avatar
Shucai Xiao committed
148
149
150
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
151
std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
152
153
154
                                                   program& prog,
                                                   instruction_ref ins,
                                                   instruction_ref input,
Shucai Xiao's avatar
Shucai Xiao committed
155
156
                                                   instruction_ref w,
                                                   instruction_ref r,
Shucai Xiao's avatar
Shucai Xiao committed
157
                                                   instruction_ref bias,
Shucai Xiao's avatar
Shucai Xiao committed
158
                                                   instruction_ref ih,
Shucai Xiao's avatar
Shucai Xiao committed
159
                                                   operation& actv_func) const
Shucai Xiao's avatar
Shucai Xiao committed
160
{
Shucai Xiao's avatar
Shucai Xiao committed
161
162
    // squeeze and transpose w
    std::vector<int64_t> perm{1, 0};
Shucai Xiao's avatar
Shucai Xiao committed
163
    auto sw      = prog.insert_instruction(ins, op::squeeze{{0}}, w);
Shucai Xiao's avatar
Shucai Xiao committed
164
165
166
    auto tran_sw = prog.insert_instruction(sw, op::transpose{perm}, sw);

    // squeeze and transpose r
Shucai Xiao's avatar
Shucai Xiao committed
167
    auto sr      = prog.insert_instruction(ins, op::squeeze{{0}}, r);
Shucai Xiao's avatar
Shucai Xiao committed
168
169
170
171
172
173
    auto tran_sr = prog.insert_instruction(ins, op::transpose{perm}, sr);

    // initial hidden state
    auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);

    // bias
Shucai Xiao's avatar
Shucai Xiao committed
174
    if(bias != prog.end())
Shucai Xiao's avatar
Shucai Xiao committed
175
    {
Shucai Xiao's avatar
Shucai Xiao committed
176
177
178
179
180
181
        long hs    = r->get_shape().lens()[2];
        auto sbias = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
        auto wb    = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
        auto rb    = prog.insert_instruction(ins, op::slice{{0}, {hs}, {2 * hs}}, sbias);
        auto b     = prog.insert_instruction(ins, op::add{}, wb, rb);
        bias       = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, b);
Shucai Xiao's avatar
Shucai Xiao committed
182
183
184
    }

    instruction_ref hidden_out, last_out;
Shucai Xiao's avatar
Shucai Xiao committed
185
186
    std::size_t seq_len = input->get_shape().lens()[0];
    long seq_index      = is_forward ? 0 : seq_len - 1;
Shucai Xiao's avatar
Shucai Xiao committed
187
188
    for(std::size_t i = 0; i < seq_len; i++)
    {
Shucai Xiao's avatar
Shucai Xiao committed
189
190
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
        xt      = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
Shucai Xiao's avatar
Shucai Xiao committed
191
192
193
194
        auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
        auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
        auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
        instruction_ref ht;
Shucai Xiao's avatar
Shucai Xiao committed
195
196
        if(bias != prog.end())
        {
Shucai Xiao's avatar
Shucai Xiao committed
197
            ht = prog.insert_instruction(ins, op::add{}, xt_ht, bias);
Shucai Xiao's avatar
Shucai Xiao committed
198
199
200
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
201
            ht = xt_ht;
Shucai Xiao's avatar
Shucai Xiao committed
202
203
204
        }

        // apply activation function
Shucai Xiao's avatar
Shucai Xiao committed
205
        ht  = prog.insert_instruction(ins, actv_func, ht);
Shucai Xiao's avatar
Shucai Xiao committed
206
        sih = ht;
Shucai Xiao's avatar
Shucai Xiao committed
207
208

        // add the dimension of sequence length
Shucai Xiao's avatar
Shucai Xiao committed
209
        last_out = prog.insert_instruction(ins, op::unsqueeze{{0}}, ht);
Shucai Xiao's avatar
Shucai Xiao committed
210
211
212
213

        if(is_forward)
        {
            hidden_out = (seq_index == 0)
Shucai Xiao's avatar
Shucai Xiao committed
214
215
                             ? last_out
                             : prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out);
Shucai Xiao's avatar
Shucai Xiao committed
216
217
218
219
        }
        else
        {
            hidden_out = (seq_index == seq_len - 1)
Shucai Xiao's avatar
Shucai Xiao committed
220
221
                             ? last_out
                             : prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out);
Shucai Xiao's avatar
Shucai Xiao committed
222
223
224
225
226
227
        }
        seq_index = is_forward ? (seq_index + 1) : (seq_index - 1);
    }

    std::vector<instruction_ref> out_args;
    out_args.push_back(hidden_out);
Shucai Xiao's avatar
Shucai Xiao committed
228
    out_args.push_back(last_out);
Shucai Xiao's avatar
Shucai Xiao committed
229
230
231
232
233
234

    return out_args;
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx