rewrite_rnn.cpp 10.6 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

void rewrite_rnn::apply(program& prog) const
{
    for(auto ins : iterator_for(prog))
    {
        if(ins->name() != "rnn")
        {
            continue;
        }

        // could be 3 to 5 inputs (though onnx::rnn has 6 inputs,
Shucai Xiao's avatar
Shucai Xiao committed
21
        // the 5th one is undefined and ignored by protobuf. so
Shucai Xiao's avatar
Shucai Xiao committed
22
23
24
        // we need to process up to 5 inputs
        auto args = ins->inputs();

Shucai Xiao's avatar
Shucai Xiao committed
25
26
        shape seq_shape         = args[0]->get_shape();
        shape wgt_shape         = args[1]->get_shape();
Shucai Xiao's avatar
Shucai Xiao committed
27
        std::size_t hidden_size = wgt_shape.lens()[1];
Shucai Xiao's avatar
Shucai Xiao committed
28
29
        std::size_t batch_size  = seq_shape.lens()[1];
        shape::type_t type      = seq_shape.type();
Shucai Xiao's avatar
Shucai Xiao committed
30
31
32
        migraphx::shape s{type, {batch_size, hidden_size}};
        std::vector<char> data(s.bytes(), 0);

Shucai Xiao's avatar
Shucai Xiao committed
33
        auto rnn_op                    = any_cast<op::rnn>(ins->get_operator());
Shucai Xiao's avatar
Shucai Xiao committed
34
        op::rnn::rnn_direction_t dicrt = rnn_op.direction;
Shucai Xiao's avatar
Shucai Xiao committed
35
        if(dicrt == op::rnn::rnn_direction_t::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        {
            std::vector<int64_t> perm{1, 0};
            // process input weight matrix
            // forward
            auto xw_forward       = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[1]);
            auto sxw_forward      = prog.insert_instruction(ins, op::squeeze{{0}}, xw_forward);
            auto trans_xw_forward = prog.insert_instruction(ins, op::transpose{perm}, sxw_forward);

            // reverse
            auto xw_reverse       = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[1]);
            auto sxw_reverse      = prog.insert_instruction(ins, op::squeeze{{0}}, xw_reverse);
            auto trans_xw_reverse = prog.insert_instruction(ins, op::transpose{perm}, sxw_reverse);

            // process hidden state weight matrix
            auto hw_forward       = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[2]);
            auto shw_forward      = prog.insert_instruction(ins, op::squeeze{{0}}, hw_forward);
            auto trans_hw_forward = prog.insert_instruction(ins, op::transpose{perm}, shw_forward);

            auto hw_reverse       = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
            auto shw_reverse      = prog.insert_instruction(ins, op::squeeze{{0}}, hw_reverse);
            auto trans_hw_reverse = prog.insert_instruction(ins, op::transpose{perm}, shw_reverse);

            // process bias
            instruction_ref bias_forward, bias_reverse;
            bias_forward = bias_reverse = prog.end();
            if(args.size() >= 4)
            {
                // forward
                long h_size    = static_cast<long>(hidden_size);
                auto b_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
                b_forward      = prog.insert_instruction(ins, op::squeeze{{0}}, b_forward);
Shucai Xiao's avatar
Shucai Xiao committed
67
68
69
70
                auto wbf = prog.insert_instruction(ins, op::slice{{0}, {0}, {h_size}}, b_forward);
                auto rbf =
                    prog.insert_instruction(ins, op::slice{{0}, {h_size}, {2 * h_size}}, b_forward);
                auto bf      = prog.insert_instruction(ins, op::add{}, wbf, rbf);
Shucai Xiao's avatar
Shucai Xiao committed
71
72
73
74
75
                bias_forward = prog.insert_instruction(ins, op::broadcast{1, s}, bf);

                // backward
                auto b_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[3]);
                b_reverse      = prog.insert_instruction(ins, op::squeeze{{0}}, b_reverse);
Shucai Xiao's avatar
Shucai Xiao committed
76
77
78
79
                auto wbr = prog.insert_instruction(ins, op::slice{{0}, {0}, {h_size}}, b_reverse);
                auto rbr =
                    prog.insert_instruction(ins, op::slice{{0}, {h_size}, {2 * h_size}}, b_reverse);
                auto br      = prog.insert_instruction(ins, op::add{}, wbr, rbr);
Shucai Xiao's avatar
Shucai Xiao committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
                bias_reverse = prog.insert_instruction(ins, op::broadcast{1, s}, br);
            }

            // process intial hidden state
            instruction_ref ih_forward, ih_reverse;
            if(args.size() >= 5)
            {
                // forward
                ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[4]);
                ih_forward = prog.insert_instruction(ins, op::squeeze{{0}}, ih_forward);

                // reverse
                ih_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[4]);
                ih_reverse = prog.insert_instruction(ins, op::squeeze{{0}}, ih_reverse);
            }
            else
            {
                ih_forward = prog.add_literal(migraphx::literal{s, data});
                ih_reverse = prog.add_literal(migraphx::literal{s, data});
            }

            auto ret_forward = rnn_oper(true,
                                        prog,
                                        ins,
                                        args[0],
                                        trans_xw_forward,
                                        trans_hw_forward,
                                        ih_forward,
                                        bias_forward,
109
                                        rnn_op.actv_funcs.at(0));
Shucai Xiao's avatar
Shucai Xiao committed
110
111
112
113
114
115
116
117
            auto ret_reverse = rnn_oper(false,
                                        prog,
                                        ins,
                                        args[0],
                                        trans_xw_reverse,
                                        trans_hw_reverse,
                                        ih_reverse,
                                        bias_reverse,
118
                                        rnn_op.actv_funcs.at(1));
