operators.hpp 41.2 KB
Newer Older
Paul's avatar
Paul committed
1
2
#ifndef MIGRAPHX_GUARD_OPERATORS_HPP
#define MIGRAPHX_GUARD_OPERATORS_HPP
Paul's avatar
Paul committed
3

4
#include <array>
Paul's avatar
Paul committed
5
6
7
8
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
9
10
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
Paul's avatar
Paul committed
11
#include <migraphx/config.hpp>
Paul's avatar
Paul committed
12
#include <cmath>
Paul's avatar
Paul committed
13
#include <utility>
Paul's avatar
Paul committed
14

Paul's avatar
Paul committed
15
namespace migraphx {
Paul's avatar
Paul committed
16
inline namespace MIGRAPHX_INLINE_NS {
17
namespace op {
Paul's avatar
Paul committed
18

19
20
21
22
23
24
25
enum padding_mode_t
{
    default_, // NOLINT
    same,
    valid
};

Paul's avatar
Paul committed
26
27
struct not_computable
{
Paul's avatar
Paul committed
28
    argument compute(const shape&, const std::vector<argument>&) const
Paul's avatar
Paul committed
29
    {
Paul's avatar
Paul committed
30
        MIGRAPHX_THROW("not computable");
Paul's avatar
Paul committed
31
    }
Paul's avatar
Paul committed
32
33
};

34
35
struct batch_norm_inference
{
36
37
    float epsilon  = 1.0e-6f;
    float momentum = 0.9f;
38
39
40

    std::string name() const { return "batch_norm_inference"; }

41
42
43
44
45
46
47
48
    enum bn_infer_mode_t
    {
        per_activation,
        spatial,
    };

    bn_infer_mode_t bn_mode = spatial;

Paul's avatar
Paul committed
49
50
51
52
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(
Paul's avatar
Paul committed
53
            f(self.epsilon, "epsilon"), f(self.momentum, "momentum"), f(self.bn_mode, "bn_mode"));
Paul's avatar
Paul committed
54
    }
55

56
57
58
59
60
61
62
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(5);
        return inputs.front();
    }
};

Khalique's avatar
Khalique committed
63
struct lrn
Khalique's avatar
Khalique committed
64
65
{
    float alpha = 0.0001;
Khalique's avatar
Khalique committed
66
67
    float beta  = 0.75;
    float bias  = 1.0;
Khalique's avatar
Khalique committed
68
    int size    = 1;
Khalique's avatar
Khalique committed
69
    std::string name() const { return "lrn"; }
Khalique's avatar
Khalique committed
70

Khalique's avatar
Khalique committed
71
72
73
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Khalique's avatar
Khalique committed
74
75
76
77
        return pack(f(self.alpha, "alpha"),
                    f(self.beta, "beta"),
                    f(self.bias, "bias"),
                    f(self.size, "size"));
Khalique's avatar
Khalique committed
78
79
    }

Khalique's avatar
Khalique committed
80
81
82
83
84
85
86
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(1);
        return inputs.front();
    }
};

Paul's avatar
Paul committed
87
struct convolution
Paul's avatar
Paul committed
88
{
Paul's avatar
Paul committed
89
90
91
    std::array<std::size_t, 2> padding  = {{0, 0}};
    std::array<std::size_t, 2> stride   = {{1, 1}};
    std::array<std::size_t, 2> dilation = {{1, 1}};
Khalique's avatar
Khalique committed
92

Paul's avatar
Paul committed
93
    padding_mode_t padding_mode = default_;
Khalique's avatar
Khalique committed
94
    int group                   = 1;
Paul's avatar
Paul committed
95
96
97
98

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
99
100
101
        return pack(f(self.padding, "padding"),
                    f(self.stride, "stride"),
                    f(self.dilation, "dilation"),
Khalique's avatar
Khalique committed
102
103
                    f(self.padding_mode, "padding_mode"),
                    f(self.group, "group"));
Paul's avatar
Paul committed
104
105
    }

Paul's avatar
Paul committed
106
    std::string name() const { return "convolution"; }
Paul's avatar
Paul committed
107
108
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
109
        check_shapes{inputs, *this}.has(2).same_type().same_ndims().only_dims(4);
Paul's avatar
Paul committed
110

Paul's avatar
Paul committed
111
        const shape& input   = inputs.at(0);
Paul's avatar
Paul committed
112
        const shape& weights = inputs.at(1);
Paul's avatar
Paul committed
113
        auto t               = input.type();
Paul's avatar
Paul committed
114
115
        if(padding_mode == default_)
        {
Paul's avatar
Paul committed
116
117
118
            return {t,
                    {
                        input.lens()[0],
Khalique's avatar
Khalique committed
119
                        weights.lens()[0],
Paul's avatar
Paul committed
120
121
122
123
124
125
126
127
128
129
130
131
132
                        std::size_t(std::max<std::ptrdiff_t>(
                            1,
                            (input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) +
                             2 * padding[0]) /
                                    stride[0] +
                                1)),
                        std::size_t(std::max<std::ptrdiff_t>(
                            1,
                            (input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) +
                             2 * padding[1]) /
                                    stride[1] +
                                1)),
                    }};
Paul's avatar
Paul committed
133
134
135
136
137
        }
        else if(padding_mode == same)
        {
            return {t,
                    {input.lens()[0],
Khalique's avatar
Khalique committed
138
                     weights.lens()[0],
Paul's avatar
Paul committed
139
140
141
142
143
144
145
146
147
148
                     static_cast<std::size_t>(
                         std::ceil(static_cast<double>(input.lens()[2]) / stride[0])),
                     static_cast<std::size_t>(
                         std::ceil(static_cast<double>(input.lens()[3]) / stride[1]))}};
        }
        else if(padding_mode == valid)
        {
            return {
                t,
                {input.lens()[0],
Khalique's avatar
Khalique committed
149
                 weights.lens()[0],
Paul's avatar
Paul committed
150
151
152
153
154
155
156
                 static_cast<std::size_t>(std::ceil(
                     static_cast<double>(input.lens()[2] - weights.lens()[2] + 1) / stride[0])),
                 static_cast<std::size_t>(std::ceil(
                     static_cast<double>(input.lens()[3] - weights.lens()[3] + 1) / stride[1]))}};
        }
        else
        {
Paul's avatar
Paul committed
157
            MIGRAPHX_THROW("Invalid padding mode");
Paul's avatar
Paul committed
158
        }
Paul's avatar
Paul committed
159
160
161
    }
};

Scott Thornton's avatar
Scott Thornton committed
162
163
struct im2col
{
Scott Thornton's avatar
Scott Thornton committed
164
165
166
    std::array<std::size_t, 2> padding  = {{0, 0}};
    std::array<std::size_t, 2> stride   = {{1, 1}};
    std::array<std::size_t, 2> dilation = {{1, 1}};
Khalique's avatar
Khalique committed
167

Paul's avatar
Paul committed
168
169
170
171
172
    padding_mode_t padding_mode = default_;

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
173
174
175
176
        return pack(f(self.padding, "padding"),
                    f(self.stride, "stride"),
                    f(self.dilation, "dilation"),
                    f(self.padding_mode, "padding_mode"));
Paul's avatar
Paul committed
177
    }
Scott Thornton's avatar
Scott Thornton committed
178
179
180
181
182

    std::string name() const { return "im2col"; }

    shape compute_shape(std::vector<shape> inputs) const
    {
Scott Thornton's avatar
Scott Thornton committed
183
184
185
        auto input          = inputs[0];
        auto weights        = inputs[1];
        auto batch_size     = input.lens()[0];
Scott Thornton's avatar
Scott Thornton committed
186
        auto input_channels = weights.lens()[1];
Scott Thornton's avatar
Scott Thornton committed
187
188
        auto kernel_height  = weights.lens()[2];
        auto kernel_width   = weights.lens()[3];
Scott Thornton's avatar
Scott Thornton committed
189
        check_shapes{inputs, *this}.has(2);
Scott Thornton's avatar
Scott Thornton committed
190
        if(batch_size != 1)
Paul's avatar
Paul committed
191
            MIGRAPHX_THROW("im2col only support batch_size 1");
Scott Thornton's avatar
Scott Thornton committed
192
        auto output_height = std::size_t(std::max<std::ptrdiff_t>(
Scott Thornton's avatar
Scott Thornton committed
193
194
195
            1,
            (input.lens()[2] - (1 + dilation[0] * (kernel_height - 1)) + 2 * padding[0]) /
                    stride[0] +
Scott Thornton's avatar
Scott Thornton committed
196
                1));
Scott Thornton's avatar
Scott Thornton committed
197
198
199
200
        auto output_width  = std::size_t(std::max<std::ptrdiff_t>(
            1,
            (input.lens()[3] - (1 + dilation[1] * (kernel_width - 1)) + 2 * padding[1]) /
                    stride[1] +
Scott Thornton's avatar
Scott Thornton committed
201
                1));
Scott Thornton's avatar
Scott Thornton committed
202
203
        auto channels_col  = kernel_height * kernel_width * input_channels;
        return {input.type(), {output_height * output_width, channels_col}};
Scott Thornton's avatar
Scott Thornton committed
204
205
206
    }
};

Paul's avatar
Paul committed
207
struct pooling
Paul's avatar
Paul committed
208
{
Paul's avatar
Paul committed
209
    std::string mode                   = "average";
Paul's avatar
Paul committed
210
211
212
    std::array<std::size_t, 2> padding = {{0, 0}};
    std::array<std::size_t, 2> stride  = {{1, 1}};
    std::array<std::size_t, 2> lengths = {{1, 1}};
Khalique's avatar
Khalique committed
213
    padding_mode_t padding_mode        = default_;
Paul's avatar
Paul committed
214
215
216
217

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
218
219
        return pack(f(self.mode, "mode"),
                    f(self.padding, "padding"),
220
                    f(self.padding, "padding_mode"),
Paul's avatar
Paul committed
221
222
                    f(self.stride, "stride"),
                    f(self.lengths, "lengths"));
Paul's avatar
Paul committed
223
224
    }

Paul's avatar
Paul committed
225
    std::string name() const { return "pooling"; }
Scott Thornton's avatar
Scott Thornton committed
226

Paul's avatar
Paul committed
227
228
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
229
        check_shapes{inputs, *this}.has(1).only_dims(4);
Paul's avatar
Paul committed
230

Paul's avatar
Paul committed
231
        const shape& input = inputs.at(0);
Paul's avatar
Paul committed
232
        auto t             = input.type();
Paul's avatar
Paul committed
233

Paul's avatar
Paul committed
234
235
        assert(lengths[0] <= (input.lens()[2] + 2 * padding[0]));
        assert(lengths[1] <= (input.lens()[3] + 2 * padding[1]));
Paul's avatar
Paul committed
236

237
238
        if(padding_mode == default_)
        {
Khalique's avatar
Khalique committed
239
240
            return {
                t,
Scott Thornton's avatar
Scott Thornton committed
241
242
243
244
245
                {
                    input.lens()[0],
                    input.lens()[1],
                    std::size_t(std::max<std::ptrdiff_t>(
                        1,
Paul's avatar
Paul committed
246
                        std::ptrdiff_t(std::floor((input.lens()[2] + 2 * padding[0] - lengths[0]) /
Paul's avatar
Paul committed
247
                                                  static_cast<float>(stride[0]))) +
Scott Thornton's avatar
Scott Thornton committed
248
249
250
                            1)),
                    std::size_t(std::max<std::ptrdiff_t>(
                        1,
Paul's avatar
Paul committed
251
                        std::ptrdiff_t(std::floor((input.lens()[3] + 2 * padding[1] - lengths[1]) /
Paul's avatar
Paul committed
252
                                                  static_cast<float>(stride[1]))) +
Scott Thornton's avatar
Scott Thornton committed
253
254
                            1)),
                }};
255
256
257
258
259
260
261
262
263
264
265
266
267
268
        }
        else if(padding_mode == same)
        {
            return {t,
                    {input.lens()[0],
                     input.lens()[1],
                     static_cast<std::size_t>(
                         std::ceil(static_cast<double>(input.lens()[2]) / stride[0])),
                     static_cast<std::size_t>(
                         std::ceil(static_cast<double>(input.lens()[3]) / stride[1]))}};
        }
        else if(padding_mode == valid)
        {
            return {t,
Khalique's avatar
Khalique committed
269
270
271
272
273
274
275
276
277
278
279
280
281
282
                    {
                        input.lens()[0],
                        input.lens()[1],
                        std::size_t(std::max<std::ptrdiff_t>(
                            1,
                            std::ptrdiff_t(std::floor((input.lens()[2] - lengths[0]) /
                                                      static_cast<float>(stride[0]))) +
                                1)),
                        std::size_t(std::max<std::ptrdiff_t>(
                            1,
                            std::ptrdiff_t(std::floor((input.lens()[3] - lengths[1]) /
                                                      static_cast<float>(stride[1]))) +
                                1)),
                    }};
283
284
285
286
287
        }
        else
        {
            MIGRAPHX_THROW("Invalid padding mode");
        }
Paul's avatar
Paul committed
288
289
290
    }
};

Khalique's avatar
Khalique committed
291
292
293
294
295
296
297
298
299
struct leaky_relu
{
    std::string name() const { return "leaky_relu"; }
    float alpha;
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(1);
        return inputs.front();
    }
Khalique's avatar
Khalique committed
300
301
302

    template <class Self, class F>
    static auto reflect(Self& self, F f)
Khalique's avatar
Khalique committed
303
    {
Khalique's avatar
Khalique committed
304
        return pack(f(self.alpha, "alpha"));
Khalique's avatar
Khalique committed
305
306
307
308
309
310
311
312
313
314
315
316
    }
};

struct elu
{
    std::string name() const { return "elu"; }
    float alpha;
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(1);
        return inputs.front();
    }
Khalique's avatar
Khalique committed
317
318
319

    template <class Self, class F>
    static auto reflect(Self& self, F f)
Khalique's avatar
Khalique committed
320
    {
Khalique's avatar
Khalique committed
321
        return pack(f(self.alpha, "alpha"));
Khalique's avatar
Khalique committed
322
    }
Khalique's avatar
Khalique committed
323
324
};

325
326
327
struct transpose
{
    std::vector<int64_t> dims;
Paul's avatar
Paul committed
328
329
330
331

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
332
        return pack(f(self.dims, "dims"));
Paul's avatar
Paul committed
333
334
    }

