operators.hpp 41.9 KB
Newer Older
Paul's avatar
Paul committed
1
2
#ifndef MIGRAPHX_GUARD_OPERATORS_HPP
#define MIGRAPHX_GUARD_OPERATORS_HPP
Paul's avatar
Paul committed
3

4
#include <array>
Paul's avatar
Paul committed
5
6
7
8
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
9
10
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
Paul's avatar
Paul committed
11
#include <migraphx/config.hpp>
Paul's avatar
Paul committed
12
#include <cmath>
Paul's avatar
Paul committed
13
#include <utility>
Paul's avatar
Paul committed
14

Paul's avatar
Paul committed
15
namespace migraphx {
Paul's avatar
Paul committed
16
inline namespace MIGRAPHX_INLINE_NS {
17
namespace op {
Paul's avatar
Paul committed
18

19
20
21
22
23
24
25
enum padding_mode_t
{
    default_, // NOLINT
    same,
    valid
};

Paul's avatar
Paul committed
26
27
struct not_computable
{
Paul's avatar
Paul committed
28
    argument compute(const shape&, const std::vector<argument>&) const
Paul's avatar
Paul committed
29
    {
Paul's avatar
Paul committed
30
        MIGRAPHX_THROW("not computable");
Paul's avatar
Paul committed
31
    }
Paul's avatar
Paul committed
32
33
};

34
35
struct batch_norm_inference
{
36
37
    float epsilon  = 1.0e-6f;
    float momentum = 0.9f;
38
39
40

    std::string name() const { return "batch_norm_inference"; }

41
42
43
44
45
46
47
48
    enum bn_infer_mode_t
    {
        per_activation,
        spatial,
    };

    bn_infer_mode_t bn_mode = spatial;

Paul's avatar
Paul committed
49
50
51
52
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(
Paul's avatar
Paul committed
53
            f(self.epsilon, "epsilon"), f(self.momentum, "momentum"), f(self.bn_mode, "bn_mode"));
Paul's avatar
Paul committed
54
    }
55

56
57
58
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(5);
Paul's avatar
Paul committed
59
60
61
        check_shapes{inputs.data(), inputs.data() + 1, *this}.only_dims(4);
        check_shapes{inputs.data() + 1, inputs.data() + inputs.size(), *this}.same_shape().elements(
            inputs.front().lens()[1]);
62
63
64
65
        return inputs.front();
    }
};

Khalique's avatar
Khalique committed
66
struct lrn
Khalique's avatar
Khalique committed
67
68
{
    float alpha = 0.0001;
Khalique's avatar
Khalique committed
69
70
    float beta  = 0.75;
    float bias  = 1.0;
Khalique's avatar
Khalique committed
71
    int size    = 1;
Khalique's avatar
Khalique committed
72
    std::string name() const { return "lrn"; }
Khalique's avatar
Khalique committed
73

Khalique's avatar
Khalique committed
74
75
76
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Khalique's avatar
Khalique committed
77
78
79
80
        return pack(f(self.alpha, "alpha"),
                    f(self.beta, "beta"),
                    f(self.bias, "bias"),
                    f(self.size, "size"));
Khalique's avatar
Khalique committed
81
82
    }

Khalique's avatar
Khalique committed
83
84
85
86
87
88
89
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(1);
        return inputs.front();
    }
};

Paul's avatar
Paul committed
90
struct convolution
Paul's avatar
Paul committed
91
{
Paul's avatar
Paul committed
92
93
94
    std::array<std::size_t, 2> padding  = {{0, 0}};
    std::array<std::size_t, 2> stride   = {{1, 1}};
    std::array<std::size_t, 2> dilation = {{1, 1}};
Khalique's avatar
Khalique committed
95

Paul's avatar
Paul committed
96
    padding_mode_t padding_mode = default_;
Khalique's avatar
Khalique committed
97
    int group                   = 1;
Paul's avatar
Paul committed
98
99
100
101

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
102
103
104
        return pack(f(self.padding, "padding"),
                    f(self.stride, "stride"),
                    f(self.dilation, "dilation"),
Khalique's avatar
Khalique committed
105
106
                    f(self.padding_mode, "padding_mode"),
                    f(self.group, "group"));
Paul's avatar
Paul committed
107
108
    }

Paul's avatar
Paul committed
109
    std::string name() const { return "convolution"; }
Paul's avatar
Paul committed
110
111
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
112
        check_shapes{inputs, *this}.has(2).same_type().same_ndims().only_dims(4);
Paul's avatar
Paul committed
113

Paul's avatar
Paul committed
114
        const shape& input   = inputs.at(0);
Paul's avatar
Paul committed
115
        const shape& weights = inputs.at(1);
Paul's avatar
Paul committed
116
        auto t               = input.type();
Paul's avatar
Paul committed
117
118
        if(padding_mode == default_)
        {
Paul's avatar
Paul committed
119
120
121
            return {t,
                    {
                        input.lens()[0],
Khalique's avatar
Khalique committed
122
                        weights.lens()[0],
Paul's avatar
Paul committed
123
124
125
126
127
128
129
130
131
132
133
134
135
                        std::size_t(std::max<std::ptrdiff_t>(
                            1,
                            (input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) +
                             2 * padding[0]) /
                                    stride[0] +
                                1)),
                        std::size_t(std::max<std::ptrdiff_t>(
                            1,
                            (input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) +
                             2 * padding[1]) /
                                    stride[1] +
                                1)),
                    }};
Paul's avatar
Paul committed
136
137
138
139
140
        }
        else if(padding_mode == same)
        {
            return {t,
                    {input.lens()[0],
Khalique's avatar
Khalique committed
141
                     weights.lens()[0],
Paul's avatar
Paul committed
142
143
144
145
146
147
148
149
150
151
                     static_cast<std::size_t>(
                         std::ceil(static_cast<double>(input.lens()[2]) / stride[0])),
                     static_cast<std::size_t>(
                         std::ceil(static_cast<double>(input.lens()[3]) / stride[1]))}};
        }
        else if(padding_mode == valid)
        {
            return {
                t,
                {input.lens()[0],
Khalique's avatar
Khalique committed
152
                 weights.lens()[0],
Paul's avatar
Paul committed
153
154
155
156
157
158
159
                 static_cast<std::size_t>(std::ceil(
                     static_cast<double>(input.lens()[2] - weights.lens()[2] + 1) / stride[0])),
                 static_cast<std::size_t>(std::ceil(
                     static_cast<double>(input.lens()[3] - weights.lens()[3] + 1) / stride[1]))}};
        }
        else
        {
Paul's avatar
Paul committed
160
            MIGRAPHX_THROW("Invalid padding mode");
Paul's avatar
Paul committed
161
        }
Paul's avatar
Paul committed
162
163
164
    }
};

Scott Thornton's avatar
Scott Thornton committed
165
166
struct im2col
{
Scott Thornton's avatar
Scott Thornton committed
167
168
169
    std::array<std::size_t, 2> padding  = {{0, 0}};
    std::array<std::size_t, 2> stride   = {{1, 1}};
    std::array<std::size_t, 2> dilation = {{1, 1}};
Khalique's avatar
Khalique committed
170

Paul's avatar
Paul committed
171
172
173
174
175
    padding_mode_t padding_mode = default_;

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
176
177
178
179
        return pack(f(self.padding, "padding"),
                    f(self.stride, "stride"),
                    f(self.dilation, "dilation"),
                    f(self.padding_mode, "padding_mode"));
Paul's avatar
Paul committed
180
    }
