lowering.cpp 32.6 KB
Newer Older
Paul's avatar
Paul committed
1

Paul's avatar
Paul committed
2
3
4
#include <migraphx/cpu/lowering.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/dfor.hpp>
5
#include <migraphx/op/identity.hpp>
6
#include <migraphx/op/batch_norm_inference.hpp>
Paul's avatar
Paul committed
7
#include <migraphx/op/convolution.hpp>
kahmed10's avatar
kahmed10 committed
8
#include <migraphx/op/deconvolution.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
9
#include <migraphx/op/quant_convolution.hpp>
Paul's avatar
Paul committed
10
#include <migraphx/op/dot.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
11
#include <migraphx/op/quant_dot.hpp>
Paul's avatar
Paul committed
12
13
14
15
16
17
18
19
#include <migraphx/op/elu.hpp>
#include <migraphx/op/im2col.hpp>
#include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/logsoftmax.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/softmax.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
20
21
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
22
#include <migraphx/op/rnn_var_sl_last_output.hpp>
Paul's avatar
Paul committed
23
24
#include <migraphx/shape_for_each.hpp>
#include <migraphx/iterator_for.hpp>
Paul's avatar
Paul committed
25
#include <migraphx/par_dfor.hpp>
26
#include <migraphx/clamp.hpp>
Paul's avatar
Paul committed
27
#include <migraphx/cpu/gemm.hpp>
28
#include <migraphx/register_op.hpp>
29
#include <migraphx/make_op.hpp>
30
#include <migraphx/program.hpp>
Paul's avatar
Paul committed
31
#include <unordered_map>
Paul's avatar
Paul committed
32
#include <utility>
kahmed10's avatar
kahmed10 committed
33
#include <iostream>
Paul's avatar
Paul committed
34

Paul's avatar
Paul committed
35
namespace migraphx {
Paul's avatar
Paul committed
36
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
37
38
39
40
41
42
43
44
namespace cpu {

template <typename T>
T zero(const T&)
{
    return T(0);
}

Khalique's avatar
Khalique committed
45
46
47
48
template <class T>
typename std::conditional_t<std::is_integral<T>{}, std::make_signed<T>, std::enable_if<true, T>>::
    type
    make_signed(T x)
Khalique's avatar
Khalique committed
49
50
51
52
{
    return x;
}

53
54
55
56
//
// cpu implemenataion of batch norm for inference
//
// inputs are:
57
58
59
60
// args[0] -> input data buffer
// args[1] -> mini batch mean
// args[2] -> mini batch variance
// args[3] -> gamma
Aditya Atluri's avatar
Aditya Atluri committed
61
// args[4] -> bias
62
63
64
//
// The equation to compute batch norm for inference is:
//
Aditya Atluri's avatar
Aditya Atluri committed
65
// output[i] = bias + gamma * (input[i] + mean) / sqrt(variance + epsilon)
66
67
68
69
70
//
// the input data format should be nchw
//
struct cpu_batch_norm_inference
{
71
    op::batch_norm_inference op;
72

73
74
75
76
77
78
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op, f);
    }

79
80
    std::string name() const { return "cpu::batch_norm_inference"; }

Paul's avatar
Paul committed
81
    shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
82

Paul's avatar
Paul committed
83
    argument compute(context&, const shape& output_shape, std::vector<argument> args) const
84
    {
85
86
        argument output{output_shape};

Aditya Atluri's avatar
Aditya Atluri committed
87
88
        double epsilon           = op.epsilon;
        auto input               = args[0];
Paul's avatar
Paul committed
89
90
91
92
        auto arg_gamma           = args[1];
        auto arg_bias            = args[2];
        auto mini_batch_mean     = args[3];
        auto mini_batch_variance = args[4];
93

94
        if(op.bn_mode == op::batch_norm_inference::spatial)
Scott Thornton's avatar
Scott Thornton committed
95
96
97
        {
            visit_all(output, input, mini_batch_mean, mini_batch_variance, arg_gamma, arg_bias)(
                [&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) {
Shucai Xiao's avatar
Shucai Xiao committed
98
99
100
101
102
103
104
105
                    par_for(output_shape.elements(), [&](auto i) {
                        auto idx = output_shape.multi(i);
                        auto c   = idx[1];
                        assert((variance[c] + epsilon) > 0);
                        result[i] =
                            gamma[c] * (buffer[i] - mean[c]) / std::sqrt(variance[c] + epsilon) +
                            bias[c];
                    });
Scott Thornton's avatar
Scott Thornton committed
106
                });
107
108
        }

109
        if(op.bn_mode == op::batch_norm_inference::per_activation)
Scott Thornton's avatar
Scott Thornton committed
110
        {
Shucai Xiao's avatar
Shucai Xiao committed
111
            visit_all(output, input, mini_batch_mean, mini_batch_variance, arg_gamma, arg_bias)(
112
                [&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) {
Shucai Xiao's avatar
Shucai Xiao committed
113
114
115
116
117
118
119
120
121
122
                    par_for(output_shape.elements(), [&](auto i) {
                        auto idx   = output_shape.multi(i);
                        idx[0]     = 0;
                        auto index = output_shape.index(idx);

                        assert((variance[index] + epsilon) > 0);
                        result[i] = gamma[index] * (buffer[i] - mean[index]) /
                                        std::sqrt(variance[index] + epsilon) +
                                    bias[index];
                    });
Scott Thornton's avatar
Scott Thornton committed
123
                });
124
        }
125
126
127
128

        return output;
    }
};
129
MIGRAPHX_REGISTER_OP(cpu_batch_norm_inference)
130

