operators.hpp 41.2 KB
Newer Older
Paul's avatar
Paul committed
1
2
#ifndef MIGRAPHX_GUARD_OPERATORS_HPP
#define MIGRAPHX_GUARD_OPERATORS_HPP
Paul's avatar
Paul committed
3

4
#include <array>
Paul's avatar
Paul committed
5
6
7
8
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
9
10
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
Paul's avatar
Paul committed
11
#include <migraphx/config.hpp>
Paul's avatar
Paul committed
12
#include <cmath>
Paul's avatar
Paul committed
13
#include <utility>
Paul's avatar
Paul committed
14

Paul's avatar
Paul committed
15
namespace migraphx {
Paul's avatar
Paul committed
16
inline namespace MIGRAPHX_INLINE_NS {
17
namespace op {
Paul's avatar
Paul committed
18

19
20
21
22
23
24
25
enum padding_mode_t
{
    default_, // NOLINT
    same,
    valid
};

Paul's avatar
Paul committed
26
27
struct not_computable
{
Paul's avatar
Paul committed
28
    argument compute(const shape&, const std::vector<argument>&) const
Paul's avatar
Paul committed
29
    {
Paul's avatar
Paul committed
30
        MIGRAPHX_THROW("not computable");
Paul's avatar
Paul committed
31
    }
Paul's avatar
Paul committed
32
33
};

34
35
struct batch_norm_inference
{
36
37
    float epsilon  = 1.0e-6f;
    float momentum = 0.9f;
38
39
40

    std::string name() const { return "batch_norm_inference"; }

41
42
43
44
45
46
47
48
    enum bn_infer_mode_t
    {
        per_activation,
        spatial,
    };

    bn_infer_mode_t bn_mode = spatial;

Paul's avatar
Paul committed
49
50
51
52
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(
Paul's avatar
Paul committed
53
            f(self.epsilon, "epsilon"), f(self.momentum, "momentum"), f(self.bn_mode, "bn_mode"));
Paul's avatar
Paul committed
54
    }
55

56
57
58
59
60
61
62
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(5);
        return inputs.front();
    }
};

Khalique's avatar
Khalique committed
63
struct lrn
Khalique's avatar
Khalique committed
64
65
{
    float alpha = 0.0001;
Khalique's avatar
Khalique committed
66
67
    float beta  = 0.75;
    float bias  = 1.0;
Khalique's avatar
Khalique committed
68
    int size    = 1;
Khalique's avatar
Khalique committed
69
    std::string name() const { return "lrn"; }
Khalique's avatar
Khalique committed
70

Khalique's avatar
Khalique committed
71
72
73
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Khalique's avatar
Khalique committed
74
75
76
77
        return pack(f(self.alpha, "alpha"),
                    f(self.beta, "beta"),
                    f(self.bias, "bias"),
                    f(self.size, "size"));
Khalique's avatar
Khalique committed
78
79
    }

Khalique's avatar
Khalique committed
80
81
82
83
84
85
86
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(1);
        return inputs.front();
    }
};

Paul's avatar
Paul committed
87
struct convolution
Paul's avatar
Paul committed
88
{
Paul's avatar
Paul committed
89
90
91
    std::array<std::size_t, 2> padding  = {{0, 0}};
    std::array<std::size_t, 2> stride   = {{1, 1}};
    std::array<std::size_t, 2> dilation = {{1, 1}};
Khalique's avatar
Khalique committed
92

Paul's avatar
Paul committed
93
    padding_mode_t padding_mode = default_;
Khalique's avatar
Khalique committed
94
    int group                   = 1;
Paul's avatar
Paul committed
95
96
97
98

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
99
100
101
        return pack(f(self.padding, "padding"),
                    f(self.stride, "stride"),
                    f(self.dilation, "dilation"),
Khalique's avatar
Khalique committed
102
103
                    f(self.padding_mode, "padding_mode"),
                    f(self.group, "group"));
Paul's avatar
Paul committed
104
105
    }

Paul's avatar
Paul committed
106
    std::string name() const { return "convolution"; }
Paul's avatar
Paul committed
107
108
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
109
        check_shapes{inputs, *this}.has(2).same_type().same_ndims().only_dims(4);
Paul's avatar
Paul committed
110

Paul's avatar
Paul committed
111
        const shape& input   = inputs.at(0);
Paul's avatar
Paul committed
112
        const shape& weights = inputs.at(1);
Paul's avatar
Paul committed
113
        auto t               = input.type();
Paul's avatar
Paul committed
114
115
        if(padding_mode == default_)
        {
Paul's avatar
Paul committed
116
117
118
            return {t,
                    {
                        input.lens()[0],
Khalique's avatar
Khalique committed
119
                        weights.lens()[0],
Paul's avatar
Paul committed
120
121
122
123
124
125
126
127
128
129
130
131
132
                        std::size_t(std::max<std::ptrdiff_t>(
                            1,
                            (input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) +
                             2 * padding[0]) /
                                    stride[0] +
                                1)),
                        std::size_t(std::max<std::ptrdiff_t>(
                            1,
                            (input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) +
                             2 * padding[1]) /
                                    stride[1] +
                                1)),
                    }};
Paul's avatar
Paul committed
133
134
135
136
137
        }
        else if(padding_mode == same)
        {
            return {t,
                    {input.lens()[0],
Khalique's avatar
Khalique committed
138
                     weights.lens()[0],
Paul's avatar
Paul committed
139
140
141
142
143
144
145
146
147
148
                     static_cast<std::size_t>(
                         std::ceil(static_cast<double>(input.lens()[2]) / stride[0])),
                     static_cast<std::size_t>(
                         std::ceil(static_cast<double>(input.lens()[3]) / stride[1]))}};
        }
        else if(padding_mode == valid)
        {
            return {
                t,
                {input.lens()[0],
Khalique's avatar
Khalique committed
149
                 weights.lens()[0],
Paul's avatar
Paul committed
150
151
152
153
154
155
156
                 static_cast<std::size_t>(std::ceil(
                     static_cast<double>(input.lens()[2] - weights.lens()[2] + 1) / stride[0])),
                 static_cast<std::size_t>(std::ceil(
                     static_cast<double>(input.lens()[3] - weights.lens()[3] + 1) / stride[1]))}};
        }
        else
        {
Paul's avatar
Paul committed
157
            MIGRAPHX_THROW("Invalid padding mode");
Paul's avatar
Paul committed
158
        }
Paul's avatar
Paul committed
159
160
161
    }
};

Scott Thornton's avatar
Scott Thornton committed
162
163
struct im2col
{
Scott Thornton's avatar
Scott Thornton committed
164
165
166
    std::array<std::size_t, 2> padding  = {{0, 0}};
    std::array<std::size_t, 2> stride   = {{1, 1}};
    std::array<std::size_t, 2> dilation = {{1, 1}};
Khalique's avatar
Khalique committed
167

Paul's avatar
Paul committed
168
169
170
171
172
    padding_mode_t padding_mode = default_;

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
173
174
175
176
        return pack(f(self.padding, "padding"),
                    f(self.stride, "stride"),
                    f(self.dilation, "dilation"),
                    f(self.padding_mode, "padding_mode"));
Paul's avatar
Paul committed
177
    }