Scott Thornton's avatar
Scott Thornton committed
181
182
183
184
185

    std::string name() const { return "im2col"; }

    shape compute_shape(std::vector<shape> inputs) const
    {
Scott Thornton's avatar
Scott Thornton committed
186
187
188
        auto input          = inputs[0];
        auto weights        = inputs[1];
        auto batch_size     = input.lens()[0];
Scott Thornton's avatar
Scott Thornton committed
189
        auto input_channels = weights.lens()[1];
Scott Thornton's avatar
Scott Thornton committed
190
191
        auto kernel_height  = weights.lens()[2];
        auto kernel_width   = weights.lens()[3];
Scott Thornton's avatar
Scott Thornton committed
192
        check_shapes{inputs, *this}.has(2);
Scott Thornton's avatar
Scott Thornton committed
193
        if(batch_size != 1)
Paul's avatar
Paul committed
194
            MIGRAPHX_THROW("im2col only support batch_size 1");
Scott Thornton's avatar
Scott Thornton committed
195
        auto output_height = std::size_t(std::max<std::ptrdiff_t>(
Scott Thornton's avatar
Scott Thornton committed
196
197
198
            1,
            (input.lens()[2] - (1 + dilation[0] * (kernel_height - 1)) + 2 * padding[0]) /
                    stride[0] +
Scott Thornton's avatar
Scott Thornton committed
199
                1));
Scott Thornton's avatar
Scott Thornton committed
200
201
202
203
        auto output_width  = std::size_t(std::max<std::ptrdiff_t>(
            1,
            (input.lens()[3] - (1 + dilation[1] * (kernel_width - 1)) + 2 * padding[1]) /
                    stride[1] +
Scott Thornton's avatar
Scott Thornton committed
204
                1));
Scott Thornton's avatar
Scott Thornton committed
205
206
        auto channels_col  = kernel_height * kernel_width * input_channels;
        return {input.type(), {output_height * output_width, channels_col}};
Scott Thornton's avatar
Scott Thornton committed
207
208
209
    }
};

Paul's avatar
Paul committed
210
struct pooling
Paul's avatar
Paul committed
211
{
Paul's avatar
Paul committed
212
    std::string mode                   = "average";
Paul's avatar
Paul committed
213
214
215
    std::array<std::size_t, 2> padding = {{0, 0}};
    std::array<std::size_t, 2> stride  = {{1, 1}};
    std::array<std::size_t, 2> lengths = {{1, 1}};
Khalique's avatar
Khalique committed
216
    padding_mode_t padding_mode        = default_;
Paul's avatar
Paul committed
217
218
219
220

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
221
222
        return pack(f(self.mode, "mode"),
                    f(self.padding, "padding"),
223
                    f(self.padding, "padding_mode"),
Paul's avatar
Paul committed
224
225
                    f(self.stride, "stride"),
                    f(self.lengths, "lengths"));
Paul's avatar
Paul committed
226
227
    }

Paul's avatar
Paul committed
228
    std::string name() const { return "pooling"; }
Scott Thornton's avatar
Scott Thornton committed
229

Paul's avatar
Paul committed
230
231
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
232
        check_shapes{inputs, *this}.has(1).only_dims(4);
Paul's avatar
Paul committed
233

Paul's avatar
Paul committed
234
        const shape& input = inputs.at(0);
Paul's avatar
Paul committed
235
        auto t             = input.type();
Paul's avatar
Paul committed
236

Paul's avatar
Paul committed
237
238
        assert(lengths[0] <= (input.lens()[2] + 2 * padding[0]));
        assert(lengths[1] <= (input.lens()[3] + 2 * padding[1]));
Paul's avatar
Paul committed
239

240
241
        if(padding_mode == default_)
        {
Khalique's avatar
Khalique committed
242
243
            return {
                t,
Scott Thornton's avatar
Scott Thornton committed
244
245
246
247
248
                {
                    input.lens()[0],
                    input.lens()[1],
                    std::size_t(std::max<std::ptrdiff_t>(
                        1,
Paul's avatar
Paul committed
249
                        std::ptrdiff_t(std::floor((input.lens()[2] + 2 * padding[0] - lengths[0]) /
Paul's avatar
Paul committed
250
                                                  static_cast<float>(stride[0]))) +
Scott Thornton's avatar
Scott Thornton committed
251
252
253
                            1)),
                    std::size_t(std::max<std::ptrdiff_t>(
                        1,
Paul's avatar
Paul committed
254
                        std::ptrdiff_t(std::floor((input.lens()[3] + 2 * padding[1] - lengths[1]) /
Paul's avatar
Paul committed
255
                                                  static_cast<float>(stride[1]))) +
Scott Thornton's avatar
Scott Thornton committed
256
257
                            1)),
                }};
258
259
260
261
262
263
264
265
266
267
268
269
270
271
        }
        else if(padding_mode == same)
        {
            return {t,
                    {input.lens()[0],
                     input.lens()[1],
                     static_cast<std::size_t>(
                         std::ceil(static_cast<double>(input.lens()[2]) / stride[0])),
                     static_cast<std::size_t>(
                         std::ceil(static_cast<double>(input.lens()[3]) / stride[1]))}};
        }
        else if(padding_mode == valid)
        {
            return {t,
Khalique's avatar
Khalique committed
272
273
274
275
276
277
278
279
280
281
282
283
284
285
                    {
                        input.lens()[0],
                        input.lens()[1],
                        std::size_t(std::max<std::ptrdiff_t>(
                            1,
                            std::ptrdiff_t(std::floor((input.lens()[2] - lengths[0]) /
                                                      static_cast<float>(stride[0]))) +
                                1)),
                        std::size_t(std::max<std::ptrdiff_t>(
                            1,
                            std::ptrdiff_t(std::floor((input.lens()[3] - lengths[1]) /
                                                      static_cast<float>(stride[1]))) +
                                1)),
                    }};
286
287
288
289
290
        }
        else
        {
            MIGRAPHX_THROW("Invalid padding mode");
        }
Paul's avatar
Paul committed
291
292
293
    }
};

Khalique's avatar
Khalique committed
294
295
296
297
298
299
300
301
302
struct leaky_relu
{
    std::string name() const { return "leaky_relu"; }
    float alpha;
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(1);
        return inputs.front();
    }
Khalique's avatar
Khalique committed
303
304
305

    template <class Self, class F>
    static auto reflect(Self& self, F f)
Khalique's avatar
Khalique committed
306
    {
Khalique's avatar
Khalique committed
307
        return pack(f(self.alpha, "alpha"));
Khalique's avatar
Khalique committed
308
309
310
311
312
313
314
315
316
317
318
319
    }
};

struct elu
{
    std::string name() const { return "elu"; }
    float alpha;
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(1);
        return inputs.front();
    }
Khalique's avatar
Khalique committed
320
321
322

    template <class Self, class F>
    static auto reflect(Self& self, F f)
Khalique's avatar
Khalique committed
323
    {
Khalique's avatar
Khalique committed
324
        return pack(f(self.alpha, "alpha"));
Khalique's avatar
Khalique committed
325
    }
Khalique's avatar
Khalique committed
326
327
};

328
329
330
struct transpose
{
    std::vector<int64_t> dims;
Paul's avatar
Paul committed
331
332
333
334

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
335
        return pack(f(self.dims, "dims"));
Paul's avatar
Paul committed
336
337
    }

338
339
340
    std::string name() const { return "transpose"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
341
        check_shapes{inputs, *this}.has(1);
342
        auto input         = inputs.at(0);
343
        auto input_lens    = input.lens();
344
345
        auto input_strides = input.strides();
        auto t             = input.type();
Paul's avatar
Paul committed
346
347
        if(dims.size() != input_lens.size())
        {
Paul's avatar
Paul committed
348
            MIGRAPHX_THROW("Permutation has wrong number of axes");
349
350
351
        }
        std::vector<int64_t> axes(dims.size());
        std::iota(axes.begin(), axes.end(), 0);
Paul's avatar
Paul committed
352
353
        if(!std::is_permutation(axes.begin(), axes.end(), dims.begin()))
        {
Paul's avatar
Paul committed
354
            MIGRAPHX_THROW("Invalid permutation");
355
        }
356
357
        std::vector<size_t> output_lens(input_lens.size());
        std::vector<size_t> output_strides(input_lens.size());
Paul's avatar
Paul committed
358
        for(std::size_t i = 0; i < output_lens.size(); i++)
Paul's avatar
Paul committed
359
360
        {
            output_lens[i]    = input_lens[dims[i]];
361
362
            output_strides[i] = input_strides[dims[i]];
        }
363
        return {t, output_lens, output_strides};
364
    }
Paul's avatar
Paul committed
365
    argument compute(shape output_shape, std::vector<argument> args) const
Paul's avatar
Paul committed
366
    {
Paul's avatar
Paul committed
367
        return {std::move(output_shape), std::move(args.front().data)};
Paul's avatar
Paul committed
368
    }
Paul's avatar
Paul committed
369
    int output_alias(const std::vector<shape>&) const { return 0; }
370
371
};

wsttiger's avatar
fixes  
wsttiger committed
372
373
374
375
376
377
/// The contiguous operator takes a non-standard input tensor and returns
/// the same tensor but in standard form. For example, if input tensor A which has lens = (4,5)
/// is first transposed, i.e. lens = (5,4), this tensor's data layout remained the same
/// during the transpose operation; only it's shape lengths and strides were changed.
/// This leaves the tensor in a non-standard form. The contiguous operator copies the
/// underlying data such that resulting tensor is returned to a standard form.
Paul's avatar
Paul committed
378
struct contiguous
379
380
381
382
{
    std::string name() const { return "contiguous"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
383
        check_shapes{inputs, *this}.has(1);
Paul's avatar
Paul committed
384
385
        auto lens = inputs.at(0).lens();
        auto t    = inputs.at(0).type();
386
387
        return {t, lens};
    }
Paul's avatar
Paul committed
388
389
390
391
392
393
394
395
396
397
398
    argument compute(const shape& output_shape, std::vector<argument> args) const
    {
        assert(output_shape.standard());
        argument result{output_shape};
        visit_all(result, args[0])([&](auto output, auto input) {
            shape_for_each(output.get_shape(), [&](const auto& idx) {
                output(idx.begin(), idx.end()) = input(idx.begin(), idx.end());
            });
        });
        return result;
    }
399
400
};

401
402
403
404
struct concat
{
    std::size_t axis = 0;
    std::string name() const { return "concat"; }
405
    std::vector<std::size_t> compute_offsets(const shape& output_shape,
Paul's avatar
Paul committed
406
                                             const std::vector<argument>& args) const
407
408
409
410
411
412
413
414
415
416
417
    {
        std::vector<std::size_t> offsets;
        std::vector<std::size_t> offset(args[0].get_shape().lens().size(), 0);
        offset[axis] = 0;
        for(const auto& arg : args)
        {
            offsets.push_back(output_shape.index(offset));
            offset[axis] += arg.get_shape().lens()[axis];
        }
        return offsets;
    }
418
419
    shape compute_shape(std::vector<shape> inputs) const
    {
Scott Thornton's avatar
Scott Thornton committed
420
        if(inputs.empty())
421
        {
Paul's avatar
Paul committed
422
            MIGRAPHX_THROW("Number of input tensors should exceed 0");
423
424
425
        }

        const auto& first_shape_lens = inputs.front().lens();
Scott Thornton's avatar
Scott Thornton committed
426
427
428
429
430
431
432
433
434
        const auto& type             = inputs.front().type();
        for(std::size_t l = 0; l < first_shape_lens.size(); l++)
        {
            if(l != axis)
            {
                if(!std::all_of(inputs.begin(), inputs.end(), [&](auto s) {
                       return s.lens()[l] == first_shape_lens[l];
                   }))
                {
Paul's avatar
Paul committed
435
                    MIGRAPHX_THROW("Non-axis dimensions should match");
436
437
438
439
                }
            }
        }
        std::size_t new_dim_axis = 0;
Scott Thornton's avatar
Scott Thornton committed
440
        for(const auto& input : inputs)
441
442
443
444
445
446
447
448
449
        {
            const auto& lens = input.lens();
            new_dim_axis += lens[axis];
        }
        std::vector<std::size_t> new_lens;
        std::copy(first_shape_lens.begin(), first_shape_lens.end(), std::back_inserter(new_lens));
        new_lens[axis] = new_dim_axis;
        return {type, new_lens};
    }
Paul's avatar
Paul committed
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
    argument compute(const shape& output_shape, std::vector<argument> args) const
    {
        argument result{output_shape};
        std::vector<std::size_t> coffsets = compute_offsets(output_shape, args);
        for(std::size_t l = 0; l < args.size(); l++)
        {
            auto argl             = args[l];
            std::size_t nelements = argl.get_shape().elements();
            visit_all(result, argl)([&](auto output, auto input) {
                auto slice_shape =
                    shape{output_shape.type(), input.get_shape().lens(), output_shape.strides()};
                auto slice = make_view(slice_shape, output.data() + coffsets[l]);
                // cppcheck-suppress useStlAlgorithm
                for(std::size_t i = 0; i < nelements; i++)
                {
                    slice[i] = input[i];
                }
            });
        }
        return result;
    }
471
472
};

473
474
475
476
477
struct slice
{
    std::vector<int64_t> axes;
    std::vector<int64_t> starts;
    std::vector<int64_t> ends;
Paul's avatar
Paul committed
478
479
480
481

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
482
        return pack(f(self.axes, "axes"), f(self.starts, "starts"), f(self.ends, "ends"));
Paul's avatar
Paul committed
483
484
    }

485
    std::string name() const { return "slice"; }
Scott Thornton's avatar
Scott Thornton committed
486
487

    auto fix_index(const std::vector<std::size_t>& lens, std::size_t axis, int64_t index) const
488
    {
Scott Thornton's avatar
Scott Thornton committed
489
        int64_t r = std::min(index, static_cast<int64_t>(lens[axis]));
Scott Thornton's avatar
Scott Thornton committed
490
491
        if(r < 0)
            r += lens[axis];
Scott Thornton's avatar
Scott Thornton committed
492
        return std::size_t(r);
Scott Thornton's avatar
Scott Thornton committed
493
494
495
496
497
498
499
    }

    auto compute_offset(const shape& s) const
    {
        const std::vector<std::size_t>& lens    = s.lens();
        const std::vector<std::size_t>& strides = s.strides();
        auto offset                             = 0;
Scott Thornton's avatar
Scott Thornton committed
500
        if(!axes.empty())
Scott Thornton's avatar
Scott Thornton committed
501
        {
Scott Thornton's avatar
Scott Thornton committed
502
503
504
505
506
            for(std::size_t i = 0; i < axes.size(); i++)
            {
                auto axis = axes[i];
                offset += fix_index(lens, axis, starts[i]) * strides[axis];
            }
507
        }
Scott Thornton's avatar
Scott Thornton committed
508
509
        else
        {
Scott Thornton's avatar
Scott Thornton committed
510
511
512
513
            for(std::size_t axis = 0; axis < lens.size(); axis++)
            {
                offset += fix_index(lens, axis, starts[axis]) * strides[axis];
            }
514
        }
Scott Thornton's avatar
Scott Thornton committed
515
516
517
518
519
        return offset;
    }

    shape compute_shape(std::vector<shape> inputs) const
    {
Scott Thornton's avatar
Scott Thornton committed
520
521
522
523
        auto input_shape        = inputs[0];
        auto t                  = input_shape.type();
        const auto& old_lens    = input_shape.lens();
        const auto& old_strides = input_shape.strides();
Scott Thornton's avatar
Scott Thornton committed
524
        if(starts.size() != axes.size() || axes.size() != ends.size())
Scott Thornton's avatar
Scott Thornton committed
525
        {
Paul's avatar
Paul committed
526
            MIGRAPHX_THROW("inconsistent sizes");
527
        }
Scott Thornton's avatar
Scott Thornton committed
528
529
        std::vector<std::size_t> new_lens = old_lens;
        for(std::size_t i = 0; i < axes.size(); i++)
Scott Thornton's avatar
Scott Thornton committed
530
        {
Scott Thornton's avatar
Scott Thornton committed
531
532
533
            auto axis = axes[i];
            new_lens[axis] =
                fix_index(old_lens, axis, ends[i]) - fix_index(old_lens, axis, starts[i]);
534
535
536
        }
        return shape{t, new_lens, old_strides};
    }
Paul's avatar
Paul committed
537
    argument compute(shape output_shape, std::vector<argument> args) const
538
    {
Scott Thornton's avatar
Scott Thornton committed
539
540
541
        auto input  = args[0];
        auto offset = compute_offset(input.get_shape()) * output_shape.type_size();
        return {std::move(output_shape), [=] { return input.data() + offset; }};
542
    }
Paul's avatar
Paul committed
543
    int output_alias(const std::vector<shape>&) const { return 0; }
544
545
546
547
548
};

struct squeeze
{
    std::vector<int64_t> axes;
Paul's avatar
Paul committed
549
550
551
552

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
553
        return pack(f(self.axes, "axes"));
Paul's avatar
Paul committed
554
555
    }

556
557
558
559
    std::string name() const { return "squeeze"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        auto input_shape = inputs[0];
Scott Thornton's avatar
Scott Thornton committed
560
561
        auto type        = input_shape.type();
        auto old_lens    = input_shape.lens();
wsttiger's avatar
wsttiger committed
562
563
        if(std::any_of(
               axes.begin(), axes.end(), [&](auto axis) { return input_shape.lens()[axis] != 1; }))
Scott Thornton's avatar
Scott Thornton committed
564
        {
Paul's avatar
Paul committed
565
            MIGRAPHX_THROW("squeeze axis dimension should be equal to 1");
566
567
        }
        std::vector<std::size_t> new_lens;
Scott Thornton's avatar
Scott Thornton committed
568
        if(axes.empty())
Scott Thornton's avatar
Scott Thornton committed
569
        {
wsttiger's avatar
wsttiger committed
570
571
572
573
            std::copy_if(old_lens.begin(),
                         old_lens.end(),
                         std::back_inserter(new_lens),
                         [](auto len) { return len != 1; });
574
        }
Scott Thornton's avatar
Scott Thornton committed
575
576
577
578
579
580
        else
        {
            for(std::size_t i = 0; i < old_lens.size(); i++)
            {
                if(std::find(axes.begin(), axes.end(), i) == axes.end())
                {
581
582
583
584
                    new_lens.push_back(old_lens[i]);
                }
            }
        }
585
586
587
588
589
590
591
592
593
594

        // squeezing a single element generates a scalar
        if (new_lens.empty())
        {
            return {type};
        }
        else
        {
            return shape{type, new_lens};            
        }        
595
    }
Paul's avatar
Paul committed
596
    argument compute(shape output_shape, std::vector<argument> args) const
597
598
    {
        return {std::move(output_shape), std::move(args.front().data)};
Scott Thornton's avatar
Scott Thornton committed
599
    }
Paul's avatar
Paul committed
600
    int output_alias(const std::vector<shape>&) const { return 0; }
601
602
603
604
605
};

struct unsqueeze
{
    std::vector<int64_t> axes;
Paul's avatar
Paul committed
606
607
608
609

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
610
        return pack(f(self.axes, "axes"));
Paul's avatar
Paul committed
611
612
    }

613
614
615
    std::string name() const { return "unsqueeze"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Scott Thornton's avatar
Scott Thornton committed
616
617
618
        auto input_shape     = inputs[0];
        auto type            = input_shape.type();
        auto old_lens        = input_shape.lens();
619
620
621
        std::size_t new_size = old_lens.size() + axes.size();
        std::vector<std::size_t> new_lens(new_size);
        std::size_t p = 0;
Scott Thornton's avatar
Scott Thornton committed
622
623
624
625
        for(std::size_t i = 0; i < new_size; i++)
        {
            if(std::find(axes.begin(), axes.end(), i) != axes.end())
            {
626
                new_lens[i] = 1;
Scott Thornton's avatar
Scott Thornton committed
627
628
629
            }
            else
            {
630
631
632
633
634
                new_lens[i] = old_lens[p++];
            }
        }
        return shape{type, new_lens};
    }
Paul's avatar
Paul committed
635
    argument compute(shape output_shape, std::vector<argument> args) const
636
637
638
    {
        return {std::move(output_shape), std::move(args.front().data)};
    }
Paul's avatar
Paul committed
639
    int output_alias(const std::vector<shape>&) const { return 0; }
640
641
};

Paul's avatar
Paul committed
642
643
644
struct reshape
{
    std::vector<int64_t> dims;
Paul's avatar
Paul committed
645
646
647
648

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
649
        return pack(f(self.dims, "dims"));
Paul's avatar
Paul committed
650
651
    }

Paul's avatar
Paul committed
652
    std::string name() const { return "reshape"; }
Paul's avatar
Paul committed
653
654
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
655
        check_shapes{inputs, *this}.has(1);
Paul's avatar
Paul committed
656
657
        auto&& idims = inputs.front().lens();
        std::vector<std::size_t> rdims(dims.begin(), dims.end());
658
659
        auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
        if(n_neg_dims > 1)
Paul's avatar
Paul committed
660
            MIGRAPHX_THROW("Dimensions for reshape can only have one -1 dim");
Paul's avatar
Paul committed
661
        for(std::size_t i = 0; i < dims.size(); i++)
Paul's avatar
Paul committed
662
663
664
        {
            if(dims[i] == 0)
                rdims[i] = idims[i];
Shucai Xiao's avatar
Shucai Xiao committed
665
666
667

            // since rdims using size_t type, -1 is the max value
            // is size_t that cause later compuation incorrect
Shucai Xiao's avatar
Shucai Xiao committed
668
            if(dims[i] == -1)
Shucai Xiao's avatar
Shucai Xiao committed
669
                rdims[i] = 1;
Paul's avatar
Paul committed
670
        }
671
672
673
        if(n_neg_dims > 0)
        {
            size_t missing_dim =
Shucai Xiao's avatar
Shucai Xiao committed
674
                inputs.front().elements() /
675
676
677
678
679
680
681
                std::accumulate(rdims.begin(), rdims.end(), 1, std::multiplies<int64_t>());
            for(std::size_t i = 0; i < rdims.size(); i++)
            {
                if(dims[i] == -1)
                    rdims[i] = missing_dim;
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
682

Scott Thornton's avatar
Scott Thornton committed
683
        shape s{inputs.front().type(), rdims};
Paul's avatar
Paul committed
684
        if(s.elements() != inputs.front().elements())
Paul's avatar
Paul committed
685
            MIGRAPHX_THROW("Wrong number of elements for reshape");
Scott Thornton's avatar
Scott Thornton committed
686
        return s;
Paul's avatar
Paul committed
687
    }
Paul's avatar
Paul committed
688
    argument compute(shape output_shape, std::vector<argument> args) const
Paul's avatar
Paul committed
689
    {
Paul's avatar
Paul committed
690
        return {std::move(output_shape), std::move(args.front().data)};
Paul's avatar
Paul committed
691
    }
Paul's avatar
Paul committed
692
    int output_alias(const std::vector<shape>&) const { return 0; }
Paul's avatar
Paul committed
693
694
};

Khalique's avatar
Khalique committed
695
696
697
struct pad
{
    std::vector<int64_t> pads;
Khalique's avatar
Khalique committed
698
    float value = 0.0f;
Khalique's avatar
Khalique committed
699
    enum pad_op_mode_t
Khalique's avatar
Khalique committed
700
    {
Khalique's avatar
Khalique committed
701
        constant_pad,
Khalique's avatar
Khalique committed
702
        reflect_pad,
Khalique's avatar
Khalique committed
703
        edge_pad
Khalique's avatar
Khalique committed
704
    };
Khalique's avatar
Khalique committed
705
    pad_op_mode_t mode = constant_pad;
Khalique's avatar
Khalique committed
706
707
708
709
710
711
712
713
714
715
716
717
718

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(f(self.mode, "mode"), f(self.pads, "pads"), f(self.value, "value"));
    }

    std::string name() const { return "pad"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(1);
        auto&& idims = inputs.front().lens();
        std::vector<std::size_t> rdims(idims.begin(), idims.end());
Khalique's avatar
Khalique committed
719
        std::size_t num_dims = rdims.size();
Khalique's avatar
Khalique committed
720

Khalique's avatar
Khalique committed
721
        for(std::size_t i = 0; i < num_dims; i++)
Khalique's avatar
Khalique committed
722
723
724
        {
            rdims[i] += pads[i] + pads[i + num_dims];
        }
Khalique's avatar
Khalique committed
725

Khalique's avatar
Khalique committed
726
727
728
        shape s{inputs.front().type(), rdims};
        return s;
    }
729
730
731

    bool symmetric() const
    {
Khalique's avatar
Khalique committed
732
        std::size_t num_dims = pads.size() / 2;
Khalique's avatar
Khalique committed
733
734
        return std::equal(
            pads.begin(), pads.begin() + num_dims, pads.begin() + num_dims, pads.end());
735
    }
Khalique's avatar
Khalique committed
736
737
};

Paul's avatar
Paul committed
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
struct as_shape
{
    shape s;
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(f(self.s, "shape"));
    }

    std::string name() const { return "as_shape"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
        check_shapes{inputs, *this}.has(1).standard();
        assert(inputs.front().elements() == s.elements());
        return s;
    }
    argument compute(shape output_shape, std::vector<argument> args) const
    {
        return {std::move(output_shape), std::move(args.front().data)};
    }
    int output_alias(const std::vector<shape>&) const { return 0; }
};

761
762
struct gather
{
763
    int axis = 0;
764
765
766
767
768
769
    std::string name() const { return "gather"; }

    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(2);
        auto lens = inputs[0].lens();
770
771
        int n_dim = static_cast<int>(lens.size());
        if(axis >= n_dim || axis < -n_dim)
772
        {
773
            MIGRAPHX_THROW("Gather: axis is out of range.");
774
        }
775

776
        // negative axis means counting dimensions from back
777
        int axis_index = (axis < 0) ? (n_dim + axis) : axis;
778

Shucai Xiao's avatar
Shucai Xiao committed
779
        auto type = inputs[0].type();
780
        lens.erase(lens.begin() + axis_index);
Shucai Xiao's avatar
Shucai Xiao committed
781
        if(!inputs[1].scalar())
782
783
784
785
        {
            auto ind_lens = inputs[1].lens();
            lens.insert(lens.begin() + axis_index, ind_lens.begin(), ind_lens.end());
        }
786

787
        // for scalar output
788
        if(lens.empty())
789
        {
790
            return {type};
791
        }
792
793

        return {type, lens};
794
795
796
797
798
    }

    argument compute(const shape& output_shape, std::vector<argument> args) const
    {
        argument result{output_shape};
799
        // negative axis means counting dimensions from back
Shucai Xiao's avatar
Shucai Xiao committed
800
801
        int axis_index =
            (axis < 0) ? static_cast<int>(args[0].get_shape().lens().size() + axis) : axis;
802

803
        // max dimension in axis
804
        visit_all(result, args[0])([&](auto output, auto data) {
805
            args[1].visit([&](auto indices) {
Shucai Xiao's avatar
Shucai Xiao committed
806
                if(output_shape.scalar())
807
808
809
                {
                    output[0] = data[indices.front()];
                }
Shucai Xiao's avatar
Shucai Xiao committed
810
                else
811
                {
Shucai Xiao's avatar
Shucai Xiao committed
812
                    auto out_lens        = data.get_shape().lens();
813
814
815
                    out_lens[axis_index] = indices.get_shape().elements();
                    migraphx::shape out_comp_shape{data.get_shape().type(), out_lens};
                    shape_for_each(out_comp_shape, [&](const auto& out_idx) {
Shucai Xiao's avatar
Shucai Xiao committed
816
                        auto data_idx        = out_idx;
817
                        data_idx[axis_index] = indices[data_idx[axis_index]];
Shucai Xiao's avatar
Shucai Xiao committed
818
                        output[out_comp_shape.index(out_idx.begin(), out_idx.end())] =
819
820
821
                            data(data_idx.begin(), data_idx.end());
                    });
                }
822
823
824
825
826
827
828
            });
        });

        return result;
    }
};

Shucai Xiao's avatar
Shucai Xiao committed
829
struct dot
830
{
Paul's avatar
Paul committed
831
    float alpha = 1.0;
Shucai Xiao's avatar
Shucai Xiao committed
832
    float beta  = 1.0;
Paul's avatar
Paul committed
833
834
835
836

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
837
        return pack(f(self.alpha, "alpha"), f(self.beta, "beta"));
Paul's avatar
Paul committed
838
839
    }

Shucai Xiao's avatar
Shucai Xiao committed
840
    std::string name() const { return "dot"; }
841
842
    shape compute_shape(std::vector<shape> inputs) const
    {
843
        check_shapes{inputs, *this}.same_type();
844
845
        const shape& a = inputs.at(0);
        const shape& b = inputs.at(1);
Scott Thornton's avatar
Scott Thornton committed
846
        auto t         = a.type();
847

848
        // only handle the case that the batch size of a and b are the same
Shucai Xiao's avatar
Shucai Xiao committed
849
850
        if(!std::equal(
               a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend()))
851
        {
852
853
            MIGRAPHX_THROW("DOT: batch size of A and B mismatch: {" + to_string_range(a.lens()) +
                           "} x {" + to_string_range(b.lens()) + "}");
854
855
        }

856
857
858
        std::size_t dim_0 = a.lens().size() - 2;
        std::size_t dim_1 = a.lens().size() - 1;
        if(a.lens()[dim_1] != b.lens()[dim_0])
Shucai Xiao's avatar
Shucai Xiao committed
859
860
        {
            MIGRAPHX_THROW("DOT: inner dimensions do not match: {" + to_string_range(a.lens()) +
Paul's avatar
Paul committed
861
                           "} x {" + to_string_range(b.lens()) + "}");
Shucai Xiao's avatar
Shucai Xiao committed
862
863
        }

Shucai Xiao's avatar
Shucai Xiao committed
864
        auto out_lens   = a.lens();
865
        out_lens[dim_1] = b.lens()[dim_1];
Shucai Xiao's avatar
Shucai Xiao committed
866
        if(inputs.size() == 3 && out_lens != inputs.at(2).lens())
Shucai Xiao's avatar
Shucai Xiao committed
867
        {
Shucai Xiao's avatar
Shucai Xiao committed
868
869
            MIGRAPHX_THROW("DOT: dimension mismatch, operand C: {" +
                           to_string_range(inputs.at(2).lens()) +
Shucai Xiao's avatar
Shucai Xiao committed
870
                           "}, cannot add to operand A * B: {" + to_string_range(out_lens) + "}");
Shucai Xiao's avatar
Shucai Xiao committed
871
        }
Shucai Xiao's avatar
Shucai Xiao committed
872

873
        return {t, out_lens};
874
875
876
    }
};