Khalique's avatar
Khalique committed
131
struct cpu_lrn
Khalique's avatar
Khalique committed
132
{
Khalique's avatar
Khalique committed
133
    op::lrn op;
Khalique's avatar
Khalique committed
134

135
136
137
138
139
140
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op, f);
    }

Khalique's avatar
Khalique committed
141
    std::string name() const { return "cpu::lrn"; }
Khalique's avatar
Khalique committed
142
143
144
145
146
    shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
    argument compute(context&, shape output_shape, std::vector<argument> args) const
    {
        argument result{output_shape};
        visit_all(result, args[0])([&](auto output, auto input) {
Khalique's avatar
Khalique committed
147
148
149
150
            int n_batch         = output_shape.lens()[0];
            int channels        = output_shape.lens()[1];
            int height          = output_shape.lens()[2];
            int width           = output_shape.lens()[3];
Paul's avatar
Paul committed
151
            float alphaoverarea = op.alpha / float(op.size);
152
153
            int radius_lower    = (op.size - 1) / 2;
            int radius_upper    = op.size / 2 + 1;
Khalique's avatar
Khalique committed
154

155
            par_dfor(n_batch, height, width)([&](int b, int h, int w) {
Khalique's avatar
Khalique committed
156
                float scale = 0;
Khalique's avatar
Khalique committed
157
                dfor(channels)([&](int c) {
158
159
                    auto start = (c - radius_lower) < 0 ? 0 : (c - radius_lower);
                    auto end   = (c + radius_upper) > channels ? channels : (c + radius_upper);
Khalique's avatar
Khalique committed
160
161
                    for(auto k = start; k < end; ++k)
                    {
Khalique's avatar
Khalique committed
162
                        scale += std::pow(input(b, k, h, w), 2);
Khalique's avatar
Khalique committed
163
164
165
                    }
                    scale *= alphaoverarea;
                    scale += op.bias;
Khalique's avatar
Khalique committed
166
                    scale              = std::pow(scale, -op.beta);
Khalique's avatar
Khalique committed
167
168
169
170
171
172
173
                    output(b, c, h, w) = input(b, c, h, w) * scale;
                });
            });
        });
        return result;
    }
};
174
MIGRAPHX_REGISTER_OP(cpu_lrn)
Khalique's avatar
Khalique committed
175

Paul Fultz II's avatar
Paul Fultz II committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
template <class V, class T, class... Ts>
void visit_quantize_impl(V&& v, T&& x, Ts&&... xs)
{
    x.visit([&](auto y) { visit_all(xs...)([&](auto... ys) { v(y, ys...); }); });
}

template <class T, class... Ts>
auto visit_quantize(T&& x, Ts&&... xs)
{
    return [&](auto v) {
        // Workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70100
        visit_quantize_impl(v, x, xs...);
    };
}

191
template <class Op>
192
struct cpu_convolution : auto_register_op<cpu_convolution<Op>>
Paul's avatar
Paul committed
193
{
194
195
196
197
    cpu_convolution() = default;

    cpu_convolution(Op pop) : op(std::move(pop)) {}

198
    Op op;
199

200
201
202
203
204
205
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op, f);
    }

206
    std::string name() const { return "cpu::" + op.name(); }
207
208
209
210
    shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
    argument compute(context&, shape output_shape, std::vector<argument> args) const
    {
        argument result{output_shape};
Paul Fultz II's avatar
Paul Fultz II committed
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
        visit_quantize(result, args[0], args[1])([&](auto output, auto input, auto weights) {
            auto in_lens = input.get_shape().lens();

            auto wei_lens = weights.get_shape().lens();
            auto wei_n    = wei_lens[0];
            auto wei_c    = wei_lens[1];
            std::vector<std::size_t> win_size(wei_lens.begin() + 1, wei_lens.end());

            par_for(output_shape.elements(), [&](auto i) {
                auto idx_o = output_shape.multi(i);
                auto w     = idx_o[1];
                auto n_dim = idx_o.size();

                std::vector<std::ptrdiff_t> win_start;
                for(std::size_t dim = 2; dim < n_dim; ++dim)
                {
                    auto d_2 = dim - 2;
                    win_start.push_back(std::ptrdiff_t(idx_o[dim] * op.stride[d_2]) -
                                        std::ptrdiff_t(op.padding[d_2]));
                }
                const auto group_id = w / (wei_n / op.group);

                shape win_shape{output_shape.type(), win_size};

                double acc = 0.0;
                shape_for_each(win_shape, [&](auto idx_win) {
                    auto k           = idx_win[0];
                    const auto in_ch = group_id * wei_c + k;
                    std::vector<std::ptrdiff_t> idx(idx_o.begin(), idx_o.end());
                    idx[1] = in_ch;
                    std::transform(idx_win.begin() + 1,
                                   idx_win.end(),
                                   win_start.begin(),
                                   idx.begin() + 2,
                                   [](std::ptrdiff_t ii, std::ptrdiff_t jj) { return ii + jj; });
                    std::vector<std::ptrdiff_t> idx_wei(idx_o.size());
                    idx_wei[0] = w;
                    std::copy(idx_win.begin(), idx_win.end(), idx_wei.begin() + 1);
                    if(std::all_of(idx.begin() + 2, idx.end(), [&](auto ii) { return ii >= 0; }) and
                       std::equal(idx.begin(),
                                  idx.end(),
                                  in_lens.begin(),
                                  in_lens.end(),
                                  std::less<std::ptrdiff_t>{}))
                    {
                        acc +=
                            input(idx.begin(), idx.end()) * weights(idx_wei.begin(), idx_wei.end());
                    }
                });

                output[i] = acc;
262
            });
263
264
265
266
267
        });
        return result;
    }
};