Scott Thornton's avatar
Scott Thornton committed
178
179
180
181
182

    std::string name() const { return "im2col"; }

    shape compute_shape(std::vector<shape> inputs) const
    {
Scott Thornton's avatar
Scott Thornton committed
183
184
185
        auto input          = inputs[0];
        auto weights        = inputs[1];
        auto batch_size     = input.lens()[0];
Scott Thornton's avatar
Scott Thornton committed
186
        auto input_channels = weights.lens()[1];
Scott Thornton's avatar
Scott Thornton committed
187
188
        auto kernel_height  = weights.lens()[2];
        auto kernel_width   = weights.lens()[3];
Scott Thornton's avatar
Scott Thornton committed
189
        check_shapes{inputs, *this}.has(2);
Scott Thornton's avatar
Scott Thornton committed
190
        if(batch_size != 1)
Paul's avatar
Paul committed
191
            MIGRAPHX_THROW("im2col only support batch_size 1");
Scott Thornton's avatar
Scott Thornton committed
192
        auto output_height = std::size_t(std::max<std::ptrdiff_t>(
Scott Thornton's avatar
Scott Thornton committed
193
194
195
            1,
            (input.lens()[2] - (1 + dilation[0] * (kernel_height - 1)) + 2 * padding[0]) /
                    stride[0] +
Scott Thornton's avatar
Scott Thornton committed
196
                1));
Scott Thornton's avatar
Scott Thornton committed
197
198
199
200
        auto output_width  = std::size_t(std::max<std::ptrdiff_t>(
            1,
            (input.lens()[3] - (1 + dilation[1] * (kernel_width - 1)) + 2 * padding[1]) /
                    stride[1] +
Scott Thornton's avatar
Scott Thornton committed
201
                1));
Scott Thornton's avatar
Scott Thornton committed
202
203
        auto channels_col  = kernel_height * kernel_width * input_channels;
        return {input.type(), {output_height * output_width, channels_col}};
Scott Thornton's avatar
Scott Thornton committed
204
205
206
    }
};

Paul's avatar
Paul committed
207
struct pooling
Paul's avatar
Paul committed
208
{
Paul's avatar
Paul committed
209
    std::string mode                   = "average";
Paul's avatar
Paul committed
210
211
212
    std::array<std::size_t, 2> padding = {{0, 0}};
    std::array<std::size_t, 2> stride  = {{1, 1}};
    std::array<std::size_t, 2> lengths = {{1, 1}};
Khalique's avatar
Khalique committed
213
    padding_mode_t padding_mode        = default_;
Paul's avatar
Paul committed
214
215
216
217

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
218
219
        return pack(f(self.mode, "mode"),
                    f(self.padding, "padding"),
220
                    f(self.padding, "padding_mode"),
Paul's avatar
Paul committed
221
222
                    f(self.stride, "stride"),
                    f(self.lengths, "lengths"));
Paul's avatar
Paul committed
223
224
    }

Paul's avatar
Paul committed
225
    std::string name() const { return "pooling"; }
Scott Thornton's avatar
Scott Thornton committed
226

Paul's avatar
Paul committed
227
228
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
229
        check_shapes{inputs, *this}.has(1).only_dims(4);
Paul's avatar
Paul committed
230

Paul's avatar
Paul committed
231
        const shape& input = inputs.at(0);
Paul's avatar
Paul committed
232
        auto t             = input.type();
Paul's avatar
Paul committed
233

Paul's avatar
Paul committed
234
235
        assert(lengths[0] <= (input.lens()[2] + 2 * padding[0]));
        assert(lengths[1] <= (input.lens()[3] + 2 * padding[1]));
Paul's avatar
Paul committed
236

237
238
        if(padding_mode == default_)
        {
Khalique's avatar
Khalique committed
239
240
            return {
                t,
Scott Thornton's avatar
Scott Thornton committed
241
242
243
244
245
                {
                    input.lens()[0],
                    input.lens()[1],
                    std::size_t(std::max<std::ptrdiff_t>(
                        1,
Paul's avatar
Paul committed
246
                        std::ptrdiff_t(std::floor((input.lens()[2] + 2 * padding[0] - lengths[0]) /
Paul's avatar
Paul committed
247
                                                  static_cast<float>(stride[0]))) +
Scott Thornton's avatar
Scott Thornton committed
248
249
250
                            1)),
                    std::size_t(std::max<std::ptrdiff_t>(
                        1,
Paul's avatar
Paul committed
251
                        std::ptrdiff_t(std::floor((input.lens()[3] + 2 * padding[1] - lengths[1]) /
Paul's avatar
Paul committed
252
                                                  static_cast<float>(stride[1]))) +
Scott Thornton's avatar
Scott Thornton committed
253
254
                            1)),
                }};
255
256
257
258
259
260
261
262
263
264
265
266
267
268
        }
        else if(padding_mode == same)
        {
            return {t,
                    {input.lens()[0],
                     input.lens()[1],
                     static_cast<std::size_t>(
                         std::ceil(static_cast<double>(input.lens()[2]) / stride[0])),
                     static_cast<std::size_t>(
                         std::ceil(static_cast<double>(input.lens()[3]) / stride[1]))}};
        }
        else if(padding_mode == valid)
        {
            return {t,
Khalique's avatar
Khalique committed
269
270
271
272
273
274
275
276
277
278
279
280
281
282
                    {
                        input.lens()[0],
                        input.lens()[1],
                        std::size_t(std::max<std::ptrdiff_t>(
                            1,
                            std::ptrdiff_t(std::floor((input.lens()[2] - lengths[0]) /
                                                      static_cast<float>(stride[0]))) +
                                1)),
                        std::size_t(std::max<std::ptrdiff_t>(
                            1,
                            std::ptrdiff_t(std::floor((input.lens()[3] - lengths[1]) /
                                                      static_cast<float>(stride[1]))) +
                                1)),
                    }};
283
284
285
286
287
        }
        else
        {
            MIGRAPHX_THROW("Invalid padding mode");
        }
Paul's avatar
Paul committed
288
289
290
    }
};

Khalique's avatar
Khalique committed
291
292
293
294
295
296
297
298
299
struct leaky_relu
{
    std::string name() const { return "leaky_relu"; }
    float alpha;
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(1);
        return inputs.front();
    }
Khalique's avatar
Khalique committed
300
301
302

    template <class Self, class F>
    static auto reflect(Self& self, F f)
Khalique's avatar
Khalique committed
303
    {
Khalique's avatar
Khalique committed
304
        return pack(f(self.alpha, "alpha"));
Khalique's avatar
Khalique committed
305
306
307
308
309
310
311
312
313
314
315
316
    }
};

struct elu
{
    std::string name() const { return "elu"; }
    float alpha;
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(1);
        return inputs.front();
    }
Khalique's avatar
Khalique committed
317
318
319

    template <class Self, class F>
    static auto reflect(Self& self, F f)
Khalique's avatar
Khalique committed
320
    {
Khalique's avatar
Khalique committed
321
        return pack(f(self.alpha, "alpha"));
Khalique's avatar
Khalique committed
322
    }
Khalique's avatar
Khalique committed
323
324
};

325
326
327
struct transpose
{
    std::vector<int64_t> dims;
Paul's avatar
Paul committed
328
329
330
331

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
332
        return pack(f(self.dims, "dims"));
Paul's avatar
Paul committed
333
334
    }

335
336
337
    std::string name() const { return "transpose"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
338
        check_shapes{inputs, *this}.has(1);
339
        auto input         = inputs.at(0);
340
        auto input_lens    = input.lens();
341
342
        auto input_strides = input.strides();
        auto t             = input.type();
Paul's avatar
Paul committed
343
344
        if(dims.size() != input_lens.size())
        {
Paul's avatar
Paul committed
345
            MIGRAPHX_THROW("Permutation has wrong number of axes");
346
347
348
        }
        std::vector<int64_t> axes(dims.size());
        std::iota(axes.begin(), axes.end(), 0);
Paul's avatar
Paul committed
349
350
        if(!std::is_permutation(axes.begin(), axes.end(), dims.begin()))
        {
Paul's avatar
Paul committed
351
            MIGRAPHX_THROW("Invalid permutation");
352
        }
353
354
        std::vector<size_t> output_lens(input_lens.size());
        std::vector<size_t> output_strides(input_lens.size());
Paul's avatar
Paul committed
355
        for(std::size_t i = 0; i < output_lens.size(); i++)
Paul's avatar
Paul committed
356
357
        {
            output_lens[i]    = input_lens[dims[i]];
358
359
            output_strides[i] = input_strides[dims[i]];
        }
360
        return {t, output_lens, output_strides};
361
    }
Paul's avatar
Paul committed
362
    argument compute(shape output_shape, std::vector<argument> args) const
Paul's avatar
Paul committed
363
    {
Paul's avatar
Paul committed
364
        return {std::move(output_shape), std::move(args.front().data)};
Paul's avatar
Paul committed
365
    }
Paul's avatar
Paul committed
366
    int output_alias(const std::vector<shape>&) const { return 0; }
367
368
};

wsttiger's avatar
fixes  
wsttiger committed
369
370
371
372
373
374
/// The contiguous operator takes a non-standard input tensor and returns
/// the same tensor but in standard form. For example, if input tensor A which has lens = (4,5)
/// is first transposed, i.e. lens = (5,4), this tensor's data layout remained the same
/// during the transpose operation; only it's shape lengths and strides were changed.
/// This leaves the tensor in a non-standard form. The contiguous operator copies the
/// underlying data such that resulting tensor is returned to a standard form.
Paul's avatar
Paul committed
375
struct contiguous
376
377
378
379
{
    std::string name() const { return "contiguous"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
380
        check_shapes{inputs, *this}.has(1);
Paul's avatar
Paul committed
381
382
        auto lens = inputs.at(0).lens();
        auto t    = inputs.at(0).type();
383
384
        return {t, lens};
    }
Paul's avatar
Paul committed
385
386
387
388
389
390
391
392
393
394
395
    argument compute(const shape& output_shape, std::vector<argument> args) const
    {
        assert(output_shape.standard());
        argument result{output_shape};
        visit_all(result, args[0])([&](auto output, auto input) {
            shape_for_each(output.get_shape(), [&](const auto& idx) {
                output(idx.begin(), idx.end()) = input(idx.begin(), idx.end());
            });
        });
        return result;
    }
396
397
};

398
399
400
401
struct concat
{
    std::size_t axis = 0;
    std::string name() const { return "concat"; }
402
    std::vector<std::size_t> compute_offsets(const shape& output_shape,
Paul's avatar
Paul committed
403
                                             const std::vector<argument>& args) const
404
405
406
407
408
409
410
411
412
413
414
    {
        std::vector<std::size_t> offsets;
        std::vector<std::size_t> offset(args[0].get_shape().lens().size(), 0);
        offset[axis] = 0;
        for(const auto& arg : args)
        {
            offsets.push_back(output_shape.index(offset));
            offset[axis] += arg.get_shape().lens()[axis];
        }
        return offsets;
    }
415
416
    shape compute_shape(std::vector<shape> inputs) const
    {
Scott Thornton's avatar
Scott Thornton committed
417
        if(inputs.empty())
418
        {
Paul's avatar
Paul committed
419
            MIGRAPHX_THROW("Number of input tensors should exceed 0");
420
421
422
        }

        const auto& first_shape_lens = inputs.front().lens();
Scott Thornton's avatar
Scott Thornton committed
423
424
425
426
427
428
429
430
431
        const auto& type             = inputs.front().type();
        for(std::size_t l = 0; l < first_shape_lens.size(); l++)
        {
            if(l != axis)
            {
                if(!std::all_of(inputs.begin(), inputs.end(), [&](auto s) {
                       return s.lens()[l] == first_shape_lens[l];
                   }))
                {
Paul's avatar
Paul committed
432
                    MIGRAPHX_THROW("Non-axis dimensions should match");
433
434
435
436
                }
            }
        }
        std::size_t new_dim_axis = 0;
Scott Thornton's avatar
Scott Thornton committed
437
        for(const auto& input : inputs)
438
439
440
441
442
443
444
445
446
        {
            const auto& lens = input.lens();
            new_dim_axis += lens[axis];
        }
        std::vector<std::size_t> new_lens;
        std::copy(first_shape_lens.begin(), first_shape_lens.end(), std::back_inserter(new_lens));
        new_lens[axis] = new_dim_axis;
        return {type, new_lens};
    }
Paul's avatar
Paul committed
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
    argument compute(const shape& output_shape, std::vector<argument> args) const
    {
        argument result{output_shape};
        std::vector<std::size_t> coffsets = compute_offsets(output_shape, args);
        for(std::size_t l = 0; l < args.size(); l++)
        {
            auto argl             = args[l];
            std::size_t nelements = argl.get_shape().elements();
            visit_all(result, argl)([&](auto output, auto input) {
                auto slice_shape =
                    shape{output_shape.type(), input.get_shape().lens(), output_shape.strides()};
                auto slice = make_view(slice_shape, output.data() + coffsets[l]);
                // cppcheck-suppress useStlAlgorithm
                for(std::size_t i = 0; i < nelements; i++)
                {
                    slice[i] = input[i];
                }
            });
        }
        return result;
    }
468
469
};

470
471
472
473
474
struct slice
{
    std::vector<int64_t> axes;
    std::vector<int64_t> starts;
    std::vector<int64_t> ends;
Paul's avatar
Paul committed
475
476
477
478

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
479
        return pack(f(self.axes, "axes"), f(self.starts, "starts"), f(self.ends, "ends"));
Paul's avatar
Paul committed
480
481
    }

482
    std::string name() const { return "slice"; }
Scott Thornton's avatar
Scott Thornton committed
483
484

    auto fix_index(const std::vector<std::size_t>& lens, std::size_t axis, int64_t index) const
