rewrite_rnn.cpp 10 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
4
5
6
7
8
9
10
11
12
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

void rewrite_rnn::apply(program& prog) const
{
Shucai Xiao's avatar
Shucai Xiao committed
13
    instruction_ref last_output = prog.end();
Shucai Xiao's avatar
Shucai Xiao committed
14
15
    for(auto ins : iterator_for(prog))
    {
Shucai Xiao's avatar
Shucai Xiao committed
16
17
        // rewrite rnn operator
        if(ins->name() == "rnn")
Shucai Xiao's avatar
Shucai Xiao committed
18
        {
Shucai Xiao's avatar
Shucai Xiao committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
            // could be 3 to 6 inputs, but the 5th input is undefined in 
            // pytorch exported onnx, and it is ignored by protobuf. So 
            // for input arguments 5 and 6, we need to check the shape, 
            // then based on the shape to judge the specific input info
            auto args = ins->inputs();

            shape seq_shape         = args[0]->get_shape();
            std::size_t hidden_size = args[1]->get_shape().lens()[1];
            std::size_t batch_size  = seq_shape.lens()[1];
            shape::type_t type      = seq_shape.type();
            migraphx::shape ih_shape{type, {batch_size, hidden_size}};
            std::vector<char> data(ih_shape.bytes(), 0);

            auto rnn_op                    = any_cast<op::rnn>(ins->get_operator());
            op::rnn::rnn_direction_t dicrt = rnn_op.direction;
            if(dicrt == op::rnn::rnn_direction_t::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
35
            {
Shucai Xiao's avatar
Shucai Xiao committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
                // input weight matrix
                auto w_forward       = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
                auto w_reverse       = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);

                // hidden state weight matrix
                auto r_forward       = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
                auto r_reverse       = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);

                // process bias
                instruction_ref bias_forward, bias_reverse;
                bias_forward = bias_reverse = prog.end();
                if(args.size() >= 4)
                {
                    bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
                    bias_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
                }

                // process intial hidden state, it could be the 6th argument
                // or the 5th one (if the sequence len argument is ignored)
                instruction_ref ih_forward, ih_reverse;
                if(args.size() == 6 || (args.size() == 5 && args[4]->get_shape().lens().size() == 3))
                {
                    auto arg_ih = (args.size() == 6) ? args[5] : args[4];
                    ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, arg_ih);
                    ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, arg_ih);
                }
                else
                {
                    ih_forward = prog.add_literal(migraphx::literal{ih_shape, data});
                    ih_reverse = prog.add_literal(migraphx::literal{ih_shape, data});
                }

                auto ret_forward = rnn_cell(true,
                                            prog,
                                            ins,
                                            args[0],
                                            w_forward,
                                            r_forward,
                                            bias_forward,
                                            ih_forward,
                                            rnn_op.actv_funcs.at(0));
                auto ret_reverse = rnn_cell(false,
                                            prog,
                                            ins,
                                            args[0],
                                            w_reverse,
                                            r_reverse,
                                            bias_reverse,
                                            ih_reverse,
                                            rnn_op.actv_funcs.at(1));

                last_output = prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]);

                // add the dimension of num_direction
                ret_forward[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_forward[0]);
                ret_reverse[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_reverse[0]);

                // concat the forward and reverse output
                prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
Shucai Xiao's avatar
Shucai Xiao committed
95
96
97
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
                bool is_forward = (dicrt == op::rnn::rnn_direction_t::forward) ? true : false;
                // input weight matrix
                auto w = args[1];

                // hidden state weight matrix
                auto r = args[2];

                // process bias and initial hidden state
                instruction_ref bias = prog.end();
                if(args.size() >= 4)
                {
                    bias = args[3];
                }

                // process intial hidden state
                instruction_ref ih;
                if(args.size() == 6 || (args.size() == 5 && args[4]->get_shape().lens().size() == 3))
                {
                    ih = (args.size() == 6) ? args[5] : args[4];
                }
                else
                {
                    ih = prog.add_literal(migraphx::literal{ih_shape, data});
                }

                auto ret = rnn_cell(is_forward,
                                    prog,
                                    ins,
                                    args[0],
                                    w,
                                    r,
                                    bias,
                                    ih,
                                    rnn_op.actv_funcs.at(0));
                last_output = ret[1];

                // add the dimension of num_direction
                prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]);
Shucai Xiao's avatar
Shucai Xiao committed
136
137
138
            }
        }

Shucai Xiao's avatar
Shucai Xiao committed
139
140
141
142
143
144
145
146
147
148
149
150
151
        // rewrite the rnn_last_output operator that right after the rnn 
        // operator. Intuitively, we can do a slice on the input to get
        // the last output, but it is already existed in the rnn operator,
        // so we can just use it as the output here
        //if (ins->name() == "rnn_last_output")
        //{
        //    // if rnn operator is executed, the last_output != prog.end()
        //    if (last_output != prog.end())
        //    {
        //        prog.replace_instruction(ins, op::identity{}, last_output);
        //        last_output = prog.end();
        //    }
        //}