kahmed10's avatar
kahmed10 committed
268
template <class Op>
269
struct cpu_deconvolution : auto_register_op<cpu_deconvolution<Op>>
kahmed10's avatar
kahmed10 committed
270
{
271
272
273
274
    cpu_deconvolution() = default;

    cpu_deconvolution(Op pop) : op(std::move(pop)) {}

kahmed10's avatar
kahmed10 committed
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
    Op op;

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op, f);
    }

    std::string name() const { return "cpu::" + op.name(); }
    shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
    argument compute(context&, shape output_shape, std::vector<argument> args) const
    {
        argument result{output_shape};
        visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) {
            using type = typename decltype(output)::value_type;

            std::fill(output.begin(), output.end(), type{0});

kahmed10's avatar
kahmed10 committed
293
294
295
            auto in_lens = input.get_shape().lens();
            auto in_n    = in_lens[0];
            auto in_c    = in_lens[1];
kahmed10's avatar
kahmed10 committed
296
297
298
299

            auto wei   = weights.get_shape().lens();
            auto wei_n = wei[0];
            auto wei_c = wei[1];
kahmed10's avatar
kahmed10 committed
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353

            auto out_lens = output_shape.lens();
            auto kdims    = op.kdims();

            std::vector<std::size_t> win_size{in_c};
            std::copy(in_lens.begin() + 2, in_lens.end(), std::back_inserter(win_size));
            std::copy(wei.begin() + 2, wei.end(), std::back_inserter(win_size));
            shape win_shape{output_shape.type(), win_size};

            par_dfor(in_n, wei_c)([&](int o, int k) {

                shape_for_each(win_shape, [&](auto idx_win) {
                    const int w = idx_win[0];

                    auto input_dims_start = idx_win.begin() + 1;
                    auto wei_dims_start   = idx_win.begin() + kdims + 1;

                    std::vector<std::ptrdiff_t> win_start;
                    for(std::size_t n = 0; n < kdims; ++n)
                    {
                        win_start.push_back(std::ptrdiff_t(*(input_dims_start + n) * op.stride[n]) -
                                            std::ptrdiff_t(op.padding[n]));
                    }

                    const int group_id = w / (wei_n / op.group);
                    const int in_ch    = group_id * wei_c + k;

                    std::vector<std::ptrdiff_t> idx_out{o, in_ch};

                    for(size_t n = 0; n < kdims; n++)
                    {
                        idx_out.push_back(win_start[n] + *(wei_dims_start + n) * op.dilation[n]);
                    }

                    std::vector<std::ptrdiff_t> idx_wei{w, k};
                    std::copy(wei_dims_start, idx_win.end(), std::back_inserter(idx_wei));

                    std::vector<std::ptrdiff_t> idx_in{o, w};
                    std::copy(input_dims_start, wei_dims_start, std::back_inserter(idx_in));

                    if(std::all_of(
                           idx_out.begin() + 2, idx_out.end(), [&](auto ii) { return ii >= 0; }) and
                       std::equal(idx_out.begin() + 2,
                                  idx_out.end(),
                                  out_lens.begin() + 2,
                                  out_lens.end(),
                                  std::less<std::ptrdiff_t>{}))
                    {
                        output(idx_out.begin(), idx_out.end()) +=
                            input(idx_in.begin(), idx_in.end()) *
                            weights(idx_wei.begin(), idx_wei.end());
                    }
                });

kahmed10's avatar
kahmed10 committed
354
            });
kahmed10's avatar
kahmed10 committed
355

kahmed10's avatar
kahmed10 committed
356
357
358
359
360
        });
        return result;
    }
};

Scott Thornton's avatar
Scott Thornton committed
361
362
struct cpu_im2col
{
363
    op::im2col op;
Scott Thornton's avatar
Scott Thornton committed
364

365
366
367
368
369
370
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op, f);
    }

Scott Thornton's avatar
Scott Thornton committed
371
372
    static std::string name() { return "cpu::im2col"; }
    shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
Scott Thornton's avatar
Scott Thornton committed
373

wsttiger's avatar
wsttiger committed
374
    argument compute(context&, const shape& output_shape, std::vector<argument> args) const