877
struct unary
Scott Thornton's avatar
Scott Thornton committed
878
{
879
880
    shape compute_shape(std::vector<shape> inputs) const
    {
881
882
        check_shapes{inputs}.has(1);
        return inputs.at(0);
883
    }
Scott Thornton's avatar
Scott Thornton committed
884
885
};

886
struct identity
887
{
888
    std::string name() const { return "identity"; }
Scott Thornton's avatar
Scott Thornton committed
889
    shape compute_shape(std::vector<shape> inputs) const { return inputs.at(0); }
Paul's avatar
Paul committed
890
    argument compute(shape output_shape, std::vector<argument> args) const
891
892
893
    {
        return {std::move(output_shape), std::move(args.at(0).data)};
    }
Paul's avatar
Paul committed
894
    int output_alias(const std::vector<shape>&) const { return 0; }
895
896
897
};

struct abs : unary
Scott Thornton's avatar
Scott Thornton committed
898
{
899
    std::string name() const { return "abs"; }
Scott Thornton's avatar
Scott Thornton committed
900
901
};

902
struct exp : unary
Scott Thornton's avatar
Scott Thornton committed
903
{
904
    std::string name() const { return "exp"; }
Scott Thornton's avatar
Scott Thornton committed
905
906
};

Shucai Xiao's avatar
Shucai Xiao committed
907
908
909
910
911
struct log : unary
{
    std::string name() const { return "log"; }
};

912
struct sin : unary
Scott Thornton's avatar
Scott Thornton committed
913
{
914
    std::string name() const { return "sin"; }
Scott Thornton's avatar
Scott Thornton committed
915
916
};

917
struct cos : unary
Scott Thornton's avatar
Scott Thornton committed
918
{
919
    std::string name() const { return "cos"; }
Scott Thornton's avatar
Scott Thornton committed
920
921
};

922
struct tan : unary
Scott Thornton's avatar
Scott Thornton committed
923
{
924
    std::string name() const { return "tan"; }
Scott Thornton's avatar
Scott Thornton committed
925
926
};

927
struct asin : unary
Scott Thornton's avatar
Scott Thornton committed
928
{
929
    std::string name() const { return "asin"; }
Scott Thornton's avatar
Scott Thornton committed
930
931
};

932
struct acos : unary
Scott Thornton's avatar
Scott Thornton committed
933
{
934
    std::string name() const { return "acos"; }
Scott Thornton's avatar
Scott Thornton committed
935
936
};

937
struct atan : unary
Scott Thornton's avatar
Scott Thornton committed
938
{
939
    std::string name() const { return "atan"; }
Scott Thornton's avatar
Scott Thornton committed
940
941
};

942
943
944
945
946
947
948
949
950
951
struct sinh : unary
{
    std::string name() const { return "sinh"; }
};

struct cosh : unary
{
    std::string name() const { return "cosh"; }
};

952
struct tanh : unary
Scott Thornton's avatar
Scott Thornton committed
953
{
954
    std::string name() const { return "tanh"; }
Scott Thornton's avatar
Scott Thornton committed
955
956
};

957
struct sigmoid : unary
Scott Thornton's avatar
Scott Thornton committed
958
{
959
    std::string name() const { return "sigmoid"; }
Scott Thornton's avatar
Scott Thornton committed
960
961
};

962
struct neg : unary
Scott Thornton's avatar
Scott Thornton committed
963
{
964
    std::string name() const { return "neg"; }
Scott Thornton's avatar
Scott Thornton committed
965
966
};

Khalique's avatar
Khalique committed
967
968
969
970
971
struct relu : unary
{
    std::string name() const { return "relu"; }
};

Paul's avatar
Paul committed
972
973
974
975
976
977
978
979
980
981
struct softmax
{
    std::string name() const { return "softmax"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs}.has(1).only_dims(4);
        return inputs.at(0);
    }
};

Shucai Xiao's avatar
Shucai Xiao committed
982
983
984
985
986
987
988
struct logsoftmax
{
    int axis = 1;
    std::string name() const { return "logsoftmax"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs}.has(1);
989
        if(axis < 0 || axis > inputs[0].lens().size())
Shucai Xiao's avatar
Shucai Xiao committed
990
        {
Shucai Xiao's avatar
Shucai Xiao committed
991
992
            MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) +
                           " is out of range");
Shucai Xiao's avatar
Shucai Xiao committed
993
994
995
996
997
        }
        return inputs.at(0);
    }
};

998
struct flatten
Scott Thornton's avatar
Scott Thornton committed
999
{
Paul's avatar
Paul committed
1000
    uint64_t axis = 0;
Paul's avatar
Paul committed
1001
1002
1003
1004

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
1005
        return pack(f(self.axis, "axis"));
Paul's avatar
Paul committed
1006
1007
    }

Scott Thornton's avatar
Scott Thornton committed
1008
    std::string name() const { return "flatten"; }
Paul's avatar
Paul committed
1009
1010
1011
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs}.has(1);
Paul's avatar
Paul committed
1012
1013
        auto&& lens = inputs.front().lens();

Paul's avatar
Paul committed
1014
        if(axis > lens.size())
Paul's avatar
Paul committed
1015
        {
Paul's avatar
Paul committed
1016
            MIGRAPHX_THROW("axis for flatten must be less than tensor rank");
Paul's avatar
Paul committed
1017
        }
Paul's avatar
Paul committed
1018
1019
1020
1021
        auto x =
            std::accumulate(lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{});
        auto y =
            std::accumulate(lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<>{});
1022
        return {inputs.at(0).type(), {x, y}};
Paul's avatar
Paul committed
1023
    }
Paul's avatar
Paul committed
1024
    argument compute(shape output_shape, std::vector<argument> args) const
Paul's avatar
Paul committed
1025
    {
Paul's avatar
Paul committed
1026
        return {std::move(output_shape), std::move(args.front().data)};
Paul's avatar
Paul committed
1027
    }
Paul's avatar
Paul committed
1028
    int output_alias(const std::vector<shape>&) const { return 0; }
Scott Thornton's avatar
Scott Thornton committed
1029
};
1030

wsttiger's avatar
fixes  
wsttiger committed
1031
1032
1033
1034
1035
1036
1037
1038
/// The broadcast operator performs the numpy-style broadcasting of an axis of a given tensor. This
/// is achieved primarily by setting the stride of the broadcasted axis to zero. Linear indicies are
/// computed from multi-indicies by computing the inner product on the multi-index with the strides.
/// For example, if we have a tensor A(2,3) it has lengths of (2,3) and strides of (3,1). If we want
/// to compute the linear offset that corresponds to the element on the 2nd row (i = 1) and 3rd
/// column (j = 2), we compute the following inner product (1,2) dot (3, 1) = 1*3 + 2*1 = 5. It is
/// obvious from there that we can negate the effects of a given axis by setting the stride of that
/// axis to zero.
1039
1040
1041
struct broadcast
{
    uint64_t axis = 0;
Paul's avatar
Paul committed
1042
1043
1044
1045

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
1046
        return pack(f(self.axis, "axis"));
Paul's avatar
Paul committed
1047
1048
    }