Shucai Xiao's avatar
Shucai Xiao committed
119
120
121
122
123
124
125
126
127
128
129
130

            // auto final_output = prog.insert_instruction(ins, op::concat{0}, ret_forward[1],

            // add the dimension of num_direction
            ret_forward[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_forward[0]);
            ret_reverse[0] = prog.insert_instruction(ins, op::unsqueeze{{1}}, ret_reverse[0]);

            // concat the forward and reverse output
            prog.replace_instruction(ins, op::concat{1}, {ret_forward[0], ret_reverse[0]});
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
131
            bool is_forward = (dicrt == op::rnn::forward) ? true : false;
Shucai Xiao's avatar
Shucai Xiao committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
            std::vector<int64_t> perm{1, 0};
            // process input weight matrix
            auto sxw      = prog.insert_instruction(ins, op::squeeze{{0}}, args[1]);
            auto trans_xw = prog.insert_instruction(ins, op::transpose{perm}, sxw);

            // process hidden state weight matrix
            auto shw      = prog.insert_instruction(ins, op::squeeze{{0}}, args[2]);
            auto trans_hw = prog.insert_instruction(ins, op::transpose{perm}, shw);

            // process bias and initial hidden state
            instruction_ref bias = prog.end();
            if(args.size() >= 4)
            {
                long h_size = static_cast<long>(hidden_size);
                auto bwr    = prog.insert_instruction(ins, op::squeeze{{0}}, args[3]);
                auto wb     = prog.insert_instruction(ins, op::slice{{0}, {0}, {h_size}}, bwr);
Shucai Xiao's avatar
Shucai Xiao committed
148
149
150
                auto rb = prog.insert_instruction(ins, op::slice{{0}, {h_size}, {2 * h_size}}, bwr);
                auto b  = prog.insert_instruction(ins, op::add{}, wb, rb);
                bias    = prog.insert_instruction(ins, op::broadcast{1, s}, b);
Shucai Xiao's avatar
Shucai Xiao committed
151
152
153
154
155
156
157
158
159
160
161
162
            }

            // process intial hidden state
            instruction_ref ih;
            if(args.size() >= 5)
            {
                ih = prog.insert_instruction(ins, op::squeeze{{0}}, args[4]);
            }
            else
            {
                ih = prog.add_literal(migraphx::literal{s, data});
            }
Shucai Xiao's avatar
Shucai Xiao committed
163
164
165
166
167
168
169
170
171
            auto ret = rnn_oper(is_forward,
                                prog,
                                ins,
                                args[0],
                                trans_xw,
                                trans_hw,
                                ih,
                                bias,
                                rnn_op.actv_funcs.at(0));
Shucai Xiao's avatar
Shucai Xiao committed
172
173
174
175
176
177
178
179

            // add the dimension of num_direction
            prog.replace_instruction(ins, op::unsqueeze{{1}}, ret[0]);
        }
    }
}

std::vector<instruction_ref> rewrite_rnn::rnn_oper(bool is_forward,
Shucai Xiao's avatar
Shucai Xiao committed
180
181
182
183
184
185
186
187
                                                   program& prog,
                                                   instruction_ref ins,
                                                   instruction_ref input,
                                                   instruction_ref wx,
                                                   instruction_ref wh,
                                                   instruction_ref ih,
                                                   instruction_ref bias,
                                                   operation& actv_func) const
Shucai Xiao's avatar
Shucai Xiao committed
188
189
190
191
192
193
194
{
    instruction_ref hidden_out, final_out;
    migraphx::shape input_shape = input->get_shape();
    std::size_t seq_len         = input_shape.lens()[0];
    long seq_index              = is_forward ? 0 : seq_len - 1;
    for(std::size_t i = 0; i < seq_len; i++)
    {
Shucai Xiao's avatar
Shucai Xiao committed
195
196
        auto xt = prog.insert_instruction(ins, op::slice{{0}, {seq_index}, {seq_index + 1}}, input);
        xt      = prog.insert_instruction(ins, op::squeeze{{0}}, xt);
Shucai Xiao's avatar
Shucai Xiao committed
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
        auto x_w = prog.insert_instruction(ins, op::dot{}, xt, wx);
        auto h_r = prog.insert_instruction(ins, op::dot{}, ih, wh);
        auto x_h = prog.insert_instruction(ins, op::add{}, x_w, h_r);
        instruction_ref before_actv;
        if(bias != prog.end())
        {
            before_actv = prog.insert_instruction(ins, op::add{}, x_h, bias);
        }
        else
        {
            before_actv = x_h;
        }

        // apply activation function
        ih = prog.insert_instruction(ins, actv_func, before_actv);

        // add the dimension of sequence length
        auto output = prog.insert_instruction(ins, op::unsqueeze{{0}}, ih);
        final_out   = output;

        if(is_forward)
        {
            hidden_out = (seq_index == 0)
Shucai Xiao's avatar
Shucai Xiao committed
220
221
                             ? output
                             : prog.insert_instruction(ins, op::concat{0}, hidden_out, output);
Shucai Xiao's avatar
Shucai Xiao committed
222
223
224
225
        }
        else
        {
            hidden_out = (seq_index == seq_len - 1)
Shucai Xiao's avatar
Shucai Xiao committed
226
227
                             ? output
                             : prog.insert_instruction(ins, op::concat{0}, output, hidden_out);
Shucai Xiao's avatar
Shucai Xiao committed
228
229
230
231
232
233
234
235
236
237
238
239
240
        }
        seq_index = is_forward ? (seq_index + 1) : (seq_index - 1);
    }

    std::vector<instruction_ref> out_args;
    out_args.push_back(hidden_out);
    out_args.push_back(final_out);

    return out_args;
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx