lowering.cpp 30 KB
Newer Older
Paul's avatar
Paul committed
1

Paul's avatar
Paul committed
2
3
4
#include <migraphx/cpu/lowering.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/dfor.hpp>
Paul's avatar
Paul committed
5
6
#include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/convolution.hpp>
kahmed10's avatar
kahmed10 committed
7
#include <migraphx/op/deconvolution.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
8
#include <migraphx/op/quant_convolution.hpp>
Paul's avatar
Paul committed
9
#include <migraphx/op/dot.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
10
#include <migraphx/op/quant_dot.hpp>
Paul's avatar
Paul committed
11
12
13
14
15
16
17
18
#include <migraphx/op/elu.hpp>
#include <migraphx/op/im2col.hpp>
#include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/logsoftmax.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/softmax.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
19
20
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
21
#include <migraphx/op/rnn_var_sl_last_output.hpp>
Paul's avatar
Paul committed
22
23
#include <migraphx/shape_for_each.hpp>
#include <migraphx/iterator_for.hpp>
Paul's avatar
Paul committed
24
#include <migraphx/par_dfor.hpp>
25
#include <migraphx/clamp.hpp>
Paul's avatar
Paul committed
26
#include <migraphx/cpu/gemm.hpp>
Paul's avatar
Paul committed
27
#include <unordered_map>
Paul's avatar
Paul committed
28
#include <utility>
Paul's avatar
Paul committed
29

Paul's avatar
Paul committed
30
namespace migraphx {
Paul's avatar
Paul committed
31
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
32
33
34
35
36
37
38
39
namespace cpu {

template <typename T>
T zero(const T&)
{
    return T(0);
}

Khalique's avatar
Khalique committed
40
41
42
43
template <class T>
typename std::conditional_t<std::is_integral<T>{}, std::make_signed<T>, std::enable_if<true, T>>::
    type
    make_signed(T x)
Khalique's avatar
Khalique committed
44
45
46
47
{
    return x;
}

48
49
50
51
//
// cpu implemenataion of batch norm for inference
//
// inputs are:
52
53
54
55
// args[0] -> input data buffer
// args[1] -> mini batch mean
// args[2] -> mini batch variance
// args[3] -> gamma
Aditya Atluri's avatar
Aditya Atluri committed
56
// args[4] -> bias
57
58
59
//
// The equation to compute batch norm for inference is:
//
Aditya Atluri's avatar
Aditya Atluri committed
60
// output[i] = bias + gamma * (input[i] + mean) / sqrt(variance + epsilon)
61
62
63
64
65
//
// the input data format should be nchw
//
struct cpu_batch_norm_inference
{
66
    op::batch_norm_inference op;
67

68
69
70
71
72
73
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op, f);
    }

74
75
    std::string name() const { return "cpu::batch_norm_inference"; }

Paul's avatar
Paul committed
76
    shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
77

Paul's avatar
Paul committed
78
    argument compute(context&, const shape& output_shape, std::vector<argument> args) const
79
    {
80
81
        argument output{output_shape};

Aditya Atluri's avatar
Aditya Atluri committed
82
83
        double epsilon           = op.epsilon;
        auto input               = args[0];
Paul's avatar
Paul committed
84
85
86
87
        auto arg_gamma           = args[1];
        auto arg_bias            = args[2];
        auto mini_batch_mean     = args[3];
        auto mini_batch_variance = args[4];
88

89
        if(op.bn_mode == op::batch_norm_inference::spatial)
Scott Thornton's avatar
Scott Thornton committed
90
91
92
        {
            visit_all(output, input, mini_batch_mean, mini_batch_variance, arg_gamma, arg_bias)(
                [&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) {
Shucai Xiao's avatar
Shucai Xiao committed
93
94
95
96
97
98
99
100
                    par_for(output_shape.elements(), [&](auto i) {
                        auto idx = output_shape.multi(i);
                        auto c   = idx[1];
                        assert((variance[c] + epsilon) > 0);
                        result[i] =
                            gamma[c] * (buffer[i] - mean[c]) / std::sqrt(variance[c] + epsilon) +
                            bias[c];
                    });
Scott Thornton's avatar
Scott Thornton committed
101
                });
102
103
        }

104
        if(op.bn_mode == op::batch_norm_inference::per_activation)
Scott Thornton's avatar
Scott Thornton committed
105
        {
Shucai Xiao's avatar
Shucai Xiao committed
106
            visit_all(output, input, mini_batch_mean, mini_batch_variance, arg_gamma, arg_bias)(
107
                [&](auto result, auto buffer, auto mean, auto variance, auto gamma, auto bias) {
Shucai Xiao's avatar
Shucai Xiao committed
108
109
110
111
112
113
114
115
116
117
                    par_for(output_shape.elements(), [&](auto i) {
                        auto idx   = output_shape.multi(i);
                        idx[0]     = 0;
                        auto index = output_shape.index(idx);

                        assert((variance[index] + epsilon) > 0);
                        result[i] = gamma[index] * (buffer[i] - mean[index]) /
                                        std::sqrt(variance[index] + epsilon) +
                                    bias[index];
                    });
Scott Thornton's avatar
Scott Thornton committed
118
                });
119
        }
120
121
122
123
124

        return output;
    }
};

Khalique's avatar
Khalique committed
125
struct cpu_lrn
Khalique's avatar
Khalique committed
126
{
Khalique's avatar
Khalique committed
127
    op::lrn op;
Khalique's avatar
Khalique committed
128

129
130
131
132
133
134
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op, f);
    }

Khalique's avatar
Khalique committed
135
    std::string name() const { return "cpu::lrn"; }
Khalique's avatar
Khalique committed
136
137
138
139
140
    shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
    argument compute(context&, shape output_shape, std::vector<argument> args) const
    {
        argument result{output_shape};
        visit_all(result, args[0])([&](auto output, auto input) {
Khalique's avatar
Khalique committed
141
142
143
144
            int n_batch         = output_shape.lens()[0];
            int channels        = output_shape.lens()[1];
            int height          = output_shape.lens()[2];
            int width           = output_shape.lens()[3];
Paul's avatar
Paul committed
145
            float alphaoverarea = op.alpha / float(op.size);
146
147
            int radius_lower    = (op.size - 1) / 2;
            int radius_upper    = op.size / 2 + 1;
Khalique's avatar
Khalique committed
148

149
            par_dfor(n_batch, height, width)([&](int b, int h, int w) {
Khalique's avatar
Khalique committed
150
                float scale = 0;
Khalique's avatar
Khalique committed
151
                dfor(channels)([&](int c) {
152
153
                    auto start = (c - radius_lower) < 0 ? 0 : (c - radius_lower);
                    auto end   = (c + radius_upper) > channels ? channels : (c + radius_upper);
Khalique's avatar
Khalique committed
154
155
                    for(auto k = start; k < end; ++k)
                    {
Khalique's avatar
Khalique committed
156
                        scale += std::pow(input(b, k, h, w), 2);
Khalique's avatar
Khalique committed
157
158
159
                    }
                    scale *= alphaoverarea;
                    scale += op.bias;
Khalique's avatar
Khalique committed
160
                    scale              = std::pow(scale, -op.beta);
Khalique's avatar
Khalique committed
161
162
163
164
165
166
167
168
                    output(b, c, h, w) = input(b, c, h, w) * scale;
                });
            });
        });
        return result;
    }
};

Paul Fultz II's avatar
Paul Fultz II committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
template <class V, class T, class... Ts>
void visit_quantize_impl(V&& v, T&& x, Ts&&... xs)
{
    x.visit([&](auto y) { visit_all(xs...)([&](auto... ys) { v(y, ys...); }); });
}

template <class T, class... Ts>
auto visit_quantize(T&& x, Ts&&... xs)
{
    return [&](auto v) {
        // Workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70100
        visit_quantize_impl(v, x, xs...);
    };
}

184
template <class Op>
Paul's avatar
Paul committed
185
186
struct cpu_convolution
{
187
    Op op;
188

189
190
191
192
193
194
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op, f);
    }

195
    std::string name() const { return "cpu::" + op.name(); }