Scott Thornton's avatar
Scott Thornton committed
1049
    shape broadcast_shape;
1050
1051
1052
    std::string name() const { return "broadcast"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Scott Thornton's avatar
Scott Thornton committed
1053
1054
        auto t     = inputs.at(0).type();
        auto input = inputs.at(0);
Paul's avatar
Paul committed
1055

Scott Thornton's avatar
Scott Thornton committed
1056
        std::vector<size_t> bcast_strides(broadcast_shape.lens().size(), 0);
1057

Scott Thornton's avatar
Scott Thornton committed
1058
1059
1060
        if(std::all_of(broadcast_shape.lens().cbegin(), broadcast_shape.lens().cend(), [&](auto x) {
               return x == 1;
           }))
1061
        {
Scott Thornton's avatar
Scott Thornton committed
1062
            if(axis != 0)
Paul's avatar
Paul committed
1063
                MIGRAPHX_THROW("when broadcasting tensor of size 1, axis should be 0");
Scott Thornton's avatar
Scott Thornton committed
1064
            return {t, broadcast_shape.lens(), std::move(bcast_strides)};
1065
1066
1067
        }
        else
        {
Scott Thornton's avatar
Scott Thornton committed
1068
            assert(broadcast_shape.lens().size() - axis >= input.lens().size());
Scott Thornton's avatar
Scott Thornton committed
1069
1070
            if(!std::equal(
                   input.lens().begin(), input.lens().end(), broadcast_shape.lens().begin() + axis))
Paul's avatar
Paul committed
1071
                MIGRAPHX_THROW("when broadcasting success sizes must match");
Paul's avatar
Paul committed
1072
            std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
Scott Thornton's avatar
Scott Thornton committed
1073
            return {t, broadcast_shape.lens(), std::move(bcast_strides)};
1074
1075
        }
    }
Paul's avatar
Paul committed
1076
    argument compute(shape output_shape, std::vector<argument> args) const
Scott Thornton's avatar
Scott Thornton committed
1077
    {
Scott Thornton's avatar
Scott Thornton committed
1078
        return {std::move(output_shape), std::move(args.at(0).data)};
Scott Thornton's avatar
Scott Thornton committed
1079
    }
Paul's avatar
Paul committed
1080
    int output_alias(const std::vector<shape>&) const { return 0; }
1081
1082
};

Scott Thornton's avatar
Scott Thornton committed
1083
1084
1085
struct multibroadcast
{
    std::vector<std::size_t> output_lens;
1086
1087
1088
1089
1090
1091
1092

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(f(self.output_lens, "output_lens"));
    }

Scott Thornton's avatar
Scott Thornton committed
1093
    std::string name() const { return "multibroadcast"; }
1094

Scott Thornton's avatar
Scott Thornton committed
1095
1096
1097
1098
1099
1100
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(1);
        auto t     = inputs.at(0).type();
        auto input = inputs.at(0);

wsttiger's avatar
wsttiger committed
1101
        if(input.lens().empty())
Paul's avatar
Paul committed
1102
            MIGRAPHX_THROW("inputs dimensions should be > 0");
Scott Thornton's avatar
Scott Thornton committed
1103

Scott Thornton's avatar
Scott Thornton committed
1104
        if(input.lens().size() > output_lens.size())
Paul's avatar
Paul committed
1105
            MIGRAPHX_THROW("inputs dimensions should <= output size");
Scott Thornton's avatar
Scott Thornton committed
1106
1107

        std::vector<size_t> bcast_strides(output_lens.size(), 0);
Scott Thornton's avatar
Scott Thornton committed
1108
1109
        auto offset = output_lens.size() - input.lens().size();
        for(int i = input.lens().size() - 1; i >= 0; i--)
Scott Thornton's avatar
Scott Thornton committed
1110
        {
Scott Thornton's avatar
Scott Thornton committed
1111
            if(output_lens[i + offset] == input.lens()[i])
Scott Thornton's avatar
Scott Thornton committed
1112
            {
Scott Thornton's avatar
Scott Thornton committed
1113
                bcast_strides[i + offset] = input.strides()[i];
Scott Thornton's avatar
Scott Thornton committed
1114
1115
1116
1117
            }
        }
        return {t, output_lens, bcast_strides};
    }
Paul's avatar
Paul committed
1118
    argument compute(shape output_shape, std::vector<argument> args) const
Scott Thornton's avatar
Scott Thornton committed
1119
1120
1121
1122
1123
1124
    {
        return {std::move(output_shape), std::move(args.at(0).data)};
    }
    int output_alias(const std::vector<shape>&) const { return 0; }
};

Khalique's avatar
Khalique committed
1125
1126
1127
1128
1129
1130
1131
1132
1133
struct scalar
{
    shape scalar_bcast;

    std::string name() const { return "scalar"; }

    shape compute_shape(std::vector<shape> inputs) const
    {
        assert(check_shapes{inputs}.has(1).only_dims(1).size() == 1);
Paul's avatar
Paul committed
1134
        auto t = inputs.at(0).type();
Khalique's avatar
Khalique committed
1135
1136
1137
1138
        std::vector<std::size_t> strides(scalar_bcast.lens().size(), 0);
        return {t, scalar_bcast.lens(), strides};
    }

Paul's avatar
Paul committed
1139
    argument compute(shape output_shape, std::vector<argument> args) const
Khalique's avatar
Khalique committed
1140
1141
1142
    {
        return {std::move(output_shape), std::move(args.at(0).data)};
    }
Paul's avatar
Paul committed
1143
    int output_alias(const std::vector<shape>&) const { return 0; }
Khalique's avatar
Khalique committed
1144
1145
};

1146
struct binary
Scott Thornton's avatar
Scott Thornton committed
1147
{
1148
1149
    shape compute_shape(std::vector<shape> inputs) const
    {
1150
        check_shapes{inputs}.has(2).same_type().same_dims();
Scott Thornton's avatar
Scott Thornton committed
1151
        auto t    = inputs.at(0).type();
1152
1153
        auto lens = inputs.at(0).lens();
        return {t, lens};
1154
    }
Scott Thornton's avatar
Scott Thornton committed
1155
1156
};

1157
1158
1159
1160
1161
1162
struct add : binary
{
    std::string name() const { return "add"; }
};

struct sub : binary
Scott Thornton's avatar
Scott Thornton committed
1163
1164
1165
1166
{
    std::string name() const { return "sub"; }
};

1167
struct mul : binary
Scott Thornton's avatar
Scott Thornton committed
1168
1169
1170
1171
{
    std::string name() const { return "mul"; }
};

1172
struct div : binary
Scott Thornton's avatar
Scott Thornton committed
1173
1174
1175
1176
{
    std::string name() const { return "div"; }
};

Khalique's avatar
Khalique committed
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
struct max : binary
{
    std::string name() const { return "max"; }
};

struct min : binary
{
    std::string name() const { return "min"; }
};

Paul's avatar
Paul committed
1187
1188
1189
1190
struct load
{
    shape s;
    std::size_t offset = 0;
Paul's avatar
Paul committed
1191
1192
1193
1194

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
1195
        return pack(f(self.s, "shape"), f(self.offset, "offset"));
Paul's avatar
Paul committed
1196
1197
    }

Paul's avatar
Paul committed
1198
1199
1200
1201
1202
1203
    std::string name() const { return "load"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
        check_shapes{inputs}.has(1);
        return s;
    }
Paul's avatar
Paul committed
1204
    argument compute(const shape&, const std::vector<argument>& args) const
Paul's avatar
Paul committed
1205
    {
Paul's avatar
Paul committed
1206
1207
        if((offset + s.bytes()) > args[0].get_shape().bytes())
            MIGRAPHX_THROW("Load access is out of bounds");
Paul's avatar
Paul committed
1208
1209
        return {s, args[0].data() + offset};
    }
Paul's avatar
Paul committed
1210
    int output_alias(const std::vector<shape>&) const { return 0; }
Paul's avatar
Paul committed
1211
1212
1213
1214
1215
1216
1217
1218

    friend std::ostream& operator<<(std::ostream& os, const load& op)
    {
        os << op.name() << "[";
        os << "offset=" << op.offset << ",";
        os << "end=" << (op.offset + op.s.bytes()) << "]";
        return os;
    }
Paul's avatar
Paul committed
1219
1220
};

Paul's avatar
Paul committed
1221
struct outline
Scott Thornton's avatar
Scott Thornton committed
1222
{
Paul's avatar
Paul committed
1223
    shape s;
Paul's avatar
Paul committed
1224
1225
1226
1227

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
1228
        return pack(f(self.s, "shape"));
Paul's avatar
Paul committed
1229
1230
    }

Paul's avatar
Paul committed
1231
    std::string name() const { return "outline"; }
Paul's avatar
Paul committed
1232
    shape compute_shape(const std::vector<shape>& inputs) const
Paul's avatar
Paul committed
1233
    {
Paul's avatar
Paul committed
1234
        check_shapes{inputs, *this}.has(0);
Paul's avatar
Paul committed
1235
1236
        return s;
    }
Paul's avatar
Paul committed
1237
    argument compute(const shape&, const std::vector<argument>&) const { return {s, nullptr}; }
Scott Thornton's avatar
Scott Thornton committed
1238
1239
};

1240
1241
// indicate rnn computation direction
enum class rnn_direction
Shucai Xiao's avatar
Shucai Xiao committed
1242
{
1243
1244
1245
1246
    forward,
    reverse,
    bidirectional,
};
Shucai Xiao's avatar
Shucai Xiao committed
1247

1248
1249
struct rnn
{
Shucai Xiao's avatar
Shucai Xiao committed
1250
    std::size_t hidden_size = 1;
1251
    std::vector<operation> actv_funcs{tanh{}, tanh{}};
1252
    rnn_direction direction = rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1253
    float clip              = 0.0f;
Shucai Xiao's avatar
Shucai Xiao committed
1254
1255
1256
1257

    std::string name() const { return "rnn"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Shucai Xiao's avatar
Shucai Xiao committed
1258
        auto in_dims     = inputs[0].lens();
Shucai Xiao's avatar
Shucai Xiao committed
1259
1260
        auto hidden_dims = inputs[2].lens();
        if(hidden_size != hidden_dims[2])
Shucai Xiao's avatar
Shucai Xiao committed
1261
1262
1263
1264
1265
        {
            MIGRAPHX_THROW("RNN: hidden size mismatch in attribute and input");
        }

        std::size_t num_directions = 1;
1266
        if(direction == rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1267
1268
1269
1270
        {
            num_directions = 2;
        }

Shucai Xiao's avatar
Shucai Xiao committed
1271
        if(num_directions != hidden_dims[0])
Shucai Xiao's avatar
Shucai Xiao committed
1272
        {
1273
            MIGRAPHX_THROW("RNN: num_direction mismatch in attribute and input");
Shucai Xiao's avatar
Shucai Xiao committed
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
        }

        std::vector<std::size_t> out_dims(in_dims);
        out_dims.insert(out_dims.begin() + 1, num_directions);
        out_dims.back() = hidden_size;

        return {inputs[0].type(), out_dims};
    }
};

Shucai Xiao's avatar
Shucai Xiao committed
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
struct rnn_last_output
{
    std::string name() const { return "rnn_last_output"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(1);
        auto dims = inputs[0].lens();

        // remove the first dimension, remaing are output shape
        dims.erase(dims.begin());
        return {inputs[0].type(), dims};
    }
};

Shucai Xiao's avatar
Shucai Xiao committed
1298
1299
1300
1301
struct gru
{
    std::size_t hidden_size = 1;
    std::vector<operation> actv_funcs{sigmoid{}, tanh{}};
1302
    rnn_direction direction = rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1303
1304
    float clip              = 0.0f;
    int linear_before_reset = 0;
Shucai Xiao's avatar
Shucai Xiao committed
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316

    std::string name() const { return "gru"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        auto in_dims     = inputs[0].lens();
        auto hidden_dims = inputs[2].lens();
        if(hidden_size != hidden_dims[2])
        {
            MIGRAPHX_THROW("GRU: hidden size mismatch in attribute and input");
        }

        std::size_t num_directions = 1;
1317
        if(direction == rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
        {
            num_directions = 2;
        }

        if(num_directions != hidden_dims[0])
        {
            MIGRAPHX_THROW("GRU: num_direction does not match the direction attribute");
        }

        std::vector<std::size_t> out_dims(in_dims);
        out_dims.insert(out_dims.begin() + 1, num_directions);
        out_dims.back() = hidden_size;

        return {inputs[0].type(), out_dims};
    }
};

Shucai Xiao's avatar
Shucai Xiao committed
1335
1336
1337
1338
struct lstm
{
    std::size_t hidden_size = 1;
    std::vector<operation> actv_funcs{sigmoid{}, tanh{}, tanh{}};
Shucai Xiao's avatar
Shucai Xiao committed
1339
    rnn_direction direction = rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1340
1341
    float clip              = 0.0f;
    int input_forget        = 0;
Shucai Xiao's avatar
Shucai Xiao committed
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353

    std::string name() const { return "lstm"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        auto in_dims     = inputs[0].lens();
        auto hidden_dims = inputs[2].lens();
        if(hidden_size != hidden_dims[2])
        {
            MIGRAPHX_THROW("LSTM: hidden size mismatch in attribute and input");
        }

        std::size_t num_directions = 1;
Shucai Xiao's avatar
Shucai Xiao committed
1354
        if(direction == rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
        {
            num_directions = 2;
        }

        if(num_directions != hidden_dims[0])
        {
            MIGRAPHX_THROW("LSTM: num_direction does not match the direction attribute");
        }

        std::vector<std::size_t> out_dims(in_dims);
        out_dims.insert(out_dims.begin() + 1, num_directions);
        out_dims.back() = hidden_size;

        return {inputs[0].type(), out_dims};
    }
};

struct lstm_last_cell_output
{
    std::string name() const { return "lstm_last_cell_output"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(1);
        auto dims = inputs[0].lens();

        // remove the first dimension, remaing are output shape
        dims.erase(dims.begin());
        return {inputs[0].type(), dims};
    }
};

1386
1387
1388
struct undefined
{
    std::string name() const { return "undefined"; }
Shucai Xiao's avatar
Shucai Xiao committed
1389
    shape compute_shape(const std::vector<shape>& inputs) const
1390
1391
1392
1393
1394
1395
1396
1397
    {
        check_shapes{inputs, *this}.has(0);
        return {};
    }

    argument compute(const shape&, const std::vector<argument>&) const { return {{}, nullptr}; }
};

1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
struct unknown
{
    std::string op;
    std::string name() const { return "unknown:" + op; }
    shape compute_shape(std::vector<shape> input) const
    {
        if(input.empty())
            return {};
        else
            return input.front();
    }

    friend std::ostream& operator<<(std::ostream& os, const unknown& x)
    {
        os << x.name();
        return os;
    }
};

1417
} // namespace op
Paul's avatar
Paul committed
1418
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
1419
} // namespace migraphx
Paul's avatar
Paul committed
1420
1421

#endif