Scott Thornton's avatar
Scott Thornton committed
375
    {
Scott Thornton's avatar
Scott Thornton committed
376
        argument result{output_shape};
Scott Thornton's avatar
Scott Thornton committed
377
        auto input_shape   = args[0].get_shape();
Scott Thornton's avatar
Scott Thornton committed
378
379
        auto weights_shape = args[1].get_shape();
        visit_all(result, args[0])([&](auto col, auto input) {
Scott Thornton's avatar
Scott Thornton committed
380
381
            const std::size_t& height   = input_shape.lens()[2];
            const std::size_t& width    = input_shape.lens()[3];
Scott Thornton's avatar
Scott Thornton committed
382
383
384
            const std::size_t& channels = weights_shape.lens()[1];
            const std::size_t& kernel_h = weights_shape.lens()[2];
            const std::size_t& kernel_w = weights_shape.lens()[3];
Scott Thornton's avatar
Scott Thornton committed
385
386
            const std::size_t& pad_h    = op.padding[0];
            const std::size_t& pad_w    = op.padding[1];
Scott Thornton's avatar
Scott Thornton committed
387
388
389
            const std::size_t& stride_h = op.stride[0];
            const std::size_t& stride_w = op.stride[1];

Paul's avatar
Paul committed
390
391
            long kdiv2_h = long(kernel_h) / 2;
            long kdiv2_w = long(kernel_w) / 2;
Scott Thornton's avatar
Scott Thornton committed
392
            // calculate output sizes
Scott Thornton's avatar
Scott Thornton committed
393
394
            const std::size_t col_height = (height - kernel_h + 2 * pad_h) / stride_h + 1;
            const std::size_t col_width  = (width - kernel_w + 2 * pad_w) / stride_w + 1;
wsttiger's avatar
wsttiger committed
395
            // account for padding for the starting position of the input pixels
Paul's avatar
Paul committed
396
            long iinput = kdiv2_h - long(pad_h);
wsttiger's avatar
wsttiger committed
397
            // loop over output pixels (ioutput, joutput)
Scott Thornton's avatar
Scott Thornton committed
398
399
            for(std::size_t ioutput = 0; ioutput < col_height; ioutput++, iinput += stride_h)
            {
Paul's avatar
Paul committed
400
                long jinput = kdiv2_w - long(pad_w);
Scott Thornton's avatar
Scott Thornton committed
401
402
403
404
405
                for(std::size_t joutput = 0; joutput < col_width; joutput++, jinput += stride_w)
                {
                    // compute linear index for output
                    std::size_t ldx = ioutput * col_width + joutput;
                    std::size_t p   = 0;
wsttiger's avatar
wsttiger committed
406
407
408
                    dfor(channels,
                         kernel_h,
                         kernel_w)([&](std::size_t c, std::size_t koffset, std::size_t loffset) {
Paul's avatar
Paul committed
409
410
                        auto idx    = iinput + long(koffset) - kdiv2_h;
                        auto jdx    = jinput + long(loffset) - kdiv2_w;
wsttiger's avatar
wsttiger committed
411
412
413
414
415
                        col(ldx, p) = ((idx >= 0) && (idx < height) && (jdx >= 0) && (jdx < width))
                                          ? input(0, c, idx, jdx)
                                          : 0;
                        p++;
                    });
Scott Thornton's avatar
Scott Thornton committed
416
417
                }
            }
Scott Thornton's avatar
Scott Thornton committed
418
        });
Scott Thornton's avatar
Scott Thornton committed
419
420
421
        return result;
    }
};
422
MIGRAPHX_REGISTER_OP(cpu_im2col)
Scott Thornton's avatar
Scott Thornton committed
423

Paul's avatar
Paul committed
424
425
426
struct max_pool
{
    static std::string name() { return "max"; }
Shucai Xiao's avatar
Shucai Xiao committed
427
428
429
430
431
    template <class T>
    static T start()
    {
        return std::numeric_limits<T>::lowest();
    }
Paul's avatar
Paul committed
432
433
434
435
436
437
438

    static double apply(double x, double y)
    {
        double m = std::max(x, y);
        return (m);
    }

Shucai Xiao's avatar
Shucai Xiao committed
439
    static double final(double x, std::size_t) { return (x); }
Paul's avatar
Paul committed
440
441
442
443
444
};

struct avg_pool
{
    static std::string name() { return "average"; }
Shucai Xiao's avatar
Shucai Xiao committed
445
446
447
448
449
450

    template <class T>
    static double start()
    {
        return 0.0;
    }
Paul's avatar
Paul committed
451
452
453

    static double apply(double x, double y) { return x + y; }

Shucai Xiao's avatar
Shucai Xiao committed
454
    static double final(double x, std::size_t y) { return (y == 0) ? 0.0 : (x / y); }
Paul's avatar
Paul committed
455
456
457
};

template <class Op>
458
struct cpu_pooling : auto_register_op<cpu_pooling<Op>>
Paul's avatar
Paul committed
459
{
460
461
462
463
    cpu_pooling() = default;

    cpu_pooling(op::pooling pop) : op(std::move(pop)) {}

464
    op::pooling op;
Paul's avatar
Paul committed
465

466
467
468
469
470
471
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op, f);
    }

Paul's avatar
Paul committed
472
    std::string name() const { return "cpu::pooling_" + Op::name(); }
Paul's avatar
Paul committed
473
474
    shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
    argument compute(context&, const shape& output_shape, std::vector<argument> args) const
Paul's avatar
Paul committed
475
476
477
    {
        argument result{output_shape};
        visit_all(result, args[0])([&](auto output, auto input) {
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
            using type   = typename decltype(output)::value_type;
            auto in_s    = input.get_shape();
            auto in_lens = in_s.lens();
            std::vector<std::size_t> vec_len(in_lens.begin() + 2, in_lens.end());

            par_for(output_shape.elements(), [&](auto i) {
                auto idx_o = output_shape.multi(i);
                auto n_dim = idx_o.size();
                std::vector<std::size_t> win_start;
                std::vector<std::size_t> win_size;
                for(std::size_t dim = 2; dim < n_dim; ++dim)
                {
                    auto d_2  = dim - 2;
                    int start = static_cast<int>(idx_o[dim] * op.stride[d_2]) -
                                static_cast<int>(op.padding[d_2]);
                    int end = std::min(start + op.lengths[d_2], in_lens[dim]);
                    start   = std::max(start, 0);
                    win_start.push_back(start);
                    win_size.push_back(end - start);
                }

                shape win_shape{output_shape.type(), win_size};
                auto pool_size = win_shape.elements();
Shucai Xiao's avatar
Shucai Xiao committed
501
                double acc     = Op::template start<type>();
502
503
504
505
506
507
508
509
510
511
512
513
                shape_for_each(win_shape, [&](auto idx_w) {
                    auto idx = idx_o;
                    std::transform(idx_w.begin(),
                                   idx_w.end(),
                                   win_start.begin(),
                                   idx.begin() + 2,
                                   [](auto ii, auto jj) { return ii + jj; });
                    if(std::all_of(idx.begin() + 2, idx.end(), [&](auto ii) { return ii >= 0; }) and
                       idx < in_lens)
                    {
                        acc = Op::apply(acc, input[in_s.index(idx)]);
                    }
Paul's avatar
Paul committed
514
                });
515
516
517

                output[i] = type(Op::final(acc, pool_size));
            });
Paul's avatar
Paul committed
518
        });
519

Paul's avatar
Paul committed
520
521
522
523
        return result;
    }
};

524
struct cpu_op
Paul's avatar
Paul committed
525
{
526
    operation op = op::identity{};
kahmed10's avatar
kahmed10 committed
527
528
529
530
531
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op, f);
    }
532
    std::string name() const { return "cpu::op"; }
Paul's avatar
Paul committed
533
    shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
Paul's avatar
Paul committed
534
    argument compute(context&, const shape& output_shape, const std::vector<argument>& args) const
Paul's avatar
Paul committed
535
    {
Paul's avatar
Paul committed
536
        return op.compute(output_shape, args);
Paul's avatar
Paul committed
537
    }
538
539
540
541
542
543
544
545
546
547
548
    value to_value() const
    {
        value v;
        v["name"]     = op.name();
        v["operator"] = op.to_value();
        return v;
    }
    void from_value(const value& v)
    {
        op = make_op(v.at("name").to<std::string>(), v.at("operator"));
    }
549
    friend std::ostream& operator<<(std::ostream& os, const cpu_op& x)
Paul's avatar
Paul committed
550
    {
551
552
        os << "cpu::" << x.op;
        return os;
Paul's avatar
Paul committed
553
554
    }
};
555
MIGRAPHX_REGISTER_OP(cpu_op)
Paul's avatar
Paul committed
556

Khalique's avatar
Khalique committed
557
struct cpu_pad
558
{
Khalique's avatar
Khalique committed
559
    op::pad op;
560
561
562
563
564
565
566

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op, f);
    }

kahmed10's avatar
kahmed10 committed
567
    std::string name() const { return "cpu::pad"; }
568
569
570
    shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
    argument compute(context&, const shape& output_shape, std::vector<argument> args) const
    {
Khalique's avatar
Khalique committed
571
        assert(output_shape.standard());
572
        argument result{output_shape};
573
574
575
576
        result.visit([&](auto output) {
            using type = typename decltype(output)::value_type;
            std::fill(output.begin(), output.end(), pad_clamp<type>(op.value));
        });
Khalique's avatar
Khalique committed
577
578

        visit_all(result, args[0])([&](auto output, auto input) {
579
            shape_for_each(input.get_shape(), [&](const auto& idx) {
Khalique's avatar
Khalique committed
580
                std::vector<std::size_t> new_idx(idx.size());
Khalique's avatar
Khalique committed
581
582
583
584
                std::transform(
                    idx.begin(), idx.end(), op.pads.begin(), new_idx.begin(), [](auto i, auto j) {
                        return i + j;
                    });
Khalique's avatar
Khalique committed
585
                output(new_idx.begin(), new_idx.end()) = input(idx.begin(), idx.end());
586
            });
Khalique's avatar
Khalique committed
587
588
        });

589
590
591
        return result;
    }
};
592
MIGRAPHX_REGISTER_OP(cpu_pad)
593

Paul's avatar
Paul committed
594
595
struct cpu_gemm
{
Shucai Xiao's avatar
Shucai Xiao committed
596
    op::dot op;
597
598
599
600
601
602

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op, f);
    }
Shucai Xiao's avatar
Shucai Xiao committed
603
    std::string name() const { return "cpu::dot"; }
Shucai Xiao's avatar
Shucai Xiao committed
604
605
    shape compute_shape(const std::vector<shape>& inputs) const
    {
Shucai Xiao's avatar
Shucai Xiao committed
606
607
608
        if(inputs.size() == 3)
        {
            auto c_shape = inputs.at(2);
609
            check_shapes{{c_shape}, *this}.not_broadcasted();
Shucai Xiao's avatar
Shucai Xiao committed
610
        }
Shucai Xiao's avatar
Shucai Xiao committed
611
        return op.compute_shape(inputs);
Shucai Xiao's avatar
Shucai Xiao committed
612
    }
Paul's avatar
Paul committed
613

Paul's avatar
Paul committed
614
    argument compute(context&, const shape& output_shape, std::vector<argument> args) const
Paul's avatar
Paul committed
615
616
    {
        argument result{output_shape};
Shucai Xiao's avatar
Shucai Xiao committed
617
        // 3 inputs, it is alpha * A * B + beta * C, then
618
        // A and B are matrices, and C is of the same shape as A * B
Shucai Xiao's avatar
Shucai Xiao committed
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
        if(args.size() == 3)
        {
            // no need to consider the value of args[2]
            if(op.beta == 0.0f)
            {
                result.visit([&](auto output) { std::fill(output.begin(), output.end(), 0); });
            }
            else
            {
                visit_all(result, args[2])([&](auto output, auto input) {
                    std::copy(input.begin(), input.end(), output.begin());
                });
            }

            migemm(result, args[0], args[1], op.alpha, op.beta);

            return result;
        }

        // 2 input arguments
        migemm(result, args[0], args[1], op.alpha, 0.0f);

Paul's avatar
Paul committed
641
642
643
        return result;
    }
};
644
MIGRAPHX_REGISTER_OP(cpu_gemm)
Paul's avatar
Paul committed
645

646
647
648
struct cpu_quant_gemm
{
    op::quant_dot op;
649
650
651
652
653
654
655

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op, f);
    }

656
657
658
659
660
661
    std::string name() const { return "cpu::quant_dot"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
        if(inputs.size() == 3)
        {
            auto c_shape = inputs.at(2);
662
            check_shapes{{c_shape}, *this}.not_broadcasted();
663
664
665
666
667
668
669
670
671
672
673
674
675
676
        }
        return op.compute_shape(inputs);
    }

    argument compute(context&, const shape& output_shape, std::vector<argument> args) const
    {
        argument result{output_shape};
        // 3 inputs, it is alpha * A * B + beta * C, then
        // A and B are matrices, and C is of the same shape to A * B

        // first, convert the args[0] and args[1] from int8_t to int32_t
        argument arg_0{{shape::int32_type, {args.at(0).get_shape().lens()}}};
        argument arg_1{{shape::int32_type, {args.at(1).get_shape().lens()}}};
        arg_0.visit([&](auto output) {
Shucai Xiao's avatar
Shucai Xiao committed
677
678
            args.at(0).visit(
                [&](auto input) { std::copy(input.begin(), input.end(), output.begin()); });
679
680
681
        });

        arg_1.visit([&](auto output) {
Shucai Xiao's avatar
Shucai Xiao committed
682
683
            args.at(1).visit(
                [&](auto input) { std::copy(input.begin(), input.end(), output.begin()); });
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
        });

        if(args.size() == 3)
        {
            // no need to consider the value of args[2]
            if(op.beta == 0)
            {
                result.visit([&](auto output) { std::fill(output.begin(), output.end(), 0); });
            }
            else
            {
                visit_all(result, args[2])([&](auto output, auto input) {
                    std::copy(input.begin(), input.end(), output.begin());
                });
            }

            migemm(result, arg_0, arg_1, op.alpha, op.beta);

            return result;
        }

        // 2 input arguments
706
        migemm(result, arg_0, arg_1, op.alpha, int32_t{0});
707
708
709
710

        return result;
    }
};
711
MIGRAPHX_REGISTER_OP(cpu_gemm)
712

Khalique's avatar
Khalique committed
713
714
715
716
717
718
struct leaky_relu_op
{
    op::leaky_relu op;
    std::string name() const { return "cpu::leaky_relu"; }
    auto fcn() const
    {
Paul's avatar
Paul committed
719
        auto a = op.alpha;
Khalique's avatar
Khalique committed
720
721
722
723
        return [a](auto x) { return x > 0 ? x : x * a; };
    }
};

Khalique's avatar
Khalique committed
724
725
726
727
728
729
struct elu_op
{
    op::elu op;
    std::string name() const { return "cpu::elu"; }
    auto fcn() const
    {
Paul's avatar
Paul committed
730
        auto a = op.alpha;
Khalique's avatar
Khalique committed
731
732
733
734
        return [a](auto x) { return x > 0 ? x : a * std::expm1(x); };
    }
};

Paul's avatar
Paul committed
735
template <typename Op>
736
struct cpu_unary : auto_register_op<cpu_unary<Op>>
Paul's avatar
Paul committed
737
{
738
739
740
741
742
743
744
    cpu_unary() = default;

    template <class T>
    cpu_unary(T pop) : op(Op{std::move(pop)})
    {
    }

Paul's avatar
Paul committed
745
    Op op;
746
747
748
749
750
751

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op.op, f);
    }