485
    {
Scott Thornton's avatar
Scott Thornton committed
486
        int64_t r = std::min(index, static_cast<int64_t>(lens[axis]));
Scott Thornton's avatar
Scott Thornton committed
487
488
        if(r < 0)
            r += lens[axis];
Scott Thornton's avatar
Scott Thornton committed
489
        return std::size_t(r);
Scott Thornton's avatar
Scott Thornton committed
490
491
492
493
494
495
496
    }

    auto compute_offset(const shape& s) const
    {
        const std::vector<std::size_t>& lens    = s.lens();
        const std::vector<std::size_t>& strides = s.strides();
        auto offset                             = 0;
Scott Thornton's avatar
Scott Thornton committed
497
        if(!axes.empty())
Scott Thornton's avatar
Scott Thornton committed
498
        {
Scott Thornton's avatar
Scott Thornton committed
499
500
501
502
503
            for(std::size_t i = 0; i < axes.size(); i++)
            {
                auto axis = axes[i];
                offset += fix_index(lens, axis, starts[i]) * strides[axis];
            }
504
        }
Scott Thornton's avatar
Scott Thornton committed
505
506
        else
        {
Scott Thornton's avatar
Scott Thornton committed
507
508
509
510
            for(std::size_t axis = 0; axis < lens.size(); axis++)
            {
                offset += fix_index(lens, axis, starts[axis]) * strides[axis];
            }
511
        }
Scott Thornton's avatar
Scott Thornton committed
512
513
514
515
516
        return offset;
    }

    shape compute_shape(std::vector<shape> inputs) const
    {
Scott Thornton's avatar
Scott Thornton committed
517
518
519
520
        auto input_shape        = inputs[0];
        auto t                  = input_shape.type();
        const auto& old_lens    = input_shape.lens();
        const auto& old_strides = input_shape.strides();
Scott Thornton's avatar
Scott Thornton committed
521
        if(starts.size() != axes.size() || axes.size() != ends.size())
Scott Thornton's avatar
Scott Thornton committed
522
        {
Paul's avatar
Paul committed
523
            MIGRAPHX_THROW("inconsistent sizes");
524
        }
Scott Thornton's avatar
Scott Thornton committed
525
526
        std::vector<std::size_t> new_lens = old_lens;
        for(std::size_t i = 0; i < axes.size(); i++)
Scott Thornton's avatar
Scott Thornton committed
527
        {
Scott Thornton's avatar
Scott Thornton committed
528
529
530
            auto axis = axes[i];
            new_lens[axis] =
                fix_index(old_lens, axis, ends[i]) - fix_index(old_lens, axis, starts[i]);
531
532
533
        }
        return shape{t, new_lens, old_strides};
    }
Paul's avatar
Paul committed
534
    argument compute(shape output_shape, std::vector<argument> args) const
535
    {
Scott Thornton's avatar
Scott Thornton committed
536
537
538
        auto input  = args[0];
        auto offset = compute_offset(input.get_shape()) * output_shape.type_size();
        return {std::move(output_shape), [=] { return input.data() + offset; }};
539
    }
Paul's avatar
Paul committed
540
    int output_alias(const std::vector<shape>&) const { return 0; }
541
542
543
544
545
};

struct squeeze
{
    std::vector<int64_t> axes;
Paul's avatar
Paul committed
546
547
548
549

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
550
        return pack(f(self.axes, "axes"));
Paul's avatar
Paul committed
551
552
    }

553
554
555
556
    std::string name() const { return "squeeze"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        auto input_shape = inputs[0];
Scott Thornton's avatar
Scott Thornton committed
557
558
        auto type        = input_shape.type();
        auto old_lens    = input_shape.lens();
wsttiger's avatar
wsttiger committed
559
560
        if(std::any_of(
               axes.begin(), axes.end(), [&](auto axis) { return input_shape.lens()[axis] != 1; }))
Scott Thornton's avatar
Scott Thornton committed
561
        {
Paul's avatar
Paul committed
562
            MIGRAPHX_THROW("squeeze axis dimension should be equal to 1");
563
564
        }
        std::vector<std::size_t> new_lens;
Scott Thornton's avatar
Scott Thornton committed
565
        if(axes.empty())
Scott Thornton's avatar
Scott Thornton committed
566
        {
wsttiger's avatar
wsttiger committed
567
568
569
570
            std::copy_if(old_lens.begin(),
                         old_lens.end(),
                         std::back_inserter(new_lens),
                         [](auto len) { return len != 1; });
571
        }
Scott Thornton's avatar
Scott Thornton committed
572
573
574
575
576
577
        else
        {
            for(std::size_t i = 0; i < old_lens.size(); i++)
            {
                if(std::find(axes.begin(), axes.end(), i) == axes.end())
                {
578
579
580
581
582
583
                    new_lens.push_back(old_lens[i]);
                }
            }
        }
        return shape{type, new_lens};
    }
Paul's avatar
Paul committed
584
    argument compute(shape output_shape, std::vector<argument> args) const
585
586
    {
        return {std::move(output_shape), std::move(args.front().data)};
Scott Thornton's avatar
Scott Thornton committed
587
    }
Paul's avatar
Paul committed
588
    int output_alias(const std::vector<shape>&) const { return 0; }
589
590
591
592
593
};

struct unsqueeze
{
    std::vector<int64_t> axes;
Paul's avatar
Paul committed
594
595
596
597

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
598
        return pack(f(self.axes, "axes"));
Paul's avatar
Paul committed
599
600
    }

601
602
603
    std::string name() const { return "unsqueeze"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Scott Thornton's avatar
Scott Thornton committed
604
605
606
        auto input_shape     = inputs[0];
        auto type            = input_shape.type();
        auto old_lens        = input_shape.lens();
607
608
609
        std::size_t new_size = old_lens.size() + axes.size();
        std::vector<std::size_t> new_lens(new_size);
        std::size_t p = 0;
Scott Thornton's avatar
Scott Thornton committed
610
611
612
613
        for(std::size_t i = 0; i < new_size; i++)
        {
            if(std::find(axes.begin(), axes.end(), i) != axes.end())
            {
614
                new_lens[i] = 1;
Scott Thornton's avatar
Scott Thornton committed
615
616
617
            }
            else
            {
618
619
620
621
622
                new_lens[i] = old_lens[p++];
            }
        }
        return shape{type, new_lens};
    }
Paul's avatar
Paul committed
623
    argument compute(shape output_shape, std::vector<argument> args) const
624
625
626
    {
        return {std::move(output_shape), std::move(args.front().data)};
    }
Paul's avatar
Paul committed
627
    int output_alias(const std::vector<shape>&) const { return 0; }
628
629
};

Paul's avatar
Paul committed
630
631
632
struct reshape
{
    std::vector<int64_t> dims;
Paul's avatar
Paul committed
633
634
635
636

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
637
        return pack(f(self.dims, "dims"));
Paul's avatar
Paul committed
638
639
    }

Paul's avatar
Paul committed
640
    std::string name() const { return "reshape"; }
Paul's avatar
Paul committed
641
642
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
643
        check_shapes{inputs, *this}.has(1);
Paul's avatar
Paul committed
644
645
        auto&& idims = inputs.front().lens();
        std::vector<std::size_t> rdims(dims.begin(), dims.end());
646
647
        auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
        if(n_neg_dims > 1)
Paul's avatar
Paul committed
648
            MIGRAPHX_THROW("Dimensions for reshape can only have one -1 dim");
Paul's avatar
Paul committed
649
        for(std::size_t i = 0; i < dims.size(); i++)
Paul's avatar
Paul committed
650
651
652
        {
            if(dims[i] == 0)
                rdims[i] = idims[i];
Shucai Xiao's avatar
Shucai Xiao committed
653
654
655

            // since rdims using size_t type, -1 is the max value
            // is size_t that cause later compuation incorrect
Shucai Xiao's avatar
Shucai Xiao committed
656
            if(dims[i] == -1)
Shucai Xiao's avatar
Shucai Xiao committed
657
                rdims[i] = 1;
Paul's avatar
Paul committed
658
        }