335
336
337
    std::string name() const { return "transpose"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
338
        check_shapes{inputs, *this}.has(1);
339
        auto input         = inputs.at(0);
340
        auto input_lens    = input.lens();
341
342
        auto input_strides = input.strides();
        auto t             = input.type();
Paul's avatar
Paul committed
343
344
        if(dims.size() != input_lens.size())
        {
Paul's avatar
Paul committed
345
            MIGRAPHX_THROW("Permutation has wrong number of axes");
346
347
348
        }
        std::vector<int64_t> axes(dims.size());
        std::iota(axes.begin(), axes.end(), 0);
Paul's avatar
Paul committed
349
350
        if(!std::is_permutation(axes.begin(), axes.end(), dims.begin()))
        {
Paul's avatar
Paul committed
351
            MIGRAPHX_THROW("Invalid permutation");
352
        }
353
354
        std::vector<size_t> output_lens(input_lens.size());
        std::vector<size_t> output_strides(input_lens.size());
Paul's avatar
Paul committed
355
        for(std::size_t i = 0; i < output_lens.size(); i++)
Paul's avatar
Paul committed
356
357
        {
            output_lens[i]    = input_lens[dims[i]];
358
359
            output_strides[i] = input_strides[dims[i]];
        }
360
        return {t, output_lens, output_strides};
361
    }
Paul's avatar
Paul committed
362
    argument compute(shape output_shape, std::vector<argument> args) const
Paul's avatar
Paul committed
363
    {
Paul's avatar
Paul committed
364
        return {std::move(output_shape), std::move(args.front().data)};
Paul's avatar
Paul committed
365
    }
Paul's avatar
Paul committed
366
    int output_alias(const std::vector<shape>&) const { return 0; }
367
368
};

wsttiger's avatar
fixes  
wsttiger committed
369
370
371
372
373
374
/// The contiguous operator takes a non-standard input tensor and returns
/// the same tensor but in standard form. For example, if input tensor A which has lens = (4,5)
/// is first transposed, i.e. lens = (5,4), this tensor's data layout remained the same
/// during the transpose operation; only it's shape lengths and strides were changed.
/// This leaves the tensor in a non-standard form. The contiguous operator copies the
/// underlying data such that resulting tensor is returned to a standard form.
Paul's avatar
Paul committed
375
struct contiguous
376
377
378
379
{
    std::string name() const { return "contiguous"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
380
        check_shapes{inputs, *this}.has(1);
Paul's avatar
Paul committed
381
382
        auto lens = inputs.at(0).lens();
        auto t    = inputs.at(0).type();
383
384
        return {t, lens};
    }
Paul's avatar
Paul committed
385
386
387
388
389
390
391
392
393
394
395
    argument compute(const shape& output_shape, std::vector<argument> args) const
    {
        assert(output_shape.standard());
        argument result{output_shape};
        visit_all(result, args[0])([&](auto output, auto input) {
            shape_for_each(output.get_shape(), [&](const auto& idx) {
                output(idx.begin(), idx.end()) = input(idx.begin(), idx.end());
            });
        });
        return result;
    }
396
397
};

398
399
400
401
struct concat
{
    std::size_t axis = 0;
    std::string name() const { return "concat"; }
402
    std::vector<std::size_t> compute_offsets(const shape& output_shape,
Paul's avatar
Paul committed
403
                                             const std::vector<argument>& args) const
404
405
406
407
408
409
410
411
412
413
414
    {
        std::vector<std::size_t> offsets;
        std::vector<std::size_t> offset(args[0].get_shape().lens().size(), 0);
        offset[axis] = 0;
        for(const auto& arg : args)
        {
            offsets.push_back(output_shape.index(offset));
            offset[axis] += arg.get_shape().lens()[axis];
        }
        return offsets;
    }
415
416
    shape compute_shape(std::vector<shape> inputs) const
    {
Scott Thornton's avatar
Scott Thornton committed
417
        if(inputs.empty())
418
        {
Paul's avatar
Paul committed
419
            MIGRAPHX_THROW("Number of input tensors should exceed 0");
420
421
422
        }

        const auto& first_shape_lens = inputs.front().lens();
Scott Thornton's avatar
Scott Thornton committed
423
424
425
426
427
428
429
430
431
        const auto& type             = inputs.front().type();
        for(std::size_t l = 0; l < first_shape_lens.size(); l++)
        {
            if(l != axis)
            {
                if(!std::all_of(inputs.begin(), inputs.end(), [&](auto s) {
                       return s.lens()[l] == first_shape_lens[l];
                   }))
                {
Paul's avatar
Paul committed
432
                    MIGRAPHX_THROW("Non-axis dimensions should match");
433
434
435
436
                }
            }
        }
        std::size_t new_dim_axis = 0;
Scott Thornton's avatar
Scott Thornton committed
437
        for(const auto& input : inputs)
438
439
440
441
442
443
444
445
446
        {
            const auto& lens = input.lens();
            new_dim_axis += lens[axis];
        }
        std::vector<std::size_t> new_lens;
        std::copy(first_shape_lens.begin(), first_shape_lens.end(), std::back_inserter(new_lens));
        new_lens[axis] = new_dim_axis;
        return {type, new_lens};
    }
Paul's avatar
Paul committed
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
    argument compute(const shape& output_shape, std::vector<argument> args) const
    {
        argument result{output_shape};
        std::vector<std::size_t> coffsets = compute_offsets(output_shape, args);
        for(std::size_t l = 0; l < args.size(); l++)
        {
            auto argl             = args[l];
            std::size_t nelements = argl.get_shape().elements();
            visit_all(result, argl)([&](auto output, auto input) {
                auto slice_shape =
                    shape{output_shape.type(), input.get_shape().lens(), output_shape.strides()};
                auto slice = make_view(slice_shape, output.data() + coffsets[l]);
                // cppcheck-suppress useStlAlgorithm
                for(std::size_t i = 0; i < nelements; i++)
                {
                    slice[i] = input[i];
                }
            });
        }
        return result;
    }
468
469
};

470
471
472
473
474
struct slice
{
    std::vector<int64_t> axes;
    std::vector<int64_t> starts;
    std::vector<int64_t> ends;
Paul's avatar
Paul committed
475
476
477
478

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
479
        return pack(f(self.axes, "axes"), f(self.starts, "starts"), f(self.ends, "ends"));
Paul's avatar
Paul committed
480
481
    }

482
    std::string name() const { return "slice"; }
Scott Thornton's avatar
Scott Thornton committed
483
484

    auto fix_index(const std::vector<std::size_t>& lens, std::size_t axis, int64_t index) const