196
197
198
199
    shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
    argument compute(context&, shape output_shape, std::vector<argument> args) const
    {
        argument result{output_shape};
Paul Fultz II's avatar
Paul Fultz II committed
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
        visit_quantize(result, args[0], args[1])([&](auto output, auto input, auto weights) {
            auto in_lens = input.get_shape().lens();

            auto wei_lens = weights.get_shape().lens();
            auto wei_n    = wei_lens[0];
            auto wei_c    = wei_lens[1];
            std::vector<std::size_t> win_size(wei_lens.begin() + 1, wei_lens.end());

            par_for(output_shape.elements(), [&](auto i) {
                auto idx_o = output_shape.multi(i);
                auto w     = idx_o[1];
                auto n_dim = idx_o.size();

                std::vector<std::ptrdiff_t> win_start;
                for(std::size_t dim = 2; dim < n_dim; ++dim)
                {
                    auto d_2 = dim - 2;
                    win_start.push_back(std::ptrdiff_t(idx_o[dim] * op.stride[d_2]) -
                                        std::ptrdiff_t(op.padding[d_2]));
                }
                const auto group_id = w / (wei_n / op.group);

                shape win_shape{output_shape.type(), win_size};

                double acc = 0.0;
                shape_for_each(win_shape, [&](auto idx_win) {
                    auto k           = idx_win[0];
                    const auto in_ch = group_id * wei_c + k;
                    std::vector<std::ptrdiff_t> idx(idx_o.begin(), idx_o.end());
                    idx[1] = in_ch;
                    std::transform(idx_win.begin() + 1,
                                   idx_win.end(),
                                   win_start.begin(),
                                   idx.begin() + 2,
                                   [](std::ptrdiff_t ii, std::ptrdiff_t jj) { return ii + jj; });
                    std::vector<std::ptrdiff_t> idx_wei(idx_o.size());
                    idx_wei[0] = w;
                    std::copy(idx_win.begin(), idx_win.end(), idx_wei.begin() + 1);
                    if(std::all_of(idx.begin() + 2, idx.end(), [&](auto ii) { return ii >= 0; }) and
                       std::equal(idx.begin(),
                                  idx.end(),
                                  in_lens.begin(),
                                  in_lens.end(),
                                  std::less<std::ptrdiff_t>{}))
                    {
                        acc +=
                            input(idx.begin(), idx.end()) * weights(idx_wei.begin(), idx_wei.end());
                    }
                });

                output[i] = acc;
251
            });
252
253
254
255
256
        });
        return result;
    }
};

kahmed10's avatar
kahmed10 committed
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
template <class Op>
struct cpu_deconvolution
{
    Op op;

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op, f);
    }

    std::string name() const { return "cpu::" + op.name(); }
    shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
    argument compute(context&, shape output_shape, std::vector<argument> args) const
    {
        argument result{output_shape};
        visit_all(result, args[0], args[1])([&](auto output, auto input, auto weights) {
            using type = typename decltype(output)::value_type;

            std::fill(output.begin(), output.end(), type{0});

            auto out_lens = output_shape.lens();
            auto out_h    = out_lens[2];
            auto out_w    = out_lens[3];

            auto in   = input.get_shape().lens();
            auto in_n = in[0];
            auto in_c = in[1];
            auto in_h = in[2];
            auto in_w = in[3];

            auto wei   = weights.get_shape().lens();
            auto wei_n = wei[0];
            auto wei_c = wei[1];
            auto wei_h = wei[2];
            auto wei_w = wei[3];

            par_dfor(in_n, wei_c)([&](std::size_t o, std::size_t k) {

                dfor(in_c, in_h, in_w, wei_h, wei_w)(
                    [&](std::size_t w, std::size_t i, std::size_t j, std::size_t x, std::size_t y) {
                        const int start_x = i * op.stride[0] - op.padding[0];
                        const int start_y = j * op.stride[1] - op.padding[1];
                        const int out_x   = start_x + x * op.dilation[0];
                        const int out_y   = start_y + y * op.dilation[1];

                        const auto group_id = w / (wei_n / op.group);
                        const auto in_ch    = group_id * wei_c + k;

                        if(out_x >= 0 && out_x < out_h && out_y >= 0 && out_y < out_w)
                        {
                            output(o, in_ch, out_x, out_y) +=
                                input(o, w, i, j) * weights(w, k, x, y);
                        }
                    });
            });
        });
        return result;
    }
};

Scott Thornton's avatar
Scott Thornton committed
318
319
struct cpu_im2col
{
320
    op::im2col op;
Scott Thornton's avatar
Scott Thornton committed
321

322
323
324
325
326
327
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op, f);
    }

Scott Thornton's avatar
Scott Thornton committed
328
329
    static std::string name() { return "cpu::im2col"; }
    shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
Scott Thornton's avatar
Scott Thornton committed
330

wsttiger's avatar
wsttiger committed
331
    argument compute(context&, const shape& output_shape, std::vector<argument> args) const
Scott Thornton's avatar
Scott Thornton committed
332
    {
Scott Thornton's avatar
Scott Thornton committed
333
        argument result{output_shape};
Scott Thornton's avatar
Scott Thornton committed
334
        auto input_shape   = args[0].get_shape();
Scott Thornton's avatar
Scott Thornton committed
335
336
        auto weights_shape = args[1].get_shape();
        visit_all(result, args[0])([&](auto col, auto input) {
Scott Thornton's avatar
Scott Thornton committed
337
338
            const std::size_t& height   = input_shape.lens()[2];
            const std::size_t& width    = input_shape.lens()[3];
Scott Thornton's avatar
Scott Thornton committed
339
340
341
            const std::size_t& channels = weights_shape.lens()[1];
            const std::size_t& kernel_h = weights_shape.lens()[2];
            const std::size_t& kernel_w = weights_shape.lens()[3];
Scott Thornton's avatar
Scott Thornton committed
342
343
            const std::size_t& pad_h    = op.padding[0];
            const std::size_t& pad_w    = op.padding[1];
Scott Thornton's avatar
Scott Thornton committed
344
345
346
            const std::size_t& stride_h = op.stride[0];
            const std::size_t& stride_w = op.stride[1];

Paul's avatar
Paul committed
347
348
            long kdiv2_h = long(kernel_h) / 2;
            long kdiv2_w = long(kernel_w) / 2;
Scott Thornton's avatar
Scott Thornton committed
349
            // calculate output sizes
Scott Thornton's avatar
Scott Thornton committed
350
351
            const std::size_t col_height = (height - kernel_h + 2 * pad_h) / stride_h + 1;
            const std::size_t col_width  = (width - kernel_w + 2 * pad_w) / stride_w + 1;
wsttiger's avatar
wsttiger committed
352
            // account for padding for the starting position of the input pixels
Paul's avatar
Paul committed
353
            long iinput = kdiv2_h - long(pad_h);
wsttiger's avatar
wsttiger committed
354
            // loop over output pixels (ioutput, joutput)
Scott Thornton's avatar
Scott Thornton committed
355
356
            for(std::size_t ioutput = 0; ioutput < col_height; ioutput++, iinput += stride_h)
            {
Paul's avatar
Paul committed
357
                long jinput = kdiv2_w - long(pad_w);
Scott Thornton's avatar
Scott Thornton committed
358
359
360
361
362
                for(std::size_t joutput = 0; joutput < col_width; joutput++, jinput += stride_w)
                {
                    // compute linear index for output
                    std::size_t ldx = ioutput * col_width + joutput;
                    std::size_t p   = 0;
wsttiger's avatar
wsttiger committed
363
364
365
                    dfor(channels,
                         kernel_h,
                         kernel_w)([&](std::size_t c, std::size_t koffset, std::size_t loffset) {
Paul's avatar
Paul committed
366
367
                        auto idx    = iinput + long(koffset) - kdiv2_h;
                        auto jdx    = jinput + long(loffset) - kdiv2_w;
wsttiger's avatar
wsttiger committed
368
369
370
371
372
                        col(ldx, p) = ((idx >= 0) && (idx < height) && (jdx >= 0) && (jdx < width))
                                          ? input(0, c, idx, jdx)
                                          : 0;
                        p++;
                    });
Scott Thornton's avatar
Scott Thornton committed
373
374
                }
            }
Scott Thornton's avatar
Scott Thornton committed
375
        });
Scott Thornton's avatar
Scott Thornton committed
376
377
378
379
        return result;
    }
};

Paul's avatar
Paul committed
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
struct max_pool
{
    static std::string name() { return "max"; }
    static double start() { return std::numeric_limits<double>::lowest(); }

    static double apply(double x, double y)
    {
        double m = std::max(x, y);
        return (m);
    }

    static double final(double x, double) { return (x); }
};

struct avg_pool
{
    static std::string name() { return "average"; }
    static double start() { return 0.0; }

    static double apply(double x, double y) { return x + y; }

    static double final(double x, double y) { return x / y; }
};

template <class Op>
struct cpu_pooling
{
407
    op::pooling op;
Paul's avatar
Paul committed
408

409
410
411
412
413
414
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op, f);
    }

Paul's avatar
Paul committed
415
    std::string name() const { return "cpu::pooling_" + Op::name(); }
Paul's avatar
Paul committed
416
417
    shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
    argument compute(context&, const shape& output_shape, std::vector<argument> args) const
Paul's avatar
Paul committed
418
419
420
    {
        argument result{output_shape};
        visit_all(result, args[0])([&](auto output, auto input) {
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
            using type   = typename decltype(output)::value_type;
            auto in_s    = input.get_shape();
            auto in_lens = in_s.lens();
            std::vector<std::size_t> vec_len(in_lens.begin() + 2, in_lens.end());

            par_for(output_shape.elements(), [&](auto i) {
                auto idx_o = output_shape.multi(i);
                auto n_dim = idx_o.size();
                std::vector<std::size_t> win_start;
                std::vector<std::size_t> win_size;
                for(std::size_t dim = 2; dim < n_dim; ++dim)
                {
                    auto d_2  = dim - 2;
                    int start = static_cast<int>(idx_o[dim] * op.stride[d_2]) -
                                static_cast<int>(op.padding[d_2]);
                    int end = std::min(start + op.lengths[d_2], in_lens[dim]);
                    start   = std::max(start, 0);
                    win_start.push_back(start);
                    win_size.push_back(end - start);
                }

                shape win_shape{output_shape.type(), win_size};
                auto pool_size = win_shape.elements();
                double acc     = Op::start();
                shape_for_each(win_shape, [&](auto idx_w) {
                    auto idx = idx_o;
                    std::transform(idx_w.begin(),
                                   idx_w.end(),
                                   win_start.begin(),
                                   idx.begin() + 2,
                                   [](auto ii, auto jj) { return ii + jj; });
                    if(std::all_of(idx.begin() + 2, idx.end(), [&](auto ii) { return ii >= 0; }) and
                       idx < in_lens)
                    {
                        acc = Op::apply(acc, input[in_s.index(idx)]);
                    }
Paul's avatar
Paul committed
457
                });
458
459
460

                output[i] = type(Op::final(acc, pool_size));
            });
Paul's avatar
Paul committed
461
        });
462

Paul's avatar
Paul committed
463
464
465
466
        return result;
    }
};

467
struct cpu_op
Paul's avatar
Paul committed
468
{
469
470
    operation op;
    std::string name() const { return "cpu::" + op.name(); }
Paul's avatar
Paul committed
471
    shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
Paul's avatar
Paul committed
472
    argument compute(context&, const shape& output_shape, const std::vector<argument>& args) const
Paul's avatar
Paul committed
473
    {
Paul's avatar
Paul committed
474
        return op.compute(output_shape, args);
Paul's avatar
Paul committed
475
    }
Paul's avatar
Paul committed
476
    friend bool operator==(const cpu_op& x, const cpu_op& y) { return x.op == y.op; }
477
    friend bool operator==(const cpu_op& x, const operation& y)
Paul's avatar
Paul committed
478
    {
479
480
481
        if(x.name() != y.name())
            return false;
        return x == any_cast<cpu_op>(y);
Paul's avatar
Paul committed
482
    }
Paul's avatar
Paul committed
483
    friend bool operator==(const operation& x, const cpu_op& y) { return y == x; }
Paul's avatar
Paul committed
484
485
};

Khalique's avatar
Khalique committed
486
struct cpu_pad
487
{
Khalique's avatar
Khalique committed
488
    op::pad op;
489
490
491
492
493
494
495

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op, f);
    }

Khalique's avatar
Khalique committed
496
    std::string name() const { return "cpu::contiguous"; }
497
498
499
    shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
    argument compute(context&, const shape& output_shape, std::vector<argument> args) const
    {
Khalique's avatar
Khalique committed
500
        assert(output_shape.standard());
501
        argument result{output_shape};
502
503
504
505
        result.visit([&](auto output) {
            using type = typename decltype(output)::value_type;
            std::fill(output.begin(), output.end(), pad_clamp<type>(op.value));
        });
Khalique's avatar
Khalique committed
506
507

        visit_all(result, args[0])([&](auto output, auto input) {
508
            shape_for_each(input.get_shape(), [&](const auto& idx) {
Khalique's avatar
Khalique committed
509
                std::vector<std::size_t> new_idx(idx.size());
Khalique's avatar
Khalique committed
510
511
512
513
                std::transform(
                    idx.begin(), idx.end(), op.pads.begin(), new_idx.begin(), [](auto i, auto j) {
                        return i + j;
                    });
Khalique's avatar
Khalique committed
514
                output(new_idx.begin(), new_idx.end()) = input(idx.begin(), idx.end());
515
            });
Khalique's avatar
Khalique committed
516
517
        });

518
519
520
521
        return result;
    }
};

Paul's avatar
Paul committed
522
523
struct cpu_gemm
{
Shucai Xiao's avatar
Shucai Xiao committed
524
    op::dot op;
525
526
527
528
529
530

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op, f);
    }
Shucai Xiao's avatar
Shucai Xiao committed
531
    std::string name() const { return "cpu::dot"; }
Shucai Xiao's avatar
Shucai Xiao committed
532
533
    shape compute_shape(const std::vector<shape>& inputs) const
    {
Shucai Xiao's avatar
Shucai Xiao committed
534
535
536
537
538
        if(inputs.size() == 3)
        {
            auto c_shape = inputs.at(2);
            check_shapes{{c_shape}}.not_broadcasted();
        }
Shucai Xiao's avatar
Shucai Xiao committed
539
        return op.compute_shape(inputs);
Shucai Xiao's avatar
Shucai Xiao committed
540
    }
Paul's avatar
Paul committed
541

Paul's avatar
Paul committed
542
    argument compute(context&, const shape& output_shape, std::vector<argument> args) const
Paul's avatar
Paul committed
543
544
    {
        argument result{output_shape};
Shucai Xiao's avatar
Shucai Xiao committed
545
        // 3 inputs, it is alpha * A * B + beta * C, then
546
        // A and B are matrices, and C is of the same shape as A * B
Shucai Xiao's avatar
Shucai Xiao committed
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
        if(args.size() == 3)
        {
            // no need to consider the value of args[2]
            if(op.beta == 0.0f)
            {
                result.visit([&](auto output) { std::fill(output.begin(), output.end(), 0); });
            }
            else
            {
                visit_all(result, args[2])([&](auto output, auto input) {
                    std::copy(input.begin(), input.end(), output.begin());
                });
            }

            migemm(result, args[0], args[1], op.alpha, op.beta);

            return result;
        }

        // 2 input arguments
        migemm(result, args[0], args[1], op.alpha, 0.0f);

Paul's avatar
Paul committed
569
570
571
572
        return result;
    }
};

573
574
575
struct cpu_quant_gemm
{
    op::quant_dot op;
576
577
578
579
580
581
582

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op, f);
    }

583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
    std::string name() const { return "cpu::quant_dot"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
        if(inputs.size() == 3)
        {
            auto c_shape = inputs.at(2);
            check_shapes{{c_shape}}.not_broadcasted();
        }
        return op.compute_shape(inputs);
    }

    argument compute(context&, const shape& output_shape, std::vector<argument> args) const
    {
        argument result{output_shape};
        // 3 inputs, it is alpha * A * B + beta * C, then
        // A and B are matrices, and C is of the same shape to A * B

        // first, convert the args[0] and args[1] from int8_t to int32_t
        argument arg_0{{shape::int32_type, {args.at(0).get_shape().lens()}}};
        argument arg_1{{shape::int32_type, {args.at(1).get_shape().lens()}}};
        arg_0.visit([&](auto output) {
Shucai Xiao's avatar
Shucai Xiao committed
604
605
            args.at(0).visit(
                [&](auto input) { std::copy(input.begin(), input.end(), output.begin()); });
606
607
608
        });

        arg_1.visit([&](auto output) {
Shucai Xiao's avatar
Shucai Xiao committed
609
610
            args.at(1).visit(
                [&](auto input) { std::copy(input.begin(), input.end(), output.begin()); });
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
        });

        if(args.size() == 3)
        {
            // no need to consider the value of args[2]
            if(op.beta == 0)
            {
                result.visit([&](auto output) { std::fill(output.begin(), output.end(), 0); });
            }
            else
            {
                visit_all(result, args[2])([&](auto output, auto input) {
                    std::copy(input.begin(), input.end(), output.begin());
                });
            }

            migemm(result, arg_0, arg_1, op.alpha, op.beta);

            return result;
        }

        // 2 input arguments
633
        migemm(result, arg_0, arg_1, op.alpha, int32_t{0});
634
635
636
637
638

        return result;
    }
};

Khalique's avatar
Khalique committed
639
640
641
642
643
644
struct leaky_relu_op
{
    op::leaky_relu op;
    std::string name() const { return "cpu::leaky_relu"; }
    auto fcn() const
    {
Paul's avatar
Paul committed
645
        auto a = op.alpha;
Khalique's avatar
Khalique committed
646
647
648
649
        return [a](auto x) { return x > 0 ? x : x * a; };
    }
};

Khalique's avatar
Khalique committed
650
651
652
653
654
655
struct elu_op
{
    op::elu op;
    std::string name() const { return "cpu::elu"; }
    auto fcn() const
    {
Paul's avatar
Paul committed
656
        auto a = op.alpha;
Khalique's avatar
Khalique committed
657
658
659
660
        return [a](auto x) { return x > 0 ? x : a * std::expm1(x); };
    }
};

Paul's avatar
Paul committed
661
662
663
664
template <typename Op>
struct cpu_unary
{
    Op op;
665
666
667
668
669
670

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op.op, f);
    }
Paul's avatar
Paul committed
671
    std::string name() const { return op.name(); }
Shucai Xiao's avatar
Shucai Xiao committed
672
    shape compute_shape(const std::vector<shape>& inputs) const
673
    {
Shucai Xiao's avatar
Shucai Xiao committed
674
675
        check_shapes{inputs}.has(1);
        auto s = inputs.at(0);
676
        return {s.type(), s.lens()};
677
678
    }

Paul's avatar
Paul committed
679
    argument compute(context&, const shape& output_shape, std::vector<argument> args) const