659
660
661
        if(n_neg_dims > 0)
        {
            size_t missing_dim =
Shucai Xiao's avatar
Shucai Xiao committed
662
                inputs.front().elements() /
663
664
665
666
667
668
669
                std::accumulate(rdims.begin(), rdims.end(), 1, std::multiplies<int64_t>());
            for(std::size_t i = 0; i < rdims.size(); i++)
            {
                if(dims[i] == -1)
                    rdims[i] = missing_dim;
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
670

Scott Thornton's avatar
Scott Thornton committed
671
        shape s{inputs.front().type(), rdims};
Paul's avatar
Paul committed
672
        if(s.elements() != inputs.front().elements())
Paul's avatar
Paul committed
673
            MIGRAPHX_THROW("Wrong number of elements for reshape");
Scott Thornton's avatar
Scott Thornton committed
674
        return s;
Paul's avatar
Paul committed
675
    }
Paul's avatar
Paul committed
676
    argument compute(shape output_shape, std::vector<argument> args) const
Paul's avatar
Paul committed
677
    {
Paul's avatar
Paul committed
678
        return {std::move(output_shape), std::move(args.front().data)};
Paul's avatar
Paul committed
679
    }
Paul's avatar
Paul committed
680
    int output_alias(const std::vector<shape>&) const { return 0; }
Paul's avatar
Paul committed
681
682
};

Khalique's avatar
Khalique committed
683
684
685
struct pad
{
    std::vector<int64_t> pads;
Khalique's avatar
Khalique committed
686
    float value = 0.0f;
Khalique's avatar
Khalique committed
687
    enum pad_op_mode_t
Khalique's avatar
Khalique committed
688
    {
Khalique's avatar
Khalique committed
689
        constant_pad,
Khalique's avatar
Khalique committed
690
        reflect_pad,
Khalique's avatar
Khalique committed
691
        edge_pad
Khalique's avatar
Khalique committed
692
    };
Khalique's avatar
Khalique committed
693
    pad_op_mode_t mode = constant_pad;
Khalique's avatar
Khalique committed
694
695
696
697
698
699
700
701
702
703
704
705
706

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(f(self.mode, "mode"), f(self.pads, "pads"), f(self.value, "value"));
    }

    std::string name() const { return "pad"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(1);
        auto&& idims = inputs.front().lens();
        std::vector<std::size_t> rdims(idims.begin(), idims.end());
Khalique's avatar
Khalique committed
707
        std::size_t num_dims = rdims.size();
Khalique's avatar
Khalique committed
708

Khalique's avatar
Khalique committed
709
        for(std::size_t i = 0; i < num_dims; i++)
Khalique's avatar
Khalique committed
710
711
712
        {
            rdims[i] += pads[i] + pads[i + num_dims];
        }
Khalique's avatar
Khalique committed
713

Khalique's avatar
Khalique committed
714
715
716
        shape s{inputs.front().type(), rdims};
        return s;
    }
717
718
719
720
721
722
723
724
725
726
727

    bool symmetric() const
    {
        std::size_t num_dims = pads.size()/2;
        for(std::size_t i = 0; i < num_dims; i++)
        {
            if(pads.at(i) != pads.at(i+num_dims))
                return false;
        }
        return true;
    }
Khalique's avatar
Khalique committed
728
729
};

Paul's avatar
Paul committed
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
struct as_shape
{
    shape s;
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(f(self.s, "shape"));
    }

    std::string name() const { return "as_shape"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
        check_shapes{inputs, *this}.has(1).standard();
        assert(inputs.front().elements() == s.elements());
        return s;
    }
    argument compute(shape output_shape, std::vector<argument> args) const
    {
        return {std::move(output_shape), std::move(args.front().data)};
    }
    int output_alias(const std::vector<shape>&) const { return 0; }
};

753
754
struct gather
{
755
    int axis = 0;
756
757
758
759
760
761
    std::string name() const { return "gather"; }

    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(2);
        auto lens = inputs[0].lens();
762
763
        int n_dim = static_cast<int>(lens.size());
        if(axis >= n_dim || axis < -n_dim)
764
        {
765
            MIGRAPHX_THROW("Gather: axis is out of range.");
766
        }
767

768
        // negative axis means counting dimensions from back
769
        int axis_index = (axis < 0) ? (n_dim + axis) : axis;
770

Shucai Xiao's avatar
Shucai Xiao committed
771
        auto type = inputs[0].type();
772
        lens.erase(lens.begin() + axis_index);
Shucai Xiao's avatar
Shucai Xiao committed
773
        if(!inputs[1].scalar())
774
775
776
777
        {
            auto ind_lens = inputs[1].lens();
            lens.insert(lens.begin() + axis_index, ind_lens.begin(), ind_lens.end());
        }
778

779
        // for scalar output
780
        if(lens.empty())
781
        {
782
            return {type};
783
        }
784
785

        return {type, lens};
786
787
788
789
790
    }

    argument compute(const shape& output_shape, std::vector<argument> args) const
    {
        argument result{output_shape};
791
        // negative axis means counting dimensions from back
Shucai Xiao's avatar
Shucai Xiao committed
792
793
        int axis_index =
            (axis < 0) ? static_cast<int>(args[0].get_shape().lens().size() + axis) : axis;
794

795
        // max dimension in axis
796
        visit_all(result, args[0])([&](auto output, auto data) {
797
            args[1].visit([&](auto indices) {
Shucai Xiao's avatar
Shucai Xiao committed
798
                if(output_shape.scalar())
799
800
801
                {
                    output[0] = data[indices.front()];
                }
Shucai Xiao's avatar
Shucai Xiao committed
802
                else
803
                {
Shucai Xiao's avatar
Shucai Xiao committed
804
                    auto out_lens        = data.get_shape().lens();
805
806
807
                    out_lens[axis_index] = indices.get_shape().elements();
                    migraphx::shape out_comp_shape{data.get_shape().type(), out_lens};
                    shape_for_each(out_comp_shape, [&](const auto& out_idx) {
Shucai Xiao's avatar
Shucai Xiao committed
808
                        auto data_idx        = out_idx;
809
                        data_idx[axis_index] = indices[data_idx[axis_index]];
Shucai Xiao's avatar
Shucai Xiao committed
810
                        output[out_comp_shape.index(out_idx.begin(), out_idx.end())] =
811
812
813
                            data(data_idx.begin(), data_idx.end());
                    });
                }
814
815
816
817
818
819
820
            });
        });

        return result;
    }
};

Shucai Xiao's avatar
Shucai Xiao committed
821
struct dot
822
{
Paul's avatar
Paul committed
823
    float alpha = 1.0;
Paul's avatar
Paul committed
824
    float beta  = 0.0;
Paul's avatar
Paul committed
825
826
827
828

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
829
        return pack(f(self.alpha, "alpha"), f(self.beta, "beta"));
Paul's avatar
Paul committed
830
831
    }

Shucai Xiao's avatar
Shucai Xiao committed
832
    std::string name() const { return "dot"; }
833
834
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
835
        check_shapes{inputs, *this}.has(2).same_type();
836
837
        const shape& a = inputs.at(0);
        const shape& b = inputs.at(1);
Scott Thornton's avatar
Scott Thornton committed
838
        auto t         = a.type();
839

840
841
842
843
        // according to the specification of the numpy.matmul()
        // inputs with the shape dims more than 2 are acceptable
        // as long as dim values are the same in the two inputs
        if(!std::equal(a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2))
844
        {
845
            MIGRAPHX_THROW("DOT: dim values mismatch");
846
847
        }

848
849
850
        std::size_t dim_0 = a.lens().size() - 2;
        std::size_t dim_1 = a.lens().size() - 1;
        if(a.lens()[dim_1] != b.lens()[dim_0])
Paul's avatar
Paul committed
851
852
            MIGRAPHX_THROW("Inner dimensions do not match: {" + to_string_range(a.lens()) +
                           "} x {" + to_string_range(b.lens()) + "}");
Shucai Xiao's avatar
Shucai Xiao committed
853
        auto out_lens   = a.lens();
854
        out_lens[dim_1] = b.lens()[dim_1];
855
        return {t, out_lens};
856
857
858
    }
};

859
struct unary
Scott Thornton's avatar
Scott Thornton committed
860
{
861
862
    shape compute_shape(std::vector<shape> inputs) const
    {
863
864
        check_shapes{inputs}.has(1);
        return inputs.at(0);
865
    }
Scott Thornton's avatar
Scott Thornton committed
866
867
};

868
struct identity
869
{
870
    std::string name() const { return "identity"; }
Scott Thornton's avatar
Scott Thornton committed
871
    shape compute_shape(std::vector<shape> inputs) const { return inputs.at(0); }
Paul's avatar
Paul committed
872
    argument compute(shape output_shape, std::vector<argument> args) const
873
874
875
    {
        return {std::move(output_shape), std::move(args.at(0).data)};
    }
Paul's avatar
Paul committed
876
    int output_alias(const std::vector<shape>&) const { return 0; }
877
878
879
};

struct abs : unary
Scott Thornton's avatar
Scott Thornton committed
880
{
881
    std::string name() const { return "abs"; }
Scott Thornton's avatar
Scott Thornton committed
882
883
};

884
struct exp : unary
Scott Thornton's avatar
Scott Thornton committed
885
{
886
    std::string name() const { return "exp"; }
Scott Thornton's avatar
Scott Thornton committed
887
888
};

Shucai Xiao's avatar
Shucai Xiao committed
889
890
891
892
893
struct log : unary
{
    std::string name() const { return "log"; }
};

894
struct sin : unary
Scott Thornton's avatar
Scott Thornton committed
895
{
896
    std::string name() const { return "sin"; }
Scott Thornton's avatar
Scott Thornton committed
897
898
};

899
struct cos : unary
Scott Thornton's avatar
Scott Thornton committed
900
{
901
    std::string name() const { return "cos"; }
Scott Thornton's avatar
Scott Thornton committed
902
903
};

904
struct tan : unary
Scott Thornton's avatar
Scott Thornton committed
905
{
906
    std::string name() const { return "tan"; }
Scott Thornton's avatar
Scott Thornton committed
907
908
};

909
struct asin : unary
Scott Thornton's avatar
Scott Thornton committed
910
{
911
    std::string name() const { return "asin"; }
Scott Thornton's avatar
Scott Thornton committed
912
913
};

914
struct acos : unary
Scott Thornton's avatar
Scott Thornton committed
915
{
916
    std::string name() const { return "acos"; }
Scott Thornton's avatar
Scott Thornton committed
917
918
};

919
struct atan : unary
Scott Thornton's avatar
Scott Thornton committed
920
{
921
    std::string name() const { return "atan"; }
Scott Thornton's avatar
Scott Thornton committed
922
923
};

924
925
926
927
928
929
930
931
932
933
struct sinh : unary
{
    std::string name() const { return "sinh"; }
};

struct cosh : unary
{
    std::string name() const { return "cosh"; }
};

934
struct tanh : unary
Scott Thornton's avatar
Scott Thornton committed
935
{
936
    std::string name() const { return "tanh"; }
Scott Thornton's avatar
Scott Thornton committed
937
938
};

939
struct sigmoid : unary
Scott Thornton's avatar
Scott Thornton committed
940
{
941
    std::string name() const { return "sigmoid"; }
Scott Thornton's avatar
Scott Thornton committed
942
943
};

944
struct neg : unary
Scott Thornton's avatar
Scott Thornton committed
945
{
946
    std::string name() const { return "neg"; }
Scott Thornton's avatar
Scott Thornton committed
947
948
};

Khalique's avatar
Khalique committed
949
950
951
952
953
struct relu : unary
{
    std::string name() const { return "relu"; }
};

Paul's avatar
Paul committed
954
955
956
957
958
959
960
961
962
963
struct softmax
{
    std::string name() const { return "softmax"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs}.has(1).only_dims(4);
        return inputs.at(0);
    }
};

Shucai Xiao's avatar
Shucai Xiao committed
964
965
966
967
968
969
970
struct logsoftmax
{
    int axis = 1;
    std::string name() const { return "logsoftmax"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs}.has(1);
971
        if(axis < 0 || axis > inputs[0].lens().size())
Shucai Xiao's avatar
Shucai Xiao committed
972
        {
Shucai Xiao's avatar
Shucai Xiao committed
973
974
            MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) +
                           " is out of range");
Shucai Xiao's avatar
Shucai Xiao committed
975
976
977
978
979
        }
        return inputs.at(0);
    }
};

980
struct flatten
Scott Thornton's avatar
Scott Thornton committed
981
{
Paul's avatar
Paul committed
982
    uint64_t axis = 0;
Paul's avatar
Paul committed
983
984
985
986

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
987
        return pack(f(self.axis, "axis"));
Paul's avatar
Paul committed
988
989
    }

Scott Thornton's avatar
Scott Thornton committed
990
    std::string name() const { return "flatten"; }
Paul's avatar
Paul committed
991
992
993
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs}.has(1);
Paul's avatar
Paul committed
994
995
        auto&& lens = inputs.front().lens();

Paul's avatar
Paul committed
996
        if(axis > lens.size())
Paul's avatar
Paul committed
997
        {
Paul's avatar
Paul committed
998
            MIGRAPHX_THROW("axis for flatten must be less than tensor rank");
Paul's avatar
Paul committed
999
        }
Paul's avatar
Paul committed
1000
1001
1002
1003
        auto x =
            std::accumulate(lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{});
        auto y =
            std::accumulate(lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<>{});
1004
        return {inputs.at(0).type(), {x, y}};
Paul's avatar
Paul committed
1005
    }
Paul's avatar
Paul committed
1006
    argument compute(shape output_shape, std::vector<argument> args) const
Paul's avatar
Paul committed
1007
    {
Paul's avatar
Paul committed
1008
        return {std::move(output_shape), std::move(args.front().data)};
Paul's avatar
Paul committed
1009
    }
Paul's avatar
Paul committed
1010
    int output_alias(const std::vector<shape>&) const { return 0; }
Scott Thornton's avatar
Scott Thornton committed
1011
};
1012

wsttiger's avatar
fixes  
wsttiger committed
1013
1014
1015
1016
1017
1018
1019
1020
/// The broadcast operator performs the numpy-style broadcasting of an axis of a given tensor. This
/// is achieved primarily by setting the stride of the broadcasted axis to zero. Linear indicies are
/// computed from multi-indicies by computing the inner product on the multi-index with the strides.
/// For example, if we have a tensor A(2,3) it has lengths of (2,3) and strides of (3,1). If we want
/// to compute the linear offset that corresponds to the element on the 2nd row (i = 1) and 3rd
/// column (j = 2), we compute the following inner product (1,2) dot (3, 1) = 1*3 + 2*1 = 5. It is
/// obvious from there that we can negate the effects of a given axis by setting the stride of that
/// axis to zero.
1021
1022
1023
struct broadcast
{
    uint64_t axis = 0;
Paul's avatar
Paul committed
1024
1025
1026
1027

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
1028
        return pack(f(self.axis, "axis"));
Paul's avatar
Paul committed
1029
1030
    }

Scott Thornton's avatar
Scott Thornton committed
1031
    shape broadcast_shape;
1032
1033
1034
    std::string name() const { return "broadcast"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Scott Thornton's avatar
Scott Thornton committed
1035
1036
        auto t     = inputs.at(0).type();
        auto input = inputs.at(0);
Paul's avatar
Paul committed
1037

Scott Thornton's avatar
Scott Thornton committed
1038
        std::vector<size_t> bcast_strides(broadcast_shape.lens().size(), 0);
1039

Scott Thornton's avatar
Scott Thornton committed
1040
1041
1042
        if(std::all_of(broadcast_shape.lens().cbegin(), broadcast_shape.lens().cend(), [&](auto x) {
               return x == 1;
           }))
1043
        {
Scott Thornton's avatar
Scott Thornton committed
1044
            if(axis != 0)
Paul's avatar
Paul committed
1045
                MIGRAPHX_THROW("when broadcasting tensor of size 1, axis should be 0");
Scott Thornton's avatar
Scott Thornton committed
1046
            return {t, broadcast_shape.lens(), std::move(bcast_strides)};
1047
1048
1049
        }
        else
        {
Scott Thornton's avatar
Scott Thornton committed
1050
            assert(broadcast_shape.lens().size() - axis >= input.lens().size());
Scott Thornton's avatar
Scott Thornton committed
1051
1052
            if(!std::equal(
                   input.lens().begin(), input.lens().end(), broadcast_shape.lens().begin() + axis))
Paul's avatar
Paul committed
1053
                MIGRAPHX_THROW("when broadcasting success sizes must match");
Paul's avatar
Paul committed
1054
            std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
Scott Thornton's avatar
Scott Thornton committed
1055
            return {t, broadcast_shape.lens(), std::move(bcast_strides)};
1056
1057
        }
    }
Paul's avatar
Paul committed
1058
    argument compute(shape output_shape, std::vector<argument> args) const
Scott Thornton's avatar
Scott Thornton committed
1059
    {
Scott Thornton's avatar
Scott Thornton committed
1060
        return {std::move(output_shape), std::move(args.at(0).data)};
Scott Thornton's avatar
Scott Thornton committed
1061
    }
Paul's avatar
Paul committed
1062
    int output_alias(const std::vector<shape>&) const { return 0; }
1063
1064
};

Scott Thornton's avatar
Scott Thornton committed
1065
1066
1067
struct multibroadcast
{
    std::vector<std::size_t> output_lens;
1068
1069
1070
1071
1072
1073
1074

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(f(self.output_lens, "output_lens"));
    }

Scott Thornton's avatar
Scott Thornton committed
1075
    std::string name() const { return "multibroadcast"; }
1076

Scott Thornton's avatar
Scott Thornton committed
1077
1078
1079
1080
1081
1082
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(1);
        auto t     = inputs.at(0).type();
        auto input = inputs.at(0);

wsttiger's avatar
wsttiger committed
1083
        if(input.lens().empty())
Paul's avatar
Paul committed
1084
            MIGRAPHX_THROW("inputs dimensions should be > 0");
Scott Thornton's avatar
Scott Thornton committed
1085

Scott Thornton's avatar
Scott Thornton committed
1086
        if(input.lens().size() > output_lens.size())
Paul's avatar
Paul committed
1087
            MIGRAPHX_THROW("inputs dimensions should <= output size");
Scott Thornton's avatar
Scott Thornton committed
1088
1089

        std::vector<size_t> bcast_strides(output_lens.size(), 0);
Scott Thornton's avatar
Scott Thornton committed
1090
1091
        auto offset = output_lens.size() - input.lens().size();
        for(int i = input.lens().size() - 1; i >= 0; i--)
Scott Thornton's avatar
Scott Thornton committed
1092
        {
Scott Thornton's avatar
Scott Thornton committed
1093
            if(output_lens[i + offset] == input.lens()[i])
Scott Thornton's avatar
Scott Thornton committed
1094
            {
Scott Thornton's avatar
Scott Thornton committed
1095
                bcast_strides[i + offset] = input.strides()[i];
Scott Thornton's avatar
Scott Thornton committed
1096
1097
1098
1099
            }
        }
        return {t, output_lens, bcast_strides};
    }
Paul's avatar
Paul committed
1100
    argument compute(shape output_shape, std::vector<argument> args) const
Scott Thornton's avatar
Scott Thornton committed
1101
1102
1103
1104
1105
1106
    {
        return {std::move(output_shape), std::move(args.at(0).data)};
    }
    int output_alias(const std::vector<shape>&) const { return 0; }
};

Khalique's avatar
Khalique committed
1107
1108
1109
1110
1111
1112
1113
1114
1115
struct scalar
{
    shape scalar_bcast;

    std::string name() const { return "scalar"; }

    shape compute_shape(std::vector<shape> inputs) const
    {
        assert(check_shapes{inputs}.has(1).only_dims(1).size() == 1);
Paul's avatar
Paul committed
1116
        auto t = inputs.at(0).type();
Khalique's avatar
Khalique committed
1117
1118
1119
1120
        std::vector<std::size_t> strides(scalar_bcast.lens().size(), 0);
        return {t, scalar_bcast.lens(), strides};
    }

Paul's avatar
Paul committed
1121
    argument compute(shape output_shape, std::vector<argument> args) const
Khalique's avatar
Khalique committed
1122
1123
1124
    {
        return {std::move(output_shape), std::move(args.at(0).data)};
    }
Paul's avatar
Paul committed
1125
    int output_alias(const std::vector<shape>&) const { return 0; }
Khalique's avatar
Khalique committed
1126
1127
};

1128
struct binary
Scott Thornton's avatar
Scott Thornton committed
1129
{
1130
1131
    shape compute_shape(std::vector<shape> inputs) const
    {
1132
        check_shapes{inputs}.has(2).same_type().same_dims();
Scott Thornton's avatar
Scott Thornton committed
1133
        auto t    = inputs.at(0).type();
1134
1135
        auto lens = inputs.at(0).lens();
        return {t, lens};
1136
    }
Scott Thornton's avatar
Scott Thornton committed
1137
1138
};

1139
1140
1141
1142
1143
1144
struct add : binary
{
    std::string name() const { return "add"; }
};

struct sub : binary
Scott Thornton's avatar
Scott Thornton committed
1145
1146
1147
1148
{
    std::string name() const { return "sub"; }
};

1149
struct mul : binary
Scott Thornton's avatar
Scott Thornton committed
1150
1151
1152
1153
{
    std::string name() const { return "mul"; }
};

1154
struct div : binary
Scott Thornton's avatar
Scott Thornton committed
1155
1156
1157
1158
{
    std::string name() const { return "div"; }
};

Khalique's avatar
Khalique committed
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
struct max : binary
{
    std::string name() const { return "max"; }
};

struct min : binary
{
    std::string name() const { return "min"; }
};

Paul's avatar
Paul committed
1169
1170
1171
1172
struct load
{
    shape s;
    std::size_t offset = 0;
Paul's avatar
Paul committed
1173
1174
1175
1176

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
1177
        return pack(f(self.s, "shape"), f(self.offset, "offset"));
Paul's avatar
Paul committed
1178
1179
    }

Paul's avatar
Paul committed
1180
1181
1182
1183
1184
1185
    std::string name() const { return "load"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
        check_shapes{inputs}.has(1);
        return s;
    }
Paul's avatar
Paul committed
1186
    argument compute(const shape&, const std::vector<argument>& args) const
Paul's avatar
Paul committed
1187
    {
Paul's avatar
Paul committed
1188
1189
        if((offset + s.bytes()) > args[0].get_shape().bytes())
            MIGRAPHX_THROW("Load access is out of bounds");
Paul's avatar
Paul committed
1190
1191
        return {s, args[0].data() + offset};
    }
Paul's avatar
Paul committed
1192
    int output_alias(const std::vector<shape>&) const { return 0; }
Paul's avatar
Paul committed
1193
1194
1195
1196
1197
1198
1199
1200

    friend std::ostream& operator<<(std::ostream& os, const load& op)
    {
        os << op.name() << "[";
        os << "offset=" << op.offset << ",";
        os << "end=" << (op.offset + op.s.bytes()) << "]";
        return os;
    }
Paul's avatar
Paul committed
1201
1202
};

Paul's avatar
Paul committed
1203
struct outline
Scott Thornton's avatar
Scott Thornton committed
1204
{
Paul's avatar
Paul committed
1205
    shape s;
Paul's avatar
Paul committed
1206
1207
1208
1209

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
1210
        return pack(f(self.s, "shape"));
Paul's avatar
Paul committed
1211
1212
    }

Paul's avatar
Paul committed
1213
    std::string name() const { return "outline"; }
Paul's avatar
Paul committed
1214
    shape compute_shape(const std::vector<shape>& inputs) const
Paul's avatar
Paul committed
1215
    {
Paul's avatar
Paul committed
1216
        check_shapes{inputs, *this}.has(0);
Paul's avatar
Paul committed
1217
1218
        return s;
    }
Paul's avatar
Paul committed
1219
    argument compute(const shape&, const std::vector<argument>&) const { return {s, nullptr}; }
Scott Thornton's avatar
Scott Thornton committed
1220
1221
};

1222
1223
// indicate rnn computation direction
enum class rnn_direction
Shucai Xiao's avatar
Shucai Xiao committed
1224
{
1225
1226
1227
1228
    forward,
    reverse,
    bidirectional,
};
Shucai Xiao's avatar
Shucai Xiao committed
1229

1230
1231
struct rnn
{
Shucai Xiao's avatar
Shucai Xiao committed
1232
    std::size_t hidden_size = 1;
1233
    std::vector<operation> actv_funcs{tanh{}, tanh{}};
1234
    rnn_direction direction = rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1235
    float clip              = 0.0f;
Shucai Xiao's avatar
Shucai Xiao committed
1236
1237
1238
1239

    std::string name() const { return "rnn"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Shucai Xiao's avatar
Shucai Xiao committed
1240
        auto in_dims     = inputs[0].lens();
Shucai Xiao's avatar
Shucai Xiao committed
1241
1242
        auto hidden_dims = inputs[2].lens();
        if(hidden_size != hidden_dims[2])
Shucai Xiao's avatar
Shucai Xiao committed
1243
1244
1245
1246
1247
        {
            MIGRAPHX_THROW("RNN: hidden size mismatch in attribute and input");
        }

        std::size_t num_directions = 1;
1248
        if(direction == rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1249
1250
1251
1252
        {
            num_directions = 2;
        }

Shucai Xiao's avatar
Shucai Xiao committed
1253
        if(num_directions != hidden_dims[0])
Shucai Xiao's avatar
Shucai Xiao committed
1254
        {
1255
            MIGRAPHX_THROW("RNN: num_direction mismatch in attribute and input");
Shucai Xiao's avatar
Shucai Xiao committed
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
        }

        std::vector<std::size_t> out_dims(in_dims);
        out_dims.insert(out_dims.begin() + 1, num_directions);
        out_dims.back() = hidden_size;

        return {inputs[0].type(), out_dims};
    }
};

Shucai Xiao's avatar
Shucai Xiao committed
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
struct rnn_last_output
{
    std::string name() const { return "rnn_last_output"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(1);
        auto dims = inputs[0].lens();

        // remove the first dimension, remaing are output shape
        dims.erase(dims.begin());
        return {inputs[0].type(), dims};
    }
};

Shucai Xiao's avatar
Shucai Xiao committed
1280
1281
1282
1283
struct gru
{
    std::size_t hidden_size = 1;
    std::vector<operation> actv_funcs{sigmoid{}, tanh{}};
1284
    rnn_direction direction = rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1285
1286
    float clip              = 0.0f;
    int linear_before_reset = 0;
Shucai Xiao's avatar
Shucai Xiao committed
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298

    std::string name() const { return "gru"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        auto in_dims     = inputs[0].lens();
        auto hidden_dims = inputs[2].lens();
        if(hidden_size != hidden_dims[2])
        {
            MIGRAPHX_THROW("GRU: hidden size mismatch in attribute and input");
        }

        std::size_t num_directions = 1;
1299
        if(direction == rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
        {
            num_directions = 2;
        }

        if(num_directions != hidden_dims[0])
        {
            MIGRAPHX_THROW("GRU: num_direction does not match the direction attribute");
        }

        std::vector<std::size_t> out_dims(in_dims);
        out_dims.insert(out_dims.begin() + 1, num_directions);
        out_dims.back() = hidden_size;

        return {inputs[0].type(), out_dims};
    }
};

Shucai Xiao's avatar
Shucai Xiao committed
1317
1318
1319
1320
struct lstm
{
    std::size_t hidden_size = 1;
    std::vector<operation> actv_funcs{sigmoid{}, tanh{}, tanh{}};
Shucai Xiao's avatar
Shucai Xiao committed
1321
    rnn_direction direction = rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1322
1323
    float clip              = 0.0f;
    int input_forget        = 0;
Shucai Xiao's avatar
Shucai Xiao committed
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335

    std::string name() const { return "lstm"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        auto in_dims     = inputs[0].lens();
        auto hidden_dims = inputs[2].lens();
        if(hidden_size != hidden_dims[2])
        {
            MIGRAPHX_THROW("LSTM: hidden size mismatch in attribute and input");
        }

        std::size_t num_directions = 1;
Shucai Xiao's avatar
Shucai Xiao committed
1336
        if(direction == rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
        {
            num_directions = 2;
        }

        if(num_directions != hidden_dims[0])
        {
            MIGRAPHX_THROW("LSTM: num_direction does not match the direction attribute");
        }

        std::vector<std::size_t> out_dims(in_dims);
        out_dims.insert(out_dims.begin() + 1, num_directions);
        out_dims.back() = hidden_size;

        return {inputs[0].type(), out_dims};
    }
};

struct lstm_last_cell_output
{
    std::string name() const { return "lstm_last_cell_output"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(1);
        auto dims = inputs[0].lens();

        // remove the first dimension, remaing are output shape
        dims.erase(dims.begin());
        return {inputs[0].type(), dims};
    }
};

1368
1369
1370
struct undefined
{
    std::string name() const { return "undefined"; }
Shucai Xiao's avatar
Shucai Xiao committed
1371
    shape compute_shape(const std::vector<shape>& inputs) const
1372
1373
1374
1375
1376
1377
1378
1379
    {
        check_shapes{inputs, *this}.has(0);
        return {};
    }

    argument compute(const shape&, const std::vector<argument>&) const { return {{}, nullptr}; }
};

1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
struct unknown
{
    std::string op;
    std::string name() const { return "unknown:" + op; }
    shape compute_shape(std::vector<shape> input) const
    {
        if(input.empty())
            return {};
        else
            return input.front();
    }

    friend std::ostream& operator<<(std::ostream& os, const unknown& x)
    {
        os << x.name();
        return os;
    }
};

1399
} // namespace op
Paul's avatar
Paul committed
1400
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
1401
} // namespace migraphx
Paul's avatar
Paul committed
1402
1403

#endif