485
    {
Scott Thornton's avatar
Scott Thornton committed
486
        int64_t r = std::min(index, static_cast<int64_t>(lens[axis]));
Scott Thornton's avatar
Scott Thornton committed
487
488
        if(r < 0)
            r += lens[axis];
Scott Thornton's avatar
Scott Thornton committed
489
        return std::size_t(r);
Scott Thornton's avatar
Scott Thornton committed
490
491
492
493
494
495
496
    }

    auto compute_offset(const shape& s) const
    {
        const std::vector<std::size_t>& lens    = s.lens();
        const std::vector<std::size_t>& strides = s.strides();
        auto offset                             = 0;
Scott Thornton's avatar
Scott Thornton committed
497
        if(!axes.empty())
Scott Thornton's avatar
Scott Thornton committed
498
        {
Scott Thornton's avatar
Scott Thornton committed
499
500
501
502
503
            for(std::size_t i = 0; i < axes.size(); i++)
            {
                auto axis = axes[i];
                offset += fix_index(lens, axis, starts[i]) * strides[axis];
            }
504
        }
Scott Thornton's avatar
Scott Thornton committed
505
506
        else
        {
Scott Thornton's avatar
Scott Thornton committed
507
508
509
510
            for(std::size_t axis = 0; axis < lens.size(); axis++)
            {
                offset += fix_index(lens, axis, starts[axis]) * strides[axis];
            }
511
        }
Scott Thornton's avatar
Scott Thornton committed
512
513
514
515
516
        return offset;
    }

    shape compute_shape(std::vector<shape> inputs) const
    {
Scott Thornton's avatar
Scott Thornton committed
517
518
519
520
        auto input_shape        = inputs[0];
        auto t                  = input_shape.type();
        const auto& old_lens    = input_shape.lens();
        const auto& old_strides = input_shape.strides();
Scott Thornton's avatar
Scott Thornton committed
521
        if(starts.size() != axes.size() || axes.size() != ends.size())
Scott Thornton's avatar
Scott Thornton committed
522
        {
Paul's avatar
Paul committed
523
            MIGRAPHX_THROW("inconsistent sizes");
524
        }
Scott Thornton's avatar
Scott Thornton committed
525
526
        std::vector<std::size_t> new_lens = old_lens;
        for(std::size_t i = 0; i < axes.size(); i++)
Scott Thornton's avatar
Scott Thornton committed
527
        {
Scott Thornton's avatar
Scott Thornton committed
528
529
530
            auto axis = axes[i];
            new_lens[axis] =
                fix_index(old_lens, axis, ends[i]) - fix_index(old_lens, axis, starts[i]);
531
532
533
        }
        return shape{t, new_lens, old_strides};
    }
Paul's avatar
Paul committed
534
    argument compute(shape output_shape, std::vector<argument> args) const
535
    {
Scott Thornton's avatar
Scott Thornton committed
536
537
538
        auto input  = args[0];
        auto offset = compute_offset(input.get_shape()) * output_shape.type_size();
        return {std::move(output_shape), [=] { return input.data() + offset; }};
539
    }
Paul's avatar
Paul committed
540
    int output_alias(const std::vector<shape>&) const { return 0; }
541
542
543
544
545
};

struct squeeze
{
    std::vector<int64_t> axes;
Paul's avatar
Paul committed
546
547
548
549

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
550
        return pack(f(self.axes, "axes"));
Paul's avatar
Paul committed
551
552
    }

553
554
555
556
    std::string name() const { return "squeeze"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        auto input_shape = inputs[0];
Scott Thornton's avatar
Scott Thornton committed
557
558
        auto type        = input_shape.type();
        auto old_lens    = input_shape.lens();
wsttiger's avatar
wsttiger committed
559
560
        if(std::any_of(
               axes.begin(), axes.end(), [&](auto axis) { return input_shape.lens()[axis] != 1; }))
Scott Thornton's avatar
Scott Thornton committed
561
        {
Paul's avatar
Paul committed
562
            MIGRAPHX_THROW("squeeze axis dimension should be equal to 1");
563
564
        }
        std::vector<std::size_t> new_lens;
Scott Thornton's avatar
Scott Thornton committed
565
        if(axes.empty())
Scott Thornton's avatar
Scott Thornton committed
566
        {
wsttiger's avatar
wsttiger committed
567
568
569
570
            std::copy_if(old_lens.begin(),
                         old_lens.end(),
                         std::back_inserter(new_lens),
                         [](auto len) { return len != 1; });
571
        }
Scott Thornton's avatar
Scott Thornton committed
572
573
574
575
576
577
        else
        {
            for(std::size_t i = 0; i < old_lens.size(); i++)
            {
                if(std::find(axes.begin(), axes.end(), i) == axes.end())
                {
578
579
580
581
582
583
                    new_lens.push_back(old_lens[i]);
                }
            }
        }
        return shape{type, new_lens};
    }
Paul's avatar
Paul committed
584
    argument compute(shape output_shape, std::vector<argument> args) const
585
586
    {
        return {std::move(output_shape), std::move(args.front().data)};
Scott Thornton's avatar
Scott Thornton committed
587
    }
Paul's avatar
Paul committed
588
    int output_alias(const std::vector<shape>&) const { return 0; }
589
590
591
592
593
};

struct unsqueeze
{
    std::vector<int64_t> axes;
Paul's avatar
Paul committed
594
595
596
597

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
598
        return pack(f(self.axes, "axes"));
Paul's avatar
Paul committed
599
600
    }

601
602
603
    std::string name() const { return "unsqueeze"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Scott Thornton's avatar
Scott Thornton committed
604
605
606
        auto input_shape     = inputs[0];
        auto type            = input_shape.type();
        auto old_lens        = input_shape.lens();
607
608
609
        std::size_t new_size = old_lens.size() + axes.size();
        std::vector<std::size_t> new_lens(new_size);
        std::size_t p = 0;
Scott Thornton's avatar
Scott Thornton committed
610
611
612
613
        for(std::size_t i = 0; i < new_size; i++)
        {
            if(std::find(axes.begin(), axes.end(), i) != axes.end())
            {
614
                new_lens[i] = 1;
Scott Thornton's avatar
Scott Thornton committed
615
616
617
            }
            else
            {
618
619
620
621
622
                new_lens[i] = old_lens[p++];
            }
        }
        return shape{type, new_lens};
    }
Paul's avatar
Paul committed
623
    argument compute(shape output_shape, std::vector<argument> args) const
624
625
626
    {
        return {std::move(output_shape), std::move(args.front().data)};
    }
Paul's avatar
Paul committed
627
    int output_alias(const std::vector<shape>&) const { return 0; }
628
629
};

Paul's avatar
Paul committed
630
631
632
struct reshape
{
    std::vector<int64_t> dims;
Paul's avatar
Paul committed
633
634
635
636

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
637
        return pack(f(self.dims, "dims"));
Paul's avatar
Paul committed
638
639
    }

Paul's avatar
Paul committed
640
    std::string name() const { return "reshape"; }
Paul's avatar
Paul committed
641
642
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
643
        check_shapes{inputs, *this}.has(1);
Paul's avatar
Paul committed
644
645
        auto&& idims = inputs.front().lens();
        std::vector<std::size_t> rdims(dims.begin(), dims.end());
646
647
        auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
        if(n_neg_dims > 1)
Paul's avatar
Paul committed
648
            MIGRAPHX_THROW("Dimensions for reshape can only have one -1 dim");
Paul's avatar
Paul committed
649
        for(std::size_t i = 0; i < dims.size(); i++)
Paul's avatar
Paul committed
650
651
652
        {
            if(dims[i] == 0)
                rdims[i] = idims[i];
Shucai Xiao's avatar
Shucai Xiao committed
653
654
655

            // since rdims using size_t type, -1 is the max value
            // is size_t that cause later compuation incorrect
Shucai Xiao's avatar
Shucai Xiao committed
656
            if(dims[i] == -1)
Shucai Xiao's avatar
Shucai Xiao committed
657
                rdims[i] = 1;
Paul's avatar
Paul committed
658
        }