Paul's avatar
Paul committed
680
681
    {
        argument result{output_shape};
682
683
684
        visit_all(result, args[0])([&](auto output, auto input) {
            assert(input.get_shape().standard());
            std::transform(input.begin(), input.end(), output.begin(), op.fcn());
Paul's avatar
Paul committed
685
        });
686

Paul's avatar
Paul committed
687
688
689
690
        return result;
    }
};

691
template <class Op>
Khalique's avatar
Khalique committed
692
struct cpu_softmax
Paul's avatar
Paul committed
693
{
694
    Op op;
Khalique's avatar
Khalique committed
695
696
697
698
699
700
701

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op, f);
    }

702
    std::string name() const { return "cpu::" + op.name(); }
Khalique's avatar
Khalique committed
703
    shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
Paul's avatar
Paul committed
704
    argument compute(context&, const shape& output_shape, std::vector<argument> args) const
Paul's avatar
Paul committed
705
706
    {
        argument result{output_shape};
707
708
709
710
        auto batch_lens    = output_shape.lens();
        int64_t tuned_axis = (op.axis < 0) ? op.axis + args[0].get_shape().lens().size() : op.axis;
        std::size_t n_dims = batch_lens[tuned_axis];
        batch_lens[tuned_axis] = 1;
711
712
        shape batch_shape{shape::int32_type, batch_lens};

Paul's avatar
Paul committed
713
714
        visit_all(result, args[0])([&](auto output, auto input) {
            using value_type = typename decltype(input)::value_type;
Shucai Xiao's avatar
Shucai Xiao committed
715
716
            std::vector<value_type> batch_max(batch_shape.elements(),
                                              std::numeric_limits<value_type>::lowest());
717
718
            std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0));
            par_for(batch_shape.elements(), [&](auto i) {
719
                auto idx = batch_shape.multi(i);
Shucai Xiao's avatar
Shucai Xiao committed
720
                for(std::size_t j = 0; j < n_dims; ++j)
721
                {
722
723
                    idx[tuned_axis] = j;
                    batch_max[i]    = std::max(batch_max[i], input(idx.begin(), idx.end()));
724
                }
Khalique's avatar
Khalique committed
725

Shucai Xiao's avatar
Shucai Xiao committed
726
                for(std::size_t j = 0; j < n_dims; ++j)
727
                {
728
                    idx[tuned_axis]   = j;
Shucai Xiao's avatar
Shucai Xiao committed
729
730
                    std::size_t index = output_shape.index(idx);
                    output[index]     = std::exp(input[index] - batch_max[i]);
731
                }
Khalique's avatar
Khalique committed
732

Shucai Xiao's avatar
Shucai Xiao committed
733
                for(std::size_t j = 0; j < n_dims; ++j)
734
                {
735
                    idx[tuned_axis] = j;
736
737
                    batch_sum[i] += output(idx.begin(), idx.end());
                }
Khalique's avatar
Khalique committed
738

Shucai Xiao's avatar
Shucai Xiao committed
739
                for(std::size_t j = 0; j < n_dims; ++j)
740
                {
741
                    idx[tuned_axis] = j;
742
743
                    output(idx.begin(), idx.end()) =
                        op.output()(output(idx.begin(), idx.end()), batch_sum[i]);
744
                }
Shucai Xiao's avatar
Shucai Xiao committed
745
746
747
748
749
750
751
            });
        });

        return result;
    }
};

Shucai Xiao's avatar
Shucai Xiao committed
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
struct cpu_rnn_var_sl_last_output
{
    op::rnn_var_sl_last_output op;

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return migraphx::reflect(self.op, f);
    }

    std::string name() const { return "cpu::rnn_var_sl_last_output"; }

    shape compute_shape(std::vector<shape> inputs) const
    {
        return op.compute_shape(std::move(inputs));
    }

    argument compute(const shape& output_shape, std::vector<argument> args) const
    {
        argument result{output_shape};
        auto out_comp_lens = args[0].get_shape().lens();
        out_comp_lens[0]   = 1;
        shape out_comp_s{output_shape.type(), out_comp_lens};

        visit_all(result, args[0])([&](auto output, auto input) {
            args[1].visit([&](auto seq_lens) {
                par_for(output_shape.elements(), [&](auto i) {
                    auto idx = out_comp_s.multi(i);
                    auto b   = idx[2];
                    if(op.direction == op::rnn_direction::reverse or idx[1] == 1)
                    {
                        idx[0] = 0;
                    }
                    else
                    {
                        idx[0] = seq_lens[b] - 1;
                    }
                    output[i] = input(idx.begin(), idx.end());
                });
            });
        });

        return result;
    }
};

Paul's avatar
Paul committed
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
struct cpu_apply
{
    program* prog;
    std::unordered_map<std::string, std::function<void(instruction_ref)>> apply_map{};

    template <class T>
    auto simple_op()
    {
        return [this](instruction_ref ins) { apply_simple_op<T>(ins); };
    }

    template <class T, class Op>
    auto extend_op()
    {
        return [this](instruction_ref ins) { apply_extend_op<T, Op>(ins); };
    }

    void init()
    {
Aditya Atluri's avatar
Aditya Atluri committed
817
        apply_map["batch_norm_inference"] =
818
            extend_op<cpu_batch_norm_inference, op::batch_norm_inference>();
819
        apply_map["convolution"] = extend_op<cpu_convolution<op::convolution>, op::convolution>();
kahmed10's avatar
kahmed10 committed
820
821
822
823
        apply_map["deconvolution"] =
            extend_op<cpu_deconvolution<op::deconvolution>, op::deconvolution>();
        apply_map["dot"]       = extend_op<cpu_gemm, op::dot>();
        apply_map["quant_dot"] = extend_op<cpu_quant_gemm, op::quant_dot>();
824
825
826
827
828
829
830
831
832
        apply_map["quant_convolution"] =
            extend_op<cpu_convolution<op::quant_convolution>, op::quant_convolution>();
        apply_map["elu"]        = extend_op<cpu_unary<elu_op>, op::elu>();
        apply_map["im2col"]     = extend_op<cpu_im2col, op::im2col>();
        apply_map["leaky_relu"] = extend_op<cpu_unary<leaky_relu_op>, op::leaky_relu>();
        apply_map["logsoftmax"] = extend_op<cpu_softmax<op::logsoftmax>, op::logsoftmax>();
        apply_map["lrn"]        = extend_op<cpu_lrn, op::lrn>();
        apply_map["pad"]        = extend_op<cpu_pad, op::pad>();
        apply_map["softmax"]    = extend_op<cpu_softmax<op::softmax>, op::softmax>();
Shucai Xiao's avatar
Shucai Xiao committed
833
834
        apply_map["rnn_var_sl_last_output"] =
            extend_op<cpu_rnn_var_sl_last_output, op::rnn_var_sl_last_output>();
Paul's avatar
Paul committed
835
836
837
838
839
840
841
    }

    void apply()
    {
        init();
        for(auto it : iterator_for(*prog))
        {
Khalique's avatar
Khalique committed
842
            if(it->name() == "pooling")
Paul's avatar
Paul committed
843
844
845
            {
                apply_pooling(it);
            }
Paul's avatar
Paul committed
846
            else if(apply_map.count(it->name()) > 0)
Paul's avatar
Paul committed
847
            {
Paul's avatar
Paul committed
848
                apply_map.at(it->name())(it);
Paul's avatar
Paul committed
849
            }
Paul's avatar
Paul committed
850
            else if(is_context_free(it->get_operator()))
851
852
853
            {
                apply_cpu_op(it);
            }
Paul's avatar
Paul committed
854
855
856
        }
    }

857
858
859
860
861
    void apply_cpu_op(instruction_ref ins)
    {
        prog->replace_instruction(ins, cpu_op{ins->get_operator()}, ins->inputs());
    }

Paul's avatar
Paul committed
862
863
864
    template <class T>
    void apply_simple_op(instruction_ref ins)
    {
Paul's avatar
Paul committed
865
        prog->replace_instruction(ins, T{}, ins->inputs());
Paul's avatar
Paul committed
866
867
868
869
870
    }

    template <class T, class Op>
    void apply_extend_op(instruction_ref ins)
    {
871
        auto&& op = any_cast<Op>(ins->get_operator());
Paul's avatar
Paul committed
872
        prog->replace_instruction(ins, T{op}, ins->inputs());
Paul's avatar
Paul committed
873
874
875
876
    }

    void apply_pooling(instruction_ref ins)
    {
877
        auto&& op = any_cast<op::pooling>(ins->get_operator());
Paul's avatar
Paul committed
878
        if(op.mode == "max")
Paul's avatar
Paul committed
879
            prog->replace_instruction(ins, cpu_pooling<max_pool>{op}, ins->inputs());
Paul's avatar
Paul committed
880
        else if(op.mode == "average")
Paul's avatar
Paul committed
881
            prog->replace_instruction(ins, cpu_pooling<avg_pool>{op}, ins->inputs());
Paul's avatar
Paul committed
882
883
884
    }
};

Shucai Xiao's avatar
Shucai Xiao committed
885
void lowering::apply(program& p) const { cpu_apply{&p}.apply(); }
Paul's avatar
Paul committed
886
887

} // namespace cpu
Paul's avatar
Paul committed
888
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
889
} // namespace migraphx