Paul's avatar
Paul committed
752
    std::string name() const { return op.name(); }
Shucai Xiao's avatar
Shucai Xiao committed
753
    shape compute_shape(const std::vector<shape>& inputs) const
754
    {
755
        check_shapes{inputs, *this}.has(1);
Shucai Xiao's avatar
Shucai Xiao committed
756
        auto s = inputs.at(0);
757
        return {s.type(), s.lens()};
758
759
    }

Paul's avatar
Paul committed
760
    argument compute(context&, const shape& output_shape, std::vector<argument> args) const
Paul's avatar
Paul committed
761
762
    {
        argument result{output_shape};
763
764
765
        visit_all(result, args[0])([&](auto output, auto input) {
            assert(input.get_shape().standard());
            std::transform(input.begin(), input.end(), output.begin(), op.fcn());
Paul's avatar
Paul committed
766
        });
767

Paul's avatar
Paul committed
768
769
770
771
        return result;
    }
};

772
template <class Op>
773
struct cpu_softmax : auto_register_op<cpu_softmax<Op>>
Paul's avatar
Paul committed
774
{
775
776
777
778
    cpu_softmax() = default;

    cpu_softmax(Op pop) : op(std::move(pop)) {}

779
    Op op;
Khalique's avatar
Khalique committed
780
781
782
783
784
785
786

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op, f);
    }

787
    std::string name() const { return "cpu::" + op.name(); }
Khalique's avatar
Khalique committed
788
    shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
Paul's avatar
Paul committed
789
    argument compute(context&, const shape& output_shape, std::vector<argument> args) const
Paul's avatar
Paul committed
790
791
    {
        argument result{output_shape};
792
793
794
795
        auto batch_lens    = output_shape.lens();
        int64_t tuned_axis = (op.axis < 0) ? op.axis + args[0].get_shape().lens().size() : op.axis;
        std::size_t n_dims = batch_lens[tuned_axis];
        batch_lens[tuned_axis] = 1;
796
797
        shape batch_shape{shape::int32_type, batch_lens};

Paul's avatar
Paul committed
798
799
        visit_all(result, args[0])([&](auto output, auto input) {
            using value_type = typename decltype(input)::value_type;
Shucai Xiao's avatar
Shucai Xiao committed
800
801
            std::vector<value_type> batch_max(batch_shape.elements(),
                                              std::numeric_limits<value_type>::lowest());
802
803
            std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0));
            par_for(batch_shape.elements(), [&](auto i) {
804
                auto idx = batch_shape.multi(i);
Shucai Xiao's avatar
Shucai Xiao committed
805
                for(std::size_t j = 0; j < n_dims; ++j)
806
                {
807
808
                    idx[tuned_axis] = j;
                    batch_max[i]    = std::max(batch_max[i], input(idx.begin(), idx.end()));
809
                }
Khalique's avatar
Khalique committed
810

Shucai Xiao's avatar
Shucai Xiao committed
811
                for(std::size_t j = 0; j < n_dims; ++j)
812
                {
813
                    idx[tuned_axis]   = j;
Shucai Xiao's avatar
Shucai Xiao committed
814
815
                    std::size_t index = output_shape.index(idx);
                    output[index]     = std::exp(input[index] - batch_max[i]);
816
                }
Khalique's avatar
Khalique committed
817

Shucai Xiao's avatar
Shucai Xiao committed
818
                for(std::size_t j = 0; j < n_dims; ++j)
819
                {
820
                    idx[tuned_axis] = j;
821
822
                    batch_sum[i] += output(idx.begin(), idx.end());
                }
Khalique's avatar
Khalique committed
823

Shucai Xiao's avatar
Shucai Xiao committed
824
                for(std::size_t j = 0; j < n_dims; ++j)
825
                {
826
                    idx[tuned_axis] = j;
827
828
                    output(idx.begin(), idx.end()) =
                        op.output()(output(idx.begin(), idx.end()), batch_sum[i]);
829
                }
Shucai Xiao's avatar
Shucai Xiao committed
830
831
832
833
834
835
836
            });
        });

        return result;
    }
};

Shucai Xiao's avatar
Shucai Xiao committed
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
struct cpu_rnn_var_sl_last_output
{
    op::rnn_var_sl_last_output op;

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op, f);
    }

    std::string name() const { return "cpu::rnn_var_sl_last_output"; }

    shape compute_shape(std::vector<shape> inputs) const
    {
        return op.compute_shape(std::move(inputs));
    }

    argument compute(const shape& output_shape, std::vector<argument> args) const
    {
        argument result{output_shape};
        auto out_comp_lens = args[0].get_shape().lens();
        out_comp_lens[0]   = 1;
        shape out_comp_s{output_shape.type(), out_comp_lens};

        visit_all(result, args[0])([&](auto output, auto input) {
            args[1].visit([&](auto seq_lens) {
                par_for(output_shape.elements(), [&](auto i) {
                    auto idx = out_comp_s.multi(i);
                    auto b   = idx[2];
                    if(op.direction == op::rnn_direction::reverse or idx[1] == 1)
                    {
                        idx[0] = 0;
                    }
                    else
                    {
                        idx[0] = seq_lens[b] - 1;
                    }
                    output[i] = input(idx.begin(), idx.end());
                });
            });
        });

        return result;
    }
};
882
MIGRAPHX_REGISTER_OP(cpu_rnn_var_sl_last_output)
Shucai Xiao's avatar
Shucai Xiao committed
883

Paul's avatar
Paul committed
884
885
struct cpu_apply
{
886
    module* prog;
Paul's avatar
Paul committed
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
    std::unordered_map<std::string, std::function<void(instruction_ref)>> apply_map{};

    template <class T>
    auto simple_op()
    {
        return [this](instruction_ref ins) { apply_simple_op<T>(ins); };
    }

    template <class T, class Op>
    auto extend_op()
    {
        return [this](instruction_ref ins) { apply_extend_op<T, Op>(ins); };
    }

    void init()
    {
Aditya Atluri's avatar
Aditya Atluri committed
903
        apply_map["batch_norm_inference"] =
904
            extend_op<cpu_batch_norm_inference, op::batch_norm_inference>();
905
        apply_map["convolution"] = extend_op<cpu_convolution<op::convolution>, op::convolution>();
kahmed10's avatar
kahmed10 committed
906
907
908
909
        apply_map["deconvolution"] =
            extend_op<cpu_deconvolution<op::deconvolution>, op::deconvolution>();
        apply_map["dot"]       = extend_op<cpu_gemm, op::dot>();
        apply_map["quant_dot"] = extend_op<cpu_quant_gemm, op::quant_dot>();
910
911
912
913
914
915
916
917
918
        apply_map["quant_convolution"] =
            extend_op<cpu_convolution<op::quant_convolution>, op::quant_convolution>();
        apply_map["elu"]        = extend_op<cpu_unary<elu_op>, op::elu>();
        apply_map["im2col"]     = extend_op<cpu_im2col, op::im2col>();
        apply_map["leaky_relu"] = extend_op<cpu_unary<leaky_relu_op>, op::leaky_relu>();
        apply_map["logsoftmax"] = extend_op<cpu_softmax<op::logsoftmax>, op::logsoftmax>();
        apply_map["lrn"]        = extend_op<cpu_lrn, op::lrn>();
        apply_map["pad"]        = extend_op<cpu_pad, op::pad>();
        apply_map["softmax"]    = extend_op<cpu_softmax<op::softmax>, op::softmax>();
Shucai Xiao's avatar
Shucai Xiao committed
919
920
        apply_map["rnn_var_sl_last_output"] =
            extend_op<cpu_rnn_var_sl_last_output, op::rnn_var_sl_last_output>();
Paul's avatar
Paul committed
921
922
923
924
925
926
927
    }

    void apply()
    {
        init();
        for(auto it : iterator_for(*prog))
        {
Khalique's avatar
Khalique committed
928
            if(it->name() == "pooling")
Paul's avatar
Paul committed
929
930
931
            {
                apply_pooling(it);
            }
Paul's avatar
Paul committed
932
            else if(apply_map.count(it->name()) > 0)
Paul's avatar
Paul committed
933
            {
Paul's avatar
Paul committed
934
                apply_map.at(it->name())(it);
Paul's avatar
Paul committed
935
            }
Paul's avatar
Paul committed
936
            else if(is_context_free(it->get_operator()))
937
938
939
            {
                apply_cpu_op(it);
            }
Paul's avatar
Paul committed
940
941
942
        }
    }

943
    void apply_cpu_op(instruction_ref ins) const
944
945
946
947
    {
        prog->replace_instruction(ins, cpu_op{ins->get_operator()}, ins->inputs());
    }

Paul's avatar
Paul committed
948
949
950
    template <class T>
    void apply_simple_op(instruction_ref ins)
    {
Paul's avatar
Paul committed
951
        prog->replace_instruction(ins, T{}, ins->inputs());
Paul's avatar
Paul committed
952
953
954
955
956
    }

    template <class T, class Op>
    void apply_extend_op(instruction_ref ins)
    {
957
        auto&& op = any_cast<Op>(ins->get_operator());
Paul's avatar
Paul committed
958
        prog->replace_instruction(ins, T{op}, ins->inputs());
Paul's avatar
Paul committed
959
960
    }

961
    void apply_pooling(instruction_ref ins) const
Paul's avatar
Paul committed
962
    {
963
        auto&& op = any_cast<op::pooling>(ins->get_operator());
Paul's avatar
Paul committed
964
        if(op.mode == "max")
Paul's avatar
Paul committed
965
            prog->replace_instruction(ins, cpu_pooling<max_pool>{op}, ins->inputs());
Paul's avatar
Paul committed
966
        else if(op.mode == "average")
Paul's avatar
Paul committed
967
            prog->replace_instruction(ins, cpu_pooling<avg_pool>{op}, ins->inputs());
Paul's avatar
Paul committed
968
969
970
    }
};

971
void lowering::apply(module& p) const { cpu_apply{&p}.apply(); }
Paul's avatar
Paul committed
972
973

} // namespace cpu
Paul's avatar
Paul committed
974
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
975
} // namespace migraphx