659
660
661
        if(n_neg_dims > 0)
        {
            size_t missing_dim =
Shucai Xiao's avatar
Shucai Xiao committed
662
                inputs.front().elements() /
663
664
665
666
667
668
669
                std::accumulate(rdims.begin(), rdims.end(), 1, std::multiplies<int64_t>());
            for(std::size_t i = 0; i < rdims.size(); i++)
            {
                if(dims[i] == -1)
                    rdims[i] = missing_dim;
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
670

Scott Thornton's avatar
Scott Thornton committed
671
        shape s{inputs.front().type(), rdims};
Paul's avatar
Paul committed
672
        if(s.elements() != inputs.front().elements())
Paul's avatar
Paul committed
673
            MIGRAPHX_THROW("Wrong number of elements for reshape");
Scott Thornton's avatar
Scott Thornton committed
674
        return s;
Paul's avatar
Paul committed
675
    }
Paul's avatar
Paul committed
676
    argument compute(shape output_shape, std::vector<argument> args) const
Paul's avatar
Paul committed
677
    {
Paul's avatar
Paul committed
678
        return {std::move(output_shape), std::move(args.front().data)};
Paul's avatar
Paul committed
679
    }
Paul's avatar
Paul committed
680
    int output_alias(const std::vector<shape>&) const { return 0; }
Paul's avatar
Paul committed
681
682
};

Khalique's avatar
Khalique committed
683
684
685
struct pad
{
    std::vector<int64_t> pads;
Khalique's avatar
Khalique committed
686
    float value = 0.0f;
Khalique's avatar
Khalique committed
687
    enum pad_op_mode_t
Khalique's avatar
Khalique committed
688
    {
Khalique's avatar
Khalique committed
689
        constant_pad,
Khalique's avatar
Khalique committed
690
        reflect_pad,
Khalique's avatar
Khalique committed
691
        edge_pad
Khalique's avatar
Khalique committed
692
    };
Khalique's avatar
Khalique committed
693
    pad_op_mode_t mode = constant_pad;
Khalique's avatar
Khalique committed
694
695
696
697
698
699
700
701
702
703
704
705
706

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(f(self.mode, "mode"), f(self.pads, "pads"), f(self.value, "value"));
    }

    std::string name() const { return "pad"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(1);
        auto&& idims = inputs.front().lens();
        std::vector<std::size_t> rdims(idims.begin(), idims.end());
Khalique's avatar
Khalique committed
707
        std::size_t num_dims = rdims.size();
Khalique's avatar
Khalique committed
708

Khalique's avatar
Khalique committed
709
        for(std::size_t i = 0; i < num_dims; i++)
Khalique's avatar
Khalique committed
710
711
712
        {
            rdims[i] += pads[i] + pads[i + num_dims];
        }
Khalique's avatar
Khalique committed
713

Khalique's avatar
Khalique committed
714
715
716
        shape s{inputs.front().type(), rdims};
        return s;
    }
717
718
719

    bool symmetric() const
    {
Khalique's avatar
Khalique committed
720
        std::size_t num_dims = pads.size() / 2;
721
        return std::equal(pads.begin(), pads.begin() + num_dims, pads.begin() + num_dims, pads.end());
722
    }
Khalique's avatar
Khalique committed
723
724
};

Paul's avatar
Paul committed
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
struct as_shape
{
    shape s;
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(f(self.s, "shape"));
    }

    std::string name() const { return "as_shape"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
        check_shapes{inputs, *this}.has(1).standard();
        assert(inputs.front().elements() == s.elements());
        return s;
    }
    argument compute(shape output_shape, std::vector<argument> args) const
    {
        return {std::move(output_shape), std::move(args.front().data)};
    }
    int output_alias(const std::vector<shape>&) const { return 0; }
};

748
749
struct gather
{
750
    int axis = 0;
751
752
753
754
755
756
    std::string name() const { return "gather"; }

    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(2);
        auto lens = inputs[0].lens();
757
758
        int n_dim = static_cast<int>(lens.size());
        if(axis >= n_dim || axis < -n_dim)
759
        {
760
            MIGRAPHX_THROW("Gather: axis is out of range.");
761
        }
762

763
        // negative axis means counting dimensions from back
764
        int axis_index = (axis < 0) ? (n_dim + axis) : axis;
765

Shucai Xiao's avatar
Shucai Xiao committed
766
        auto type = inputs[0].type();
767
        lens.erase(lens.begin() + axis_index);
Shucai Xiao's avatar
Shucai Xiao committed
768
        if(!inputs[1].scalar())
769
770
771
772
        {
            auto ind_lens = inputs[1].lens();
            lens.insert(lens.begin() + axis_index, ind_lens.begin(), ind_lens.end());
        }
773

774
        // for scalar output
775
        if(lens.empty())
776
        {
777
            return {type};
778
        }
779
780

        return {type, lens};
781
782
783
784
785
    }

    argument compute(const shape& output_shape, std::vector<argument> args) const
    {
        argument result{output_shape};
786
        // negative axis means counting dimensions from back
Shucai Xiao's avatar
Shucai Xiao committed
787
788
        int axis_index =
            (axis < 0) ? static_cast<int>(args[0].get_shape().lens().size() + axis) : axis;
789

790
        // max dimension in axis
791
        visit_all(result, args[0])([&](auto output, auto data) {
792
            args[1].visit([&](auto indices) {
Shucai Xiao's avatar
Shucai Xiao committed
793
                if(output_shape.scalar())
794
795
796
                {
                    output[0] = data[indices.front()];
                }
Shucai Xiao's avatar
Shucai Xiao committed
797
                else
798
                {
Shucai Xiao's avatar
Shucai Xiao committed
799
                    auto out_lens        = data.get_shape().lens();
800
801
802
                    out_lens[axis_index] = indices.get_shape().elements();
                    migraphx::shape out_comp_shape{data.get_shape().type(), out_lens};
                    shape_for_each(out_comp_shape, [&](const auto& out_idx) {
Shucai Xiao's avatar
Shucai Xiao committed
803
                        auto data_idx        = out_idx;
804
                        data_idx[axis_index] = indices[data_idx[axis_index]];
Shucai Xiao's avatar
Shucai Xiao committed
805
                        output[out_comp_shape.index(out_idx.begin(), out_idx.end())] =
806
807
808
                            data(data_idx.begin(), data_idx.end());
                    });
                }
809
810
811
812
813
814
815
            });
        });

        return result;
    }
};

Shucai Xiao's avatar
Shucai Xiao committed
816
struct dot
817
{
Paul's avatar
Paul committed
818
    float alpha = 1.0;
Paul's avatar
Paul committed
819
    float beta  = 0.0;
Paul's avatar
Paul committed
820
821
822
823

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
824
        return pack(f(self.alpha, "alpha"), f(self.beta, "beta"));
Paul's avatar
Paul committed
825
826
    }

Shucai Xiao's avatar
Shucai Xiao committed
827
    std::string name() const { return "dot"; }
828
829
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
830
        check_shapes{inputs, *this}.has(2).same_type();
831
832
        const shape& a = inputs.at(0);
        const shape& b = inputs.at(1);
Scott Thornton's avatar
Scott Thornton committed
833
        auto t         = a.type();
834

835
836
837
838
        // according to the specification of the numpy.matmul()
        // inputs with the shape dims more than 2 are acceptable
        // as long as dim values are the same in the two inputs
        if(!std::equal(a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2))
839
        {
840
            MIGRAPHX_THROW("DOT: dim values mismatch");
841
842
        }

843
844
845
        std::size_t dim_0 = a.lens().size() - 2;
        std::size_t dim_1 = a.lens().size() - 1;
        if(a.lens()[dim_1] != b.lens()[dim_0])
Paul's avatar
Paul committed
846
847
            MIGRAPHX_THROW("Inner dimensions do not match: {" + to_string_range(a.lens()) +
                           "} x {" + to_string_range(b.lens()) + "}");
Shucai Xiao's avatar
Shucai Xiao committed
848
        auto out_lens   = a.lens();
849
        out_lens[dim_1] = b.lens()[dim_1];
850
        return {t, out_lens};
851
852
853
    }
};

854
struct unary
Scott Thornton's avatar
Scott Thornton committed
855
{
856
857
    shape compute_shape(std::vector<shape> inputs) const
    {
858
859
        check_shapes{inputs}.has(1);
        return inputs.at(0);
860
    }
Scott Thornton's avatar
Scott Thornton committed
861
862
};

863
struct identity
864
{
865
    std::string name() const { return "identity"; }
Scott Thornton's avatar
Scott Thornton committed
866
    shape compute_shape(std::vector<shape> inputs) const { return inputs.at(0); }
Paul's avatar
Paul committed
867
    argument compute(shape output_shape, std::vector<argument> args) const
868
869
870
    {
        return {std::move(output_shape), std::move(args.at(0).data)};
    }
Paul's avatar
Paul committed
871
    int output_alias(const std::vector<shape>&) const { return 0; }
872
873
874
};

struct abs : unary
Scott Thornton's avatar
Scott Thornton committed
875
{
876
    std::string name() const { return "abs"; }
Scott Thornton's avatar
Scott Thornton committed
877
878
};

879
struct exp : unary
Scott Thornton's avatar
Scott Thornton committed
880
{
881
    std::string name() const { return "exp"; }
Scott Thornton's avatar
Scott Thornton committed
882
883
};

Shucai Xiao's avatar
Shucai Xiao committed
884
885
886
887
888
struct log : unary
{
    std::string name() const { return "log"; }
};

889
struct sin : unary
Scott Thornton's avatar
Scott Thornton committed
890
{
891
    std::string name() const { return "sin"; }
Scott Thornton's avatar
Scott Thornton committed
892
893
};

894
struct cos : unary
Scott Thornton's avatar
Scott Thornton committed
895
{
896
    std::string name() const { return "cos"; }
Scott Thornton's avatar
Scott Thornton committed
897
898
};

899
struct tan : unary
Scott Thornton's avatar
Scott Thornton committed
900
{
901
    std::string name() const { return "tan"; }
Scott Thornton's avatar
Scott Thornton committed
902
903
};

904
struct asin : unary
Scott Thornton's avatar
Scott Thornton committed
905
{
906
    std::string name() const { return "asin"; }
Scott Thornton's avatar
Scott Thornton committed
907
908
};

909
struct acos : unary
Scott Thornton's avatar
Scott Thornton committed
910
{
911
    std::string name() const { return "acos"; }
Scott Thornton's avatar
Scott Thornton committed
912
913
};

914
struct atan : unary
Scott Thornton's avatar
Scott Thornton committed
915
{
916
    std::string name() const { return "atan"; }
Scott Thornton's avatar
Scott Thornton committed
917
918
};

919
920
921
922
923
924
925
926
927
928
struct sinh : unary
{
    std::string name() const { return "sinh"; }
};

struct cosh : unary
{
    std::string name() const { return "cosh"; }
};

929
struct tanh : unary
Scott Thornton's avatar
Scott Thornton committed
930
{
931
    std::string name() const { return "tanh"; }
Scott Thornton's avatar
Scott Thornton committed
932
933
};

934
struct sigmoid : unary
Scott Thornton's avatar
Scott Thornton committed
935
{
936
    std::string name() const { return "sigmoid"; }
Scott Thornton's avatar
Scott Thornton committed
937
938
};

939
struct neg : unary
Scott Thornton's avatar
Scott Thornton committed
940
{
941
    std::string name() const { return "neg"; }
Scott Thornton's avatar
Scott Thornton committed
942
943
};

Khalique's avatar
Khalique committed
944
945
946
947
948
struct relu : unary
{
    std::string name() const { return "relu"; }
};

Paul's avatar
Paul committed
949
950
951
952
953
954
955
956
957
958
struct softmax
{
    std::string name() const { return "softmax"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs}.has(1).only_dims(4);
        return inputs.at(0);
    }
};

Shucai Xiao's avatar
Shucai Xiao committed
959
960
961
962
963
964
965
struct logsoftmax
{
    int axis = 1;
    std::string name() const { return "logsoftmax"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs}.has(1);
966
        if(axis < 0 || axis > inputs[0].lens().size())
Shucai Xiao's avatar
Shucai Xiao committed
967
        {
Shucai Xiao's avatar
Shucai Xiao committed
968
969
            MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) +
                           " is out of range");
Shucai Xiao's avatar
Shucai Xiao committed
970
971
972
973
974
        }
        return inputs.at(0);
    }
};

975
struct flatten
Scott Thornton's avatar
Scott Thornton committed
976
{
Paul's avatar
Paul committed
977
    uint64_t axis = 0;
Paul's avatar
Paul committed
978
979
980
981

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
982
        return pack(f(self.axis, "axis"));
Paul's avatar
Paul committed
983
984
    }

Scott Thornton's avatar
Scott Thornton committed
985
    std::string name() const { return "flatten"; }
Paul's avatar
Paul committed
986
987
988
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs}.has(1);
Paul's avatar
Paul committed
989
990
        auto&& lens = inputs.front().lens();

Paul's avatar
Paul committed
991
        if(axis > lens.size())
Paul's avatar
Paul committed
992
        {
Paul's avatar
Paul committed
993
            MIGRAPHX_THROW("axis for flatten must be less than tensor rank");
Paul's avatar
Paul committed
994
        }
Paul's avatar
Paul committed
995
996
997
998
        auto x =
            std::accumulate(lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{});
        auto y =
            std::accumulate(lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<>{});
999
        return {inputs.at(0).type(), {x, y}};
Paul's avatar
Paul committed
1000
    }
Paul's avatar
Paul committed
1001
    argument compute(shape output_shape, std::vector<argument> args) const
Paul's avatar
Paul committed
1002
    {
Paul's avatar
Paul committed
1003
        return {std::move(output_shape), std::move(args.front().data)};
Paul's avatar
Paul committed
1004
    }
Paul's avatar
Paul committed
1005
    int output_alias(const std::vector<shape>&) const { return 0; }
Scott Thornton's avatar
Scott Thornton committed
1006
};
1007

wsttiger's avatar
fixes  
wsttiger committed
1008
1009
1010
1011
1012
1013
1014
1015
/// The broadcast operator performs the numpy-style broadcasting of an axis of a given tensor. This
/// is achieved primarily by setting the stride of the broadcasted axis to zero. Linear indicies are
/// computed from multi-indicies by computing the inner product on the multi-index with the strides.
/// For example, if we have a tensor A(2,3) it has lengths of (2,3) and strides of (3,1). If we want
/// to compute the linear offset that corresponds to the element on the 2nd row (i = 1) and 3rd
/// column (j = 2), we compute the following inner product (1,2) dot (3, 1) = 1*3 + 2*1 = 5. It is
/// obvious from there that we can negate the effects of a given axis by setting the stride of that
/// axis to zero.
1016
1017
1018
struct broadcast
{
    uint64_t axis = 0;
Paul's avatar
Paul committed
1019
1020
1021
1022

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
1023
        return pack(f(self.axis, "axis"));
Paul's avatar
Paul committed
1024
1025
    }

Scott Thornton's avatar
Scott Thornton committed
1026
    shape broadcast_shape;
1027
1028
1029
    std::string name() const { return "broadcast"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Scott Thornton's avatar
Scott Thornton committed
1030
1031
        auto t     = inputs.at(0).type();
        auto input = inputs.at(0);
Paul's avatar
Paul committed
1032

Scott Thornton's avatar
Scott Thornton committed
1033
        std::vector<size_t> bcast_strides(broadcast_shape.lens().size(), 0);
1034

Scott Thornton's avatar
Scott Thornton committed
1035
1036
1037
        if(std::all_of(broadcast_shape.lens().cbegin(), broadcast_shape.lens().cend(), [&](auto x) {
               return x == 1;
           }))
1038
        {
Scott Thornton's avatar
Scott Thornton committed
1039
            if(axis != 0)
Paul's avatar
Paul committed
1040
                MIGRAPHX_THROW("when broadcasting tensor of size 1, axis should be 0");
Scott Thornton's avatar
Scott Thornton committed
1041
            return {t, broadcast_shape.lens(), std::move(bcast_strides)};
1042
1043
1044
        }
        else
        {
Scott Thornton's avatar
Scott Thornton committed
1045
            assert(broadcast_shape.lens().size() - axis >= input.lens().size());
Scott Thornton's avatar
Scott Thornton committed
1046
1047
            if(!std::equal(
                   input.lens().begin(), input.lens().end(), broadcast_shape.lens().begin() + axis))
Paul's avatar
Paul committed
1048
                MIGRAPHX_THROW("when broadcasting success sizes must match");
Paul's avatar
Paul committed
1049
            std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
Scott Thornton's avatar
Scott Thornton committed
1050
            return {t, broadcast_shape.lens(), std::move(bcast_strides)};
1051
1052
        }
    }
Paul's avatar
Paul committed
1053
    argument compute(shape output_shape, std::vector<argument> args) const
Scott Thornton's avatar
Scott Thornton committed
1054
    {
Scott Thornton's avatar
Scott Thornton committed
1055
        return {std::move(output_shape), std::move(args.at(0).data)};
Scott Thornton's avatar
Scott Thornton committed
1056
    }
Paul's avatar
Paul committed
1057
    int output_alias(const std::vector<shape>&) const { return 0; }
1058
1059
};

Scott Thornton's avatar
Scott Thornton committed
1060
1061
1062
struct multibroadcast
{
    std::vector<std::size_t> output_lens;
1063
1064
1065
1066
1067
1068
1069

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(f(self.output_lens, "output_lens"));
    }

Scott Thornton's avatar
Scott Thornton committed
1070
    std::string name() const { return "multibroadcast"; }
1071

Scott Thornton's avatar
Scott Thornton committed
1072
1073
1074
1075
1076
1077
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(1);
        auto t     = inputs.at(0).type();
        auto input = inputs.at(0);

wsttiger's avatar
wsttiger committed
1078
        if(input.lens().empty())
Paul's avatar
Paul committed
1079
            MIGRAPHX_THROW("inputs dimensions should be > 0");
Scott Thornton's avatar
Scott Thornton committed
1080

Scott Thornton's avatar
Scott Thornton committed
1081
        if(input.lens().size() > output_lens.size())
Paul's avatar
Paul committed
1082
            MIGRAPHX_THROW("inputs dimensions should <= output size");
Scott Thornton's avatar
Scott Thornton committed
1083
1084

        std::vector<size_t> bcast_strides(output_lens.size(), 0);
Scott Thornton's avatar
Scott Thornton committed
1085
1086
        auto offset = output_lens.size() - input.lens().size();
        for(int i = input.lens().size() - 1; i >= 0; i--)
Scott Thornton's avatar
Scott Thornton committed
1087
        {
Scott Thornton's avatar
Scott Thornton committed
1088
            if(output_lens[i + offset] == input.lens()[i])
Scott Thornton's avatar
Scott Thornton committed
1089
            {
Scott Thornton's avatar
Scott Thornton committed
1090
                bcast_strides[i + offset] = input.strides()[i];
Scott Thornton's avatar
Scott Thornton committed
1091
1092
1093
1094
            }
        }
        return {t, output_lens, bcast_strides};
    }
Paul's avatar
Paul committed
1095
    argument compute(shape output_shape, std::vector<argument> args) const
Scott Thornton's avatar
Scott Thornton committed
1096
1097
1098
1099
1100
1101
    {
        return {std::move(output_shape), std::move(args.at(0).data)};
    }
    int output_alias(const std::vector<shape>&) const { return 0; }
};

Khalique's avatar
Khalique committed
1102
1103
1104
1105
1106
1107
1108
1109
1110
struct scalar
{
    shape scalar_bcast;

    std::string name() const { return "scalar"; }

    shape compute_shape(std::vector<shape> inputs) const
    {
        assert(check_shapes{inputs}.has(1).only_dims(1).size() == 1);
Paul's avatar
Paul committed
1111
        auto t = inputs.at(0).type();
Khalique's avatar
Khalique committed
1112
1113
1114
1115
        std::vector<std::size_t> strides(scalar_bcast.lens().size(), 0);
        return {t, scalar_bcast.lens(), strides};
    }

Paul's avatar
Paul committed
1116
    argument compute(shape output_shape, std::vector<argument> args) const
Khalique's avatar
Khalique committed
1117
1118
1119
    {
        return {std::move(output_shape), std::move(args.at(0).data)};
    }
Paul's avatar
Paul committed
1120
    int output_alias(const std::vector<shape>&) const { return 0; }
Khalique's avatar
Khalique committed
1121
1122
};

1123
struct binary
Scott Thornton's avatar
Scott Thornton committed
1124
{
1125
1126
    shape compute_shape(std::vector<shape> inputs) const
    {
1127
        check_shapes{inputs}.has(2).same_type().same_dims();
Scott Thornton's avatar
Scott Thornton committed
1128
        auto t    = inputs.at(0).type();
1129
1130
        auto lens = inputs.at(0).lens();
        return {t, lens};
1131
    }
Scott Thornton's avatar
Scott Thornton committed
1132
1133
};

1134
1135
1136
1137
1138
1139
struct add : binary
{
    std::string name() const { return "add"; }
};

struct sub : binary
Scott Thornton's avatar
Scott Thornton committed
1140
1141
1142
1143
{
    std::string name() const { return "sub"; }
};

1144
struct mul : binary
Scott Thornton's avatar
Scott Thornton committed
1145
1146
1147
1148
{
    std::string name() const { return "mul"; }
};

1149
struct div : binary
Scott Thornton's avatar
Scott Thornton committed
1150
1151
1152
1153
{
    std::string name() const { return "div"; }
};

Khalique's avatar
Khalique committed
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
struct max : binary
{
    std::string name() const { return "max"; }
};

struct min : binary
{
    std::string name() const { return "min"; }
};

Paul's avatar
Paul committed
1164
1165
1166
1167
struct load
{
    shape s;
    std::size_t offset = 0;
Paul's avatar
Paul committed
1168
1169
1170
1171

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
1172
        return pack(f(self.s, "shape"), f(self.offset, "offset"));
Paul's avatar
Paul committed
1173
1174
    }

Paul's avatar
Paul committed
1175
1176
1177
1178
1179
1180
    std::string name() const { return "load"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
        check_shapes{inputs}.has(1);
        return s;
    }
Paul's avatar
Paul committed
1181
    argument compute(const shape&, const std::vector<argument>& args) const
Paul's avatar
Paul committed
1182
    {
Paul's avatar
Paul committed
1183
1184
        if((offset + s.bytes()) > args[0].get_shape().bytes())
            MIGRAPHX_THROW("Load access is out of bounds");
Paul's avatar
Paul committed
1185
1186
        return {s, args[0].data() + offset};
    }
Paul's avatar
Paul committed
1187
    int output_alias(const std::vector<shape>&) const { return 0; }
Paul's avatar
Paul committed
1188
1189
1190
1191
1192
1193
1194
1195

    friend std::ostream& operator<<(std::ostream& os, const load& op)
    {
        os << op.name() << "[";
        os << "offset=" << op.offset << ",";
        os << "end=" << (op.offset + op.s.bytes()) << "]";
        return os;
    }
Paul's avatar
Paul committed
1196
1197
};

Paul's avatar
Paul committed
1198
struct outline
Scott Thornton's avatar
Scott Thornton committed
1199
{
Paul's avatar
Paul committed
1200
    shape s;
Paul's avatar
Paul committed
1201
1202
1203
1204

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
1205
        return pack(f(self.s, "shape"));
Paul's avatar
Paul committed
1206
1207
    }

Paul's avatar
Paul committed
1208
    std::string name() const { return "outline"; }
Paul's avatar
Paul committed
1209
    shape compute_shape(const std::vector<shape>& inputs) const
Paul's avatar
Paul committed
1210
    {
Paul's avatar
Paul committed
1211
        check_shapes{inputs, *this}.has(0);
Paul's avatar
Paul committed
1212
1213
        return s;
    }
Paul's avatar
Paul committed
1214
    argument compute(const shape&, const std::vector<argument>&) const { return {s, nullptr}; }
Scott Thornton's avatar
Scott Thornton committed
1215
1216
};

1217
1218
// indicate rnn computation direction
enum class rnn_direction
Shucai Xiao's avatar
Shucai Xiao committed
1219
{
1220
1221
1222
1223
    forward,
    reverse,
    bidirectional,
};
Shucai Xiao's avatar
Shucai Xiao committed
1224

1225
1226
struct rnn
{
Shucai Xiao's avatar
Shucai Xiao committed
1227
    std::size_t hidden_size = 1;
1228
    std::vector<operation> actv_funcs{tanh{}, tanh{}};
1229
    rnn_direction direction = rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1230
    float clip              = 0.0f;
Shucai Xiao's avatar
Shucai Xiao committed
1231
1232
1233
1234

    std::string name() const { return "rnn"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Shucai Xiao's avatar
Shucai Xiao committed
1235
        auto in_dims     = inputs[0].lens();
Shucai Xiao's avatar
Shucai Xiao committed
1236
1237
        auto hidden_dims = inputs[2].lens();
        if(hidden_size != hidden_dims[2])
Shucai Xiao's avatar
Shucai Xiao committed
1238
1239
1240
1241
1242
        {
            MIGRAPHX_THROW("RNN: hidden size mismatch in attribute and input");
        }

        std::size_t num_directions = 1;
1243
        if(direction == rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1244
1245
1246
1247
        {
            num_directions = 2;
        }

Shucai Xiao's avatar
Shucai Xiao committed
1248
        if(num_directions != hidden_dims[0])
Shucai Xiao's avatar
Shucai Xiao committed
1249
        {
1250
            MIGRAPHX_THROW("RNN: num_direction mismatch in attribute and input");
Shucai Xiao's avatar
Shucai Xiao committed
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
        }

        std::vector<std::size_t> out_dims(in_dims);
        out_dims.insert(out_dims.begin() + 1, num_directions);
        out_dims.back() = hidden_size;

        return {inputs[0].type(), out_dims};
    }
};

Shucai Xiao's avatar
Shucai Xiao committed
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
struct rnn_last_output
{
    std::string name() const { return "rnn_last_output"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(1);
        auto dims = inputs[0].lens();

        // remove the first dimension, remaing are output shape
        dims.erase(dims.begin());
        return {inputs[0].type(), dims};
    }
};

Shucai Xiao's avatar
Shucai Xiao committed
1275
1276
1277
1278
struct gru
{
    std::size_t hidden_size = 1;
    std::vector<operation> actv_funcs{sigmoid{}, tanh{}};
1279
    rnn_direction direction = rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1280
1281
    float clip              = 0.0f;
    int linear_before_reset = 0;
Shucai Xiao's avatar
Shucai Xiao committed
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293

    std::string name() const { return "gru"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        auto in_dims     = inputs[0].lens();
        auto hidden_dims = inputs[2].lens();
        if(hidden_size != hidden_dims[2])
        {
            MIGRAPHX_THROW("GRU: hidden size mismatch in attribute and input");
        }

        std::size_t num_directions = 1;
1294
        if(direction == rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
        {
            num_directions = 2;
        }

        if(num_directions != hidden_dims[0])
        {
            MIGRAPHX_THROW("GRU: num_direction does not match the direction attribute");
        }

        std::vector<std::size_t> out_dims(in_dims);
        out_dims.insert(out_dims.begin() + 1, num_directions);
        out_dims.back() = hidden_size;

        return {inputs[0].type(), out_dims};
    }
};

Shucai Xiao's avatar
Shucai Xiao committed
1312
1313
1314
1315
struct lstm
{
    std::size_t hidden_size = 1;
    std::vector<operation> actv_funcs{sigmoid{}, tanh{}, tanh{}};
Shucai Xiao's avatar
Shucai Xiao committed
1316
    rnn_direction direction = rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1317
1318
    float clip              = 0.0f;
    int input_forget        = 0;
Shucai Xiao's avatar
Shucai Xiao committed
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330

    std::string name() const { return "lstm"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        auto in_dims     = inputs[0].lens();
        auto hidden_dims = inputs[2].lens();
        if(hidden_size != hidden_dims[2])
        {
            MIGRAPHX_THROW("LSTM: hidden size mismatch in attribute and input");
        }

        std::size_t num_directions = 1;
Shucai Xiao's avatar
Shucai Xiao committed
1331
        if(direction == rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
        {
            num_directions = 2;
        }

        if(num_directions != hidden_dims[0])
        {
            MIGRAPHX_THROW("LSTM: num_direction does not match the direction attribute");
        }

        std::vector<std::size_t> out_dims(in_dims);
        out_dims.insert(out_dims.begin() + 1, num_directions);
        out_dims.back() = hidden_size;

        return {inputs[0].type(), out_dims};
    }
};

struct lstm_last_cell_output
{
    std::string name() const { return "lstm_last_cell_output"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(1);
        auto dims = inputs[0].lens();

        // remove the first dimension, remaing are output shape
        dims.erase(dims.begin());
        return {inputs[0].type(), dims};
    }
};

1363
1364
1365
struct undefined
{
    std::string name() const { return "undefined"; }
Shucai Xiao's avatar
Shucai Xiao committed
1366
    shape compute_shape(const std::vector<shape>& inputs) const
1367
1368
1369
1370
1371
1372
1373
1374
    {
        check_shapes{inputs, *this}.has(0);
        return {};
    }

    argument compute(const shape&, const std::vector<argument>&) const { return {{}, nullptr}; }
};

1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
struct unknown
{
    std::string op;
    std::string name() const { return "unknown:" + op; }
    shape compute_shape(std::vector<shape> input) const
    {
        if(input.empty())
            return {};
        else
            return input.front();
    }

    friend std::ostream& operator<<(std::ostream& os, const unknown& x)
    {
        os << x.name();
        return os;
    }
};

1394
} // namespace op
Paul's avatar
Paul committed
1395
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
1396
} // namespace migraphx
Paul's avatar
Paul committed
1397
1398

#endif