Shucai Xiao's avatar
Shucai Xiao committed
152
153
154
    }
}

Shucai Xiao's avatar
Shucai Xiao committed
155
std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
156
157
158
                                                   program& prog,
                                                   instruction_ref ins,
                                                   instruction_ref input,
Shucai Xiao's avatar
Shucai Xiao committed
159
160
                                                   instruction_ref w,
                                                   instruction_ref r,
Shucai Xiao's avatar
Shucai Xiao committed
161
                                                   instruction_ref bias,
Shucai Xiao's avatar
Shucai Xiao committed
162
                                                   instruction_ref ih,
Shucai Xiao's avatar
Shucai Xiao committed
163
                                                   operation& actv_func) const
Shucai Xiao's avatar
Shucai Xiao committed
164
{
Shucai Xiao's avatar
Shucai Xiao committed
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    // squeeze and transpose w
    std::vector<int64_t> perm{1, 0};
    auto sw = prog.insert_instruction(ins, op::squeeze{{0}}, w);
    auto tran_sw = prog.insert_instruction(sw, op::transpose{perm}, sw);

    // squeeze and transpose r
    auto sr = prog.insert_instruction(ins, op::squeeze{{0}}, r);
    auto tran_sr = prog.insert_instruction(ins, op::transpose{perm}, sr);

    // initial hidden state
    auto sih = prog.insert_instruction(ins, op::squeeze{{0}}, ih);

    // bias
    if (bias != prog.end()) 
    {
        long hs = r->get_shape().lens()[2];
        auto sbias      = prog.insert_instruction(ins, op::squeeze{{0}}, bias);
        auto wb = prog.insert_instruction(ins, op::slice{{0}, {0}, {hs}}, sbias);
        auto rb =
            prog.insert_instruction(ins, op::slice{{0}, {hs}, {2*hs}}, sbias);
        auto b      = prog.insert_instruction(ins, op::add{}, wb, rb);
        bias = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, b);
    }

    instruction_ref hidden_out, last_out;
    std::size_t seq_len         = input->get_shape().lens()[0];
Shucai Xiao's avatar
Shucai Xiao committed
191
192
193
    long seq_index              = is_forward ? 0 : seq_len - 1;
    for(std::size_t i = 0; i < seq_len; i++)
    {
Shucai Xiao's avatar
Shucai Xiao committed
194
195
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
        xt      = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
Shucai Xiao's avatar
Shucai Xiao committed
196
197
198
199
        auto xt_wi = prog.insert_instruction(ins, op::dot{}, xt, tran_sw);
        auto ht_ri = prog.insert_instruction(ins, op::dot{}, sih, tran_sr);
        auto xt_ht = prog.insert_instruction(ins, op::add{}, xt_wi, ht_ri);
        instruction_ref ht;
Shucai Xiao's avatar
Shucai Xiao committed
200
201
        if(bias != prog.end())
        {
Shucai Xiao's avatar
Shucai Xiao committed
202
            ht = prog.insert_instruction(ins, op::add{}, xt_ht, bias);
Shucai Xiao's avatar
Shucai Xiao committed
203
204
205
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
206
            ht = xt_ht;
Shucai Xiao's avatar
Shucai Xiao committed
207
208
209
        }

        // apply activation function
Shucai Xiao's avatar
Shucai Xiao committed
210
211
        ht = prog.insert_instruction(ins, actv_func, ht);
        sih = ht;
Shucai Xiao's avatar
Shucai Xiao committed
212
213

        // add the dimension of sequence length
Shucai Xiao's avatar
Shucai Xiao committed
214
        last_out = prog.insert_instruction(ins, op::unsqueeze{{0}}, ht);
Shucai Xiao's avatar
Shucai Xiao committed
215
216
217
218

        if(is_forward)
        {
            hidden_out = (seq_index == 0)
Shucai Xiao's avatar
Shucai Xiao committed
219
220
                             ? last_out
                             : prog.insert_instruction(ins, op::concat{0}, hidden_out, last_out);
Shucai Xiao's avatar
Shucai Xiao committed
221
222
223
224
        }
        else
        {
            hidden_out = (seq_index == seq_len - 1)
Shucai Xiao's avatar
Shucai Xiao committed
225
226
                             ? last_out
                             : prog.insert_instruction(ins, op::concat{0}, last_out, hidden_out);
Shucai Xiao's avatar
Shucai Xiao committed
227
228
229
230
231
232
        }
        seq_index = is_forward ? (seq_index + 1) : (seq_index - 1);
    }

    std::vector<instruction_ref> out_args;
    out_args.push_back(hidden_out);
Shucai Xiao's avatar
Shucai Xiao committed
233
    out_args.push_back(last_out);
Shucai Xiao's avatar
Shucai Xiao committed
234
235
236
237
238
239

    return out_args;
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx