"doc/en/DeepseekR1_V3_tutorial.md" did not exist on "bf1d413be0e6971186ad1c221a943061df918430"
operators.hpp 37.1 KB
Newer Older
Paul's avatar
Paul committed
1
2
#ifndef MIGRAPHX_GUARD_OPERATORS_HPP
#define MIGRAPHX_GUARD_OPERATORS_HPP
Paul's avatar
Paul committed
3

4
#include <array>
Paul's avatar
Paul committed
5
6
7
8
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
9
10
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
Paul's avatar
Paul committed
11
#include <migraphx/config.hpp>
Paul's avatar
Paul committed
12
#include <cmath>
Paul's avatar
Paul committed
13
#include <utility>
Paul's avatar
Paul committed
14

Paul's avatar
Paul committed
15
namespace migraphx {
Paul's avatar
Paul committed
16
inline namespace MIGRAPHX_INLINE_NS {
17
namespace op {
Paul's avatar
Paul committed
18

19
20
21
22
23
24
25
enum padding_mode_t
{
    default_, // NOLINT
    same,
    valid
};

Paul's avatar
Paul committed
26
27
struct not_computable
{
Paul's avatar
Paul committed
28
    argument compute(const shape&, const std::vector<argument>&) const
Paul's avatar
Paul committed
29
    {
Paul's avatar
Paul committed
30
        MIGRAPHX_THROW("not computable");
Paul's avatar
Paul committed
31
    }
Paul's avatar
Paul committed
32
33
};

34
35
struct batch_norm_inference
{
36
37
    float epsilon  = 1.0e-6f;
    float momentum = 0.9f;
38
39
40

    std::string name() const { return "batch_norm_inference"; }

41
42
43
44
45
46
47
48
    enum bn_infer_mode_t
    {
        per_activation,
        spatial,
    };

    bn_infer_mode_t bn_mode = spatial;

Paul's avatar
Paul committed
49
50
51
52
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(
Paul's avatar
Paul committed
53
            f(self.epsilon, "epsilon"), f(self.momentum, "momentum"), f(self.bn_mode, "bn_mode"));
Paul's avatar
Paul committed
54
    }
55

56
57
58
59
60
61
62
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(5);
        return inputs.front();
    }
};

Paul's avatar
Paul committed
63
struct convolution
Paul's avatar
Paul committed
64
{
Paul's avatar
Paul committed
65
66
67
    std::array<std::size_t, 2> padding  = {{0, 0}};
    std::array<std::size_t, 2> stride   = {{1, 1}};
    std::array<std::size_t, 2> dilation = {{1, 1}};
Khalique's avatar
Khalique committed
68

Paul's avatar
Paul committed
69
    padding_mode_t padding_mode = default_;
Khalique's avatar
Khalique committed
70
    int group                   = 1;
Paul's avatar
Paul committed
71
72
73
74

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
75
76
77
        return pack(f(self.padding, "padding"),
                    f(self.stride, "stride"),
                    f(self.dilation, "dilation"),
Khalique's avatar
Khalique committed
78
79
                    f(self.padding_mode, "padding_mode"),
                    f(self.group, "group"));
Paul's avatar
Paul committed
80
81
    }

Paul's avatar
Paul committed
82
    std::string name() const { return "convolution"; }
Paul's avatar
Paul committed
83
84
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
85
        check_shapes{inputs, *this}.has(2).same_type().same_ndims().only_dims(4);
Paul's avatar
Paul committed
86

Paul's avatar
Paul committed
87
        const shape& input   = inputs.at(0);
Paul's avatar
Paul committed
88
        const shape& weights = inputs.at(1);
Paul's avatar
Paul committed
89
        auto t               = input.type();
Paul's avatar
Paul committed
90
91
        if(padding_mode == default_)
        {
Paul's avatar
Paul committed
92
93
94
            return {t,
                    {
                        input.lens()[0],
Khalique's avatar
Khalique committed
95
                        weights.lens()[0],
Paul's avatar
Paul committed
96
97
98
99
100
101
102
103
104
105
106
107
108
                        std::size_t(std::max<std::ptrdiff_t>(
                            1,
                            (input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) +
                             2 * padding[0]) /
                                    stride[0] +
                                1)),
                        std::size_t(std::max<std::ptrdiff_t>(
                            1,
                            (input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) +
                             2 * padding[1]) /
                                    stride[1] +
                                1)),
                    }};
Paul's avatar
Paul committed
109
110
111
112
113
        }
        else if(padding_mode == same)
        {
            return {t,
                    {input.lens()[0],
Khalique's avatar
Khalique committed
114
                     weights.lens()[0],
Paul's avatar
Paul committed
115
116
117
118
119
120
121
122
123
124
                     static_cast<std::size_t>(
                         std::ceil(static_cast<double>(input.lens()[2]) / stride[0])),
                     static_cast<std::size_t>(
                         std::ceil(static_cast<double>(input.lens()[3]) / stride[1]))}};
        }
        else if(padding_mode == valid)
        {
            return {
                t,
                {input.lens()[0],
Khalique's avatar
Khalique committed
125
                 weights.lens()[0],
Paul's avatar
Paul committed
126
127
128
129
130
131
132
                 static_cast<std::size_t>(std::ceil(
                     static_cast<double>(input.lens()[2] - weights.lens()[2] + 1) / stride[0])),
                 static_cast<std::size_t>(std::ceil(
                     static_cast<double>(input.lens()[3] - weights.lens()[3] + 1) / stride[1]))}};
        }
        else
        {
Paul's avatar
Paul committed
133
            MIGRAPHX_THROW("Invalid padding mode");
Paul's avatar
Paul committed
134
        }
Paul's avatar
Paul committed
135
136
137
    }
};

Scott Thornton's avatar
Scott Thornton committed
138
139
struct im2col
{
Scott Thornton's avatar
Scott Thornton committed
140
141
142
    std::array<std::size_t, 2> padding  = {{0, 0}};
    std::array<std::size_t, 2> stride   = {{1, 1}};
    std::array<std::size_t, 2> dilation = {{1, 1}};
Khalique's avatar
Khalique committed
143

Paul's avatar
Paul committed
144
145
146
147
148
    padding_mode_t padding_mode = default_;

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
149
150
151
152
        return pack(f(self.padding, "padding"),
                    f(self.stride, "stride"),
                    f(self.dilation, "dilation"),
                    f(self.padding_mode, "padding_mode"));
Paul's avatar
Paul committed
153
    }
Scott Thornton's avatar
Scott Thornton committed
154
155
156
157
158

    std::string name() const { return "im2col"; }

    shape compute_shape(std::vector<shape> inputs) const
    {
Scott Thornton's avatar
Scott Thornton committed
159
160
161
        auto input          = inputs[0];
        auto weights        = inputs[1];
        auto batch_size     = input.lens()[0];
Scott Thornton's avatar
Scott Thornton committed
162
        auto input_channels = weights.lens()[1];
Scott Thornton's avatar
Scott Thornton committed
163
164
        auto kernel_height  = weights.lens()[2];
        auto kernel_width   = weights.lens()[3];
Scott Thornton's avatar
Scott Thornton committed
165
        check_shapes{inputs, *this}.has(2);
Scott Thornton's avatar
Scott Thornton committed
166
        if(batch_size != 1)
Paul's avatar
Paul committed
167
            MIGRAPHX_THROW("im2col only support batch_size 1");
Scott Thornton's avatar
Scott Thornton committed
168
        auto output_height = std::size_t(std::max<std::ptrdiff_t>(
Scott Thornton's avatar
Scott Thornton committed
169
170
171
            1,
            (input.lens()[2] - (1 + dilation[0] * (kernel_height - 1)) + 2 * padding[0]) /
                    stride[0] +
Scott Thornton's avatar
Scott Thornton committed
172
                1));
Scott Thornton's avatar
Scott Thornton committed
173
174
175
176
        auto output_width  = std::size_t(std::max<std::ptrdiff_t>(
            1,
            (input.lens()[3] - (1 + dilation[1] * (kernel_width - 1)) + 2 * padding[1]) /
                    stride[1] +
Scott Thornton's avatar
Scott Thornton committed
177
                1));
Scott Thornton's avatar
Scott Thornton committed
178
179
        auto channels_col  = kernel_height * kernel_width * input_channels;
        return {input.type(), {output_height * output_width, channels_col}};
Scott Thornton's avatar
Scott Thornton committed
180
181
182
    }
};

Paul's avatar
Paul committed
183
struct pooling
Paul's avatar
Paul committed
184
{
Paul's avatar
Paul committed
185
    std::string mode                   = "average";
Paul's avatar
Paul committed
186
187
188
    std::array<std::size_t, 2> padding = {{0, 0}};
    std::array<std::size_t, 2> stride  = {{1, 1}};
    std::array<std::size_t, 2> lengths = {{1, 1}};
Khalique's avatar
Khalique committed
189
    padding_mode_t padding_mode        = default_;
Paul's avatar
Paul committed
190
191
192
193

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
194
195
        return pack(f(self.mode, "mode"),
                    f(self.padding, "padding"),
196
                    f(self.padding, "padding_mode"),
Paul's avatar
Paul committed
197
198
                    f(self.stride, "stride"),
                    f(self.lengths, "lengths"));
Paul's avatar
Paul committed
199
200
    }

Paul's avatar
Paul committed
201
    std::string name() const { return "pooling"; }
Scott Thornton's avatar
Scott Thornton committed
202

Paul's avatar
Paul committed
203
204
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
205
        check_shapes{inputs, *this}.has(1).only_dims(4);
Paul's avatar
Paul committed
206

Paul's avatar
Paul committed
207
        const shape& input = inputs.at(0);
Paul's avatar
Paul committed
208
        auto t             = input.type();
Paul's avatar
Paul committed
209

Paul's avatar
Paul committed
210
211
        assert(lengths[0] <= (input.lens()[2] + 2 * padding[0]));
        assert(lengths[1] <= (input.lens()[3] + 2 * padding[1]));
Paul's avatar
Paul committed
212

213
214
        if(padding_mode == default_)
        {
Khalique's avatar
Khalique committed
215
216
            return {
                t,
Scott Thornton's avatar
Scott Thornton committed
217
218
219
220
221
                {
                    input.lens()[0],
                    input.lens()[1],
                    std::size_t(std::max<std::ptrdiff_t>(
                        1,
Paul's avatar
Paul committed
222
                        std::ptrdiff_t(std::floor((input.lens()[2] + 2 * padding[0] - lengths[0]) /
Paul's avatar
Paul committed
223
                                                  static_cast<float>(stride[0]))) +
Scott Thornton's avatar
Scott Thornton committed
224
225
226
                            1)),
                    std::size_t(std::max<std::ptrdiff_t>(
                        1,
Paul's avatar
Paul committed
227
                        std::ptrdiff_t(std::floor((input.lens()[3] + 2 * padding[1] - lengths[1]) /
Paul's avatar
Paul committed
228
                                                  static_cast<float>(stride[1]))) +
Scott Thornton's avatar
Scott Thornton committed
229
230
                            1)),
                }};
231
232
233
234
235
236
237
238
239
240
241
242
243
244
        }
        else if(padding_mode == same)
        {
            return {t,
                    {input.lens()[0],
                     input.lens()[1],
                     static_cast<std::size_t>(
                         std::ceil(static_cast<double>(input.lens()[2]) / stride[0])),
                     static_cast<std::size_t>(
                         std::ceil(static_cast<double>(input.lens()[3]) / stride[1]))}};
        }
        else if(padding_mode == valid)
        {
            return {t,
Khalique's avatar
Khalique committed
245
246
247
248
249
250
251
252
253
254
255
256
257
258
                    {
                        input.lens()[0],
                        input.lens()[1],
                        std::size_t(std::max<std::ptrdiff_t>(
                            1,
                            std::ptrdiff_t(std::floor((input.lens()[2] - lengths[0]) /
                                                      static_cast<float>(stride[0]))) +
                                1)),
                        std::size_t(std::max<std::ptrdiff_t>(
                            1,
                            std::ptrdiff_t(std::floor((input.lens()[3] - lengths[1]) /
                                                      static_cast<float>(stride[1]))) +
                                1)),
                    }};
259
260
261
262
263
        }
        else
        {
            MIGRAPHX_THROW("Invalid padding mode");
        }
Paul's avatar
Paul committed
264
265
266
    }
};

Khalique's avatar
Khalique committed
267
268
269
270
271
272
273
274
275
struct leaky_relu
{
    std::string name() const { return "leaky_relu"; }
    float alpha;
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(1);
        return inputs.front();
    }
Khalique's avatar
Khalique committed
276
277
278

    template <class Self, class F>
    static auto reflect(Self& self, F f)
Khalique's avatar
Khalique committed
279
    {
Khalique's avatar
Khalique committed
280
        return pack(f(self.alpha, "alpha"));
Khalique's avatar
Khalique committed
281
282
283
284
285
286
287
288
289
290
291
292
    }
};

struct elu
{
    std::string name() const { return "elu"; }
    float alpha;
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(1);
        return inputs.front();
    }
Khalique's avatar
Khalique committed
293
294
295

    template <class Self, class F>
    static auto reflect(Self& self, F f)
Khalique's avatar
Khalique committed
296
    {
Khalique's avatar
Khalique committed
297
        return pack(f(self.alpha, "alpha"));
Khalique's avatar
Khalique committed
298
    }
Khalique's avatar
Khalique committed
299
300
};

301
302
303
struct transpose
{
    std::vector<int64_t> dims;
Paul's avatar
Paul committed
304
305
306
307

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
308
        return pack(f(self.dims, "dims"));
Paul's avatar
Paul committed
309
310
    }

311
312
313
    std::string name() const { return "transpose"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
314
        check_shapes{inputs, *this}.has(1);
315
        auto input         = inputs.at(0);
316
        auto input_lens    = input.lens();
317
318
        auto input_strides = input.strides();
        auto t             = input.type();
Paul's avatar
Paul committed
319
320
        if(dims.size() != input_lens.size())
        {
Paul's avatar
Paul committed
321
            MIGRAPHX_THROW("Permutation has wrong number of axes");
322
323
324
        }
        std::vector<int64_t> axes(dims.size());
        std::iota(axes.begin(), axes.end(), 0);
Paul's avatar
Paul committed
325
326
        if(!std::is_permutation(axes.begin(), axes.end(), dims.begin()))
        {
Paul's avatar
Paul committed
327
            MIGRAPHX_THROW("Invalid permutation");
328
        }
329
330
        std::vector<size_t> output_lens(input_lens.size());
        std::vector<size_t> output_strides(input_lens.size());
Paul's avatar
Paul committed
331
        for(std::size_t i = 0; i < output_lens.size(); i++)
Paul's avatar
Paul committed
332
333
        {
            output_lens[i]    = input_lens[dims[i]];
334
335
            output_strides[i] = input_strides[dims[i]];
        }
336
        return {t, output_lens, output_strides};
337
    }
Paul's avatar
Paul committed
338
    argument compute(shape output_shape, std::vector<argument> args) const
Paul's avatar
Paul committed
339
    {
Paul's avatar
Paul committed
340
        return {std::move(output_shape), std::move(args.front().data)};
Paul's avatar
Paul committed
341
    }
Paul's avatar
Paul committed
342
    int output_alias(const std::vector<shape>&) const { return 0; }
343
344
};

wsttiger's avatar
fixes  
wsttiger committed
345
346
347
348
349
350
/// The contiguous operator takes a non-standard input tensor and returns
/// the same tensor but in standard form. For example, if input tensor A which has lens = (4,5)
/// is first transposed, i.e. lens = (5,4), this tensor's data layout remained the same
/// during the transpose operation; only it's shape lengths and strides were changed.
/// This leaves the tensor in a non-standard form. The contiguous operator copies the
/// underlying data such that resulting tensor is returned to a standard form.
Paul's avatar
Paul committed
351
struct contiguous
352
353
354
355
{
    std::string name() const { return "contiguous"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
356
        check_shapes{inputs, *this}.has(1);
Paul's avatar
Paul committed
357
358
        auto lens = inputs.at(0).lens();
        auto t    = inputs.at(0).type();
359
360
361
362
        return {t, lens};
    }
};

363
364
365
366
struct concat
{
    std::size_t axis = 0;
    std::string name() const { return "concat"; }
367
    std::vector<std::size_t> compute_offsets(const shape& output_shape,
Paul's avatar
Paul committed
368
                                             const std::vector<argument>& args) const
369
370
371
372
373
374
375
376
377
378
379
    {
        std::vector<std::size_t> offsets;
        std::vector<std::size_t> offset(args[0].get_shape().lens().size(), 0);
        offset[axis] = 0;
        for(const auto& arg : args)
        {
            offsets.push_back(output_shape.index(offset));
            offset[axis] += arg.get_shape().lens()[axis];
        }
        return offsets;
    }
380
381
    shape compute_shape(std::vector<shape> inputs) const
    {
Scott Thornton's avatar
Scott Thornton committed
382
        if(inputs.empty())
383
        {
Paul's avatar
Paul committed
384
            MIGRAPHX_THROW("Number of input tensors should exceed 0");
385
386
387
        }

        const auto& first_shape_lens = inputs.front().lens();
Scott Thornton's avatar
Scott Thornton committed
388
389
390
391
392
393
394
395
396
        const auto& type             = inputs.front().type();
        for(std::size_t l = 0; l < first_shape_lens.size(); l++)
        {
            if(l != axis)
            {
                if(!std::all_of(inputs.begin(), inputs.end(), [&](auto s) {
                       return s.lens()[l] == first_shape_lens[l];
                   }))
                {
Paul's avatar
Paul committed
397
                    MIGRAPHX_THROW("Non-axis dimensions should match");
398
399
400
401
                }
            }
        }
        std::size_t new_dim_axis = 0;
Scott Thornton's avatar
Scott Thornton committed
402
        for(const auto& input : inputs)
403
404
405
406
407
408
409
410
411
        {
            const auto& lens = input.lens();
            new_dim_axis += lens[axis];
        }
        std::vector<std::size_t> new_lens;
        std::copy(first_shape_lens.begin(), first_shape_lens.end(), std::back_inserter(new_lens));
        new_lens[axis] = new_dim_axis;
        return {type, new_lens};
    }
Paul's avatar
Paul committed
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
    argument compute(const shape& output_shape, std::vector<argument> args) const
    {
        argument result{output_shape};
        std::vector<std::size_t> coffsets = compute_offsets(output_shape, args);
        for(std::size_t l = 0; l < args.size(); l++)
        {
            auto argl             = args[l];
            std::size_t nelements = argl.get_shape().elements();
            visit_all(result, argl)([&](auto output, auto input) {
                auto slice_shape =
                    shape{output_shape.type(), input.get_shape().lens(), output_shape.strides()};
                auto slice = make_view(slice_shape, output.data() + coffsets[l]);
                // cppcheck-suppress useStlAlgorithm
                for(std::size_t i = 0; i < nelements; i++)
                {
                    slice[i] = input[i];
                }
            });
        }
        return result;
    }
433
434
};

435
436
437
438
439
struct slice
{
    std::vector<int64_t> axes;
    std::vector<int64_t> starts;
    std::vector<int64_t> ends;
Paul's avatar
Paul committed
440
441
442
443

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
444
        return pack(f(self.axes, "axes"), f(self.starts, "starts"), f(self.ends, "ends"));
Paul's avatar
Paul committed
445
446
    }

447
    std::string name() const { return "slice"; }
Scott Thornton's avatar
Scott Thornton committed
448
449

    auto fix_index(const std::vector<std::size_t>& lens, std::size_t axis, int64_t index) const
450
    {
Scott Thornton's avatar
Scott Thornton committed
451
        int64_t r = std::min(index, static_cast<int64_t>(lens[axis]));
Scott Thornton's avatar
Scott Thornton committed
452
453
        if(r < 0)
            r += lens[axis];
Scott Thornton's avatar
Scott Thornton committed
454
        return std::size_t(r);
Scott Thornton's avatar
Scott Thornton committed
455
456
457
458
459
460
461
    }

    auto compute_offset(const shape& s) const
    {
        const std::vector<std::size_t>& lens    = s.lens();
        const std::vector<std::size_t>& strides = s.strides();
        auto offset                             = 0;
Scott Thornton's avatar
Scott Thornton committed
462
        if(!axes.empty())
Scott Thornton's avatar
Scott Thornton committed
463
        {
Scott Thornton's avatar
Scott Thornton committed
464
465
466
467
468
            for(std::size_t i = 0; i < axes.size(); i++)
            {
                auto axis = axes[i];
                offset += fix_index(lens, axis, starts[i]) * strides[axis];
            }
469
        }
Scott Thornton's avatar
Scott Thornton committed
470
471
        else
        {
Scott Thornton's avatar
Scott Thornton committed
472
473
474
475
            for(std::size_t axis = 0; axis < lens.size(); axis++)
            {
                offset += fix_index(lens, axis, starts[axis]) * strides[axis];
            }
476
        }
Scott Thornton's avatar
Scott Thornton committed
477
478
479
480
481
        return offset;
    }

    shape compute_shape(std::vector<shape> inputs) const
    {
Scott Thornton's avatar
Scott Thornton committed
482
483
484
485
        auto input_shape        = inputs[0];
        auto t                  = input_shape.type();
        const auto& old_lens    = input_shape.lens();
        const auto& old_strides = input_shape.strides();
Scott Thornton's avatar
Scott Thornton committed
486
        if(starts.size() != axes.size() || axes.size() != ends.size())
Scott Thornton's avatar
Scott Thornton committed
487
        {
Paul's avatar
Paul committed
488
            MIGRAPHX_THROW("inconsistent sizes");
489
        }
Scott Thornton's avatar
Scott Thornton committed
490
491
        std::vector<std::size_t> new_lens = old_lens;
        for(std::size_t i = 0; i < axes.size(); i++)
Scott Thornton's avatar
Scott Thornton committed
492
        {
Scott Thornton's avatar
Scott Thornton committed
493
494
495
            auto axis = axes[i];
            new_lens[axis] =
                fix_index(old_lens, axis, ends[i]) - fix_index(old_lens, axis, starts[i]);
496
497
498
        }
        return shape{t, new_lens, old_strides};
    }
Paul's avatar
Paul committed
499
    argument compute(shape output_shape, std::vector<argument> args) const
500
    {
Scott Thornton's avatar
Scott Thornton committed
501
502
503
        auto input  = args[0];
        auto offset = compute_offset(input.get_shape()) * output_shape.type_size();
        return {std::move(output_shape), [=] { return input.data() + offset; }};
504
    }
Paul's avatar
Paul committed
505
    int output_alias(const std::vector<shape>&) const { return 0; }
506
507
508
509
510
};

struct squeeze
{
    std::vector<int64_t> axes;
Paul's avatar
Paul committed
511
512
513
514

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
515
        return pack(f(self.axes, "axes"));
Paul's avatar
Paul committed
516
517
    }

518
519
520
521
    std::string name() const { return "squeeze"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        auto input_shape = inputs[0];
Scott Thornton's avatar
Scott Thornton committed
522
523
        auto type        = input_shape.type();
        auto old_lens    = input_shape.lens();
wsttiger's avatar
wsttiger committed
524
525
        if(std::any_of(
               axes.begin(), axes.end(), [&](auto axis) { return input_shape.lens()[axis] != 1; }))
Scott Thornton's avatar
Scott Thornton committed
526
        {
Paul's avatar
Paul committed
527
            MIGRAPHX_THROW("squeeze axis dimension should be equal to 1");
528
529
        }
        std::vector<std::size_t> new_lens;
Scott Thornton's avatar
Scott Thornton committed
530
        if(axes.empty())
Scott Thornton's avatar
Scott Thornton committed
531
        {
wsttiger's avatar
wsttiger committed
532
533
534
535
            std::copy_if(old_lens.begin(),
                         old_lens.end(),
                         std::back_inserter(new_lens),
                         [](auto len) { return len != 1; });
536
        }
Scott Thornton's avatar
Scott Thornton committed
537
538
539
540
541
542
        else
        {
            for(std::size_t i = 0; i < old_lens.size(); i++)
            {
                if(std::find(axes.begin(), axes.end(), i) == axes.end())
                {
543
544
545
546
547
548
                    new_lens.push_back(old_lens[i]);
                }
            }
        }
        return shape{type, new_lens};
    }
Paul's avatar
Paul committed
549
    argument compute(shape output_shape, std::vector<argument> args) const
550
551
    {
        return {std::move(output_shape), std::move(args.front().data)};
Scott Thornton's avatar
Scott Thornton committed
552
    }
Paul's avatar
Paul committed
553
    int output_alias(const std::vector<shape>&) const { return 0; }
554
555
556
557
558
};

struct unsqueeze
{
    std::vector<int64_t> axes;
Paul's avatar
Paul committed
559
560
561
562

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
563
        return pack(f(self.axes, "axes"));
Paul's avatar
Paul committed
564
565
    }

566
567
568
    std::string name() const { return "unsqueeze"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Scott Thornton's avatar
Scott Thornton committed
569
570
571
        auto input_shape     = inputs[0];
        auto type            = input_shape.type();
        auto old_lens        = input_shape.lens();
572
573
574
        std::size_t new_size = old_lens.size() + axes.size();
        std::vector<std::size_t> new_lens(new_size);
        std::size_t p = 0;
Scott Thornton's avatar
Scott Thornton committed
575
576
577
578
        for(std::size_t i = 0; i < new_size; i++)
        {
            if(std::find(axes.begin(), axes.end(), i) != axes.end())
            {
579
                new_lens[i] = 1;
Scott Thornton's avatar
Scott Thornton committed
580
581
582
            }
            else
            {
583
584
585
586
587
                new_lens[i] = old_lens[p++];
            }
        }
        return shape{type, new_lens};
    }
Paul's avatar
Paul committed
588
    argument compute(shape output_shape, std::vector<argument> args) const
589
590
591
    {
        return {std::move(output_shape), std::move(args.front().data)};
    }
Paul's avatar
Paul committed
592
    int output_alias(const std::vector<shape>&) const { return 0; }
593
594
};

Paul's avatar
Paul committed
595
596
597
struct reshape
{
    std::vector<int64_t> dims;
Paul's avatar
Paul committed
598
599
600
601

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
602
        return pack(f(self.dims, "dims"));
Paul's avatar
Paul committed
603
604
    }

Paul's avatar
Paul committed
605
    std::string name() const { return "reshape"; }
Paul's avatar
Paul committed
606
607
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
608
        check_shapes{inputs, *this}.has(1);
Paul's avatar
Paul committed
609
610
        auto&& idims = inputs.front().lens();
        std::vector<std::size_t> rdims(dims.begin(), dims.end());
611
612
        auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
        if(n_neg_dims > 1)
Paul's avatar
Paul committed
613
            MIGRAPHX_THROW("Dimensions for reshape can only have one -1 dim");
Paul's avatar
Paul committed
614
        for(std::size_t i = 0; i < dims.size(); i++)
Paul's avatar
Paul committed
615
616
617
        {
            if(dims[i] == 0)
                rdims[i] = idims[i];
Shucai Xiao's avatar
Shucai Xiao committed
618
619
620

            // since rdims using size_t type, -1 is the max value
            // is size_t that cause later compuation incorrect
Shucai Xiao's avatar
Shucai Xiao committed
621
            if(dims[i] == -1)
Shucai Xiao's avatar
Shucai Xiao committed
622
                rdims[i] = 1;
Paul's avatar
Paul committed
623
        }
624
625
626
        if(n_neg_dims > 0)
        {
            size_t missing_dim =
Shucai Xiao's avatar
Shucai Xiao committed
627
                inputs.front().elements() /
628
629
630
631
632
633
634
                std::accumulate(rdims.begin(), rdims.end(), 1, std::multiplies<int64_t>());
            for(std::size_t i = 0; i < rdims.size(); i++)
            {
                if(dims[i] == -1)
                    rdims[i] = missing_dim;
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
635

Scott Thornton's avatar
Scott Thornton committed
636
        shape s{inputs.front().type(), rdims};
Paul's avatar
Paul committed
637
        if(s.elements() != inputs.front().elements())
Paul's avatar
Paul committed
638
            MIGRAPHX_THROW("Wrong number of elements for reshape");
Scott Thornton's avatar
Scott Thornton committed
639
        return s;
Paul's avatar
Paul committed
640
    }
Paul's avatar
Paul committed
641
    argument compute(shape output_shape, std::vector<argument> args) const
Paul's avatar
Paul committed
642
    {
Paul's avatar
Paul committed
643
        return {std::move(output_shape), std::move(args.front().data)};
Paul's avatar
Paul committed
644
    }
Paul's avatar
Paul committed
645
    int output_alias(const std::vector<shape>&) const { return 0; }
Paul's avatar
Paul committed
646
647
};

Khalique's avatar
Khalique committed
648
649
650
struct pad
{
    std::vector<int64_t> pads;
Khalique's avatar
Khalique committed
651
    float value = 0.0f;
Khalique's avatar
Khalique committed
652
    enum pad_op_mode_t
Khalique's avatar
Khalique committed
653
    {
Khalique's avatar
Khalique committed
654
        constant_pad,
Khalique's avatar
Khalique committed
655
        reflect_pad,
Khalique's avatar
Khalique committed
656
        edge_pad
Khalique's avatar
Khalique committed
657
    };
Khalique's avatar
Khalique committed
658
    pad_op_mode_t mode = constant_pad;
Khalique's avatar
Khalique committed
659
660
661
662
663
664
665
666
667
668
669
670
671

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(f(self.mode, "mode"), f(self.pads, "pads"), f(self.value, "value"));
    }

    std::string name() const { return "pad"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(1);
        auto&& idims = inputs.front().lens();
        std::vector<std::size_t> rdims(idims.begin(), idims.end());
Khalique's avatar
Khalique committed
672
        std::size_t num_dims = rdims.size();
Khalique's avatar
Khalique committed
673

Khalique's avatar
Khalique committed
674
        for(std::size_t i = 0; i < num_dims; i++)
Khalique's avatar
Khalique committed
675
676
677
        {
            rdims[i] += pads[i] + pads[i + num_dims];
        }
Khalique's avatar
Khalique committed
678

Khalique's avatar
Khalique committed
679
680
681
682
683
        shape s{inputs.front().type(), rdims};
        return s;
    }
};

Paul's avatar
Paul committed
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
struct as_shape
{
    shape s;
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(f(self.s, "shape"));
    }

    std::string name() const { return "as_shape"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
        check_shapes{inputs, *this}.has(1).standard();
        assert(inputs.front().elements() == s.elements());
        return s;
    }
    argument compute(shape output_shape, std::vector<argument> args) const
    {
        return {std::move(output_shape), std::move(args.front().data)};
    }
    int output_alias(const std::vector<shape>&) const { return 0; }
};

707
708
struct gather
{
709
    int axis = 0;
710
711
712
713
714
715
    std::string name() const { return "gather"; }

    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(2);
        auto lens = inputs[0].lens();
716
717
        int n_dim = static_cast<int>(lens.size());
        if(axis >= n_dim || axis < -n_dim)
718
        {
719
            MIGRAPHX_THROW("Gather: axis is out of range.");
720
        }
721

722
        // negative axis means counting dimensions from back
723
        int axis_index = (axis < 0) ? (n_dim + axis) : axis;
724

Shucai Xiao's avatar
Shucai Xiao committed
725
        auto type        = inputs[0].type();
726
        lens[axis_index] = inputs[1].elements();
727
728
729
730
731

        return {type, lens};
    }

    template <class T>
Shucai Xiao's avatar
Shucai Xiao committed
732
    void compute_index(const T& out_idx,
Shucai Xiao's avatar
Shucai Xiao committed
733
                       const int axis_index,
Shucai Xiao's avatar
Shucai Xiao committed
734
735
736
                       const std::vector<std::size_t>& vec_indices,
                       const std::size_t max_dim,
                       T& in_idx) const
737
    {
Shucai Xiao's avatar
Shucai Xiao committed
738
        in_idx          = out_idx;
739
        std::size_t idx = vec_indices.at(out_idx[axis_index]);
740
741
        if(idx >= max_dim)
        {
742
            MIGRAPHX_THROW("Gather: indices are out of range in input tensor");
743
        }
744
        in_idx[axis_index] = idx;
745
746
747
748
749
    }

    argument compute(const shape& output_shape, std::vector<argument> args) const
    {
        argument result{output_shape};
750
751
752
        // negative axis means counting dimensions from back
        int axis_index = (axis < 0) ? (output_shape.lens().size() + axis) : axis;

753
        // max dimension in axis
754
        std::size_t max_dim = args[0].get_shape().lens()[axis_index];
755
756
        std::vector<std::size_t> vec_indices;
        args[1].visit([&](auto indices) { vec_indices.assign(indices.begin(), indices.end()); });
757
        visit_all(result, args[0])([&](auto output, auto input) {
758
            std::vector<std::size_t> in_idx;
759
            shape_for_each(output.get_shape(), [&](const auto& idx) {
760
                this->compute_index(idx, axis_index, vec_indices, max_dim, in_idx);
761
762
763
764
765
766
767
768
                output(idx.begin(), idx.end()) = input(in_idx.begin(), in_idx.end());
            });
        });

        return result;
    }
};

Shucai Xiao's avatar
Shucai Xiao committed
769
struct dot
770
{
Paul's avatar
Paul committed
771
    float alpha = 1.0;
Paul's avatar
Paul committed
772
    float beta  = 0.0;
Paul's avatar
Paul committed
773
774
775
776

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
777
        return pack(f(self.alpha, "alpha"), f(self.beta, "beta"));
Paul's avatar
Paul committed
778
779
    }

Shucai Xiao's avatar
Shucai Xiao committed
780
    std::string name() const { return "dot"; }
781
782
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
783
        check_shapes{inputs, *this}.has(2).same_type();
784
785
        const shape& a = inputs.at(0);
        const shape& b = inputs.at(1);
Scott Thornton's avatar
Scott Thornton committed
786
        auto t         = a.type();
787

788
        if(a.lens()[1] != b.lens()[0])
Paul's avatar
Paul committed
789
790
            MIGRAPHX_THROW("Inner dimensions do not match: {" + to_string_range(a.lens()) +
                           "} x {" + to_string_range(b.lens()) + "}");
Scott Thornton's avatar
Scott Thornton committed
791
        return {t, {a.lens()[0], b.lens()[1]}};
792
793
794
    }
};

795
struct unary
Scott Thornton's avatar
Scott Thornton committed
796
{
797
798
    shape compute_shape(std::vector<shape> inputs) const
    {
799
800
        check_shapes{inputs}.has(1);
        return inputs.at(0);
801
    }
Scott Thornton's avatar
Scott Thornton committed
802
803
};

804
struct identity
805
{
806
    std::string name() const { return "identity"; }
Scott Thornton's avatar
Scott Thornton committed
807
    shape compute_shape(std::vector<shape> inputs) const { return inputs.at(0); }
Paul's avatar
Paul committed
808
    argument compute(shape output_shape, std::vector<argument> args) const
809
810
811
    {
        return {std::move(output_shape), std::move(args.at(0).data)};
    }
Paul's avatar
Paul committed
812
    int output_alias(const std::vector<shape>&) const { return 0; }
813
814
815
};

struct abs : unary
Scott Thornton's avatar
Scott Thornton committed
816
{
817
    std::string name() const { return "abs"; }
Scott Thornton's avatar
Scott Thornton committed
818
819
};

820
struct exp : unary
Scott Thornton's avatar
Scott Thornton committed
821
{
822
    std::string name() const { return "exp"; }
Scott Thornton's avatar
Scott Thornton committed
823
824
};

Shucai Xiao's avatar
Shucai Xiao committed
825
826
827
828
829
struct log : unary
{
    std::string name() const { return "log"; }
};

830
struct sin : unary
Scott Thornton's avatar
Scott Thornton committed
831
{
832
    std::string name() const { return "sin"; }
Scott Thornton's avatar
Scott Thornton committed
833
834
};

835
struct cos : unary
Scott Thornton's avatar
Scott Thornton committed
836
{
837
    std::string name() const { return "cos"; }
Scott Thornton's avatar
Scott Thornton committed
838
839
};

840
struct tan : unary
Scott Thornton's avatar
Scott Thornton committed
841
{
842
    std::string name() const { return "tan"; }
Scott Thornton's avatar
Scott Thornton committed
843
844
};

845
struct asin : unary
Scott Thornton's avatar
Scott Thornton committed
846
{
847
    std::string name() const { return "asin"; }
Scott Thornton's avatar
Scott Thornton committed
848
849
};

850
struct acos : unary
Scott Thornton's avatar
Scott Thornton committed
851
{
852
    std::string name() const { return "acos"; }
Scott Thornton's avatar
Scott Thornton committed
853
854
};

855
struct atan : unary
Scott Thornton's avatar
Scott Thornton committed
856
{
857
    std::string name() const { return "atan"; }
Scott Thornton's avatar
Scott Thornton committed
858
859
};

860
861
862
863
864
865
866
867
868
869
struct sinh : unary
{
    std::string name() const { return "sinh"; }
};

struct cosh : unary
{
    std::string name() const { return "cosh"; }
};

870
struct tanh : unary
Scott Thornton's avatar
Scott Thornton committed
871
{
872
    std::string name() const { return "tanh"; }
Scott Thornton's avatar
Scott Thornton committed
873
874
};

875
struct sigmoid : unary
Scott Thornton's avatar
Scott Thornton committed
876
{
877
    std::string name() const { return "sigmoid"; }
Scott Thornton's avatar
Scott Thornton committed
878
879
};

880
struct neg : unary
Scott Thornton's avatar
Scott Thornton committed
881
{
882
    std::string name() const { return "neg"; }
Scott Thornton's avatar
Scott Thornton committed
883
884
};

Khalique's avatar
Khalique committed
885
886
887
888
889
struct relu : unary
{
    std::string name() const { return "relu"; }
};

Paul's avatar
Paul committed
890
891
892
893
894
895
896
897
898
899
struct softmax
{
    std::string name() const { return "softmax"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs}.has(1).only_dims(4);
        return inputs.at(0);
    }
};

900
struct flatten
Scott Thornton's avatar
Scott Thornton committed
901
{
Paul's avatar
Paul committed
902
    uint64_t axis = 0;
Paul's avatar
Paul committed
903
904
905
906

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
907
        return pack(f(self.axis, "axis"));
Paul's avatar
Paul committed
908
909
    }

Scott Thornton's avatar
Scott Thornton committed
910
    std::string name() const { return "flatten"; }
Paul's avatar
Paul committed
911
912
913
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs}.has(1);
Paul's avatar
Paul committed
914
915
        auto&& lens = inputs.front().lens();

Paul's avatar
Paul committed
916
        if(axis > lens.size())
Paul's avatar
Paul committed
917
        {
Paul's avatar
Paul committed
918
            MIGRAPHX_THROW("axis for flatten must be less than tensor rank");
Paul's avatar
Paul committed
919
        }
Paul's avatar
Paul committed
920
921
922
923
        auto x =
            std::accumulate(lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{});
        auto y =
            std::accumulate(lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<>{});
924
        return {inputs.at(0).type(), {x, y}};
Paul's avatar
Paul committed
925
    }
Paul's avatar
Paul committed
926
    argument compute(shape output_shape, std::vector<argument> args) const
Paul's avatar
Paul committed
927
    {
Paul's avatar
Paul committed
928
        return {std::move(output_shape), std::move(args.front().data)};
Paul's avatar
Paul committed
929
    }
Paul's avatar
Paul committed
930
    int output_alias(const std::vector<shape>&) const { return 0; }
Scott Thornton's avatar
Scott Thornton committed
931
};
932

wsttiger's avatar
fixes  
wsttiger committed
933
934
935
936
937
938
939
940
/// The broadcast operator performs the numpy-style broadcasting of an axis of a given tensor. This
/// is achieved primarily by setting the stride of the broadcasted axis to zero. Linear indicies are
/// computed from multi-indicies by computing the inner product on the multi-index with the strides.
/// For example, if we have a tensor A(2,3) it has lengths of (2,3) and strides of (3,1). If we want
/// to compute the linear offset that corresponds to the element on the 2nd row (i = 1) and 3rd
/// column (j = 2), we compute the following inner product (1,2) dot (3, 1) = 1*3 + 2*1 = 5. It is
/// obvious from there that we can negate the effects of a given axis by setting the stride of that
/// axis to zero.
941
942
943
struct broadcast
{
    uint64_t axis = 0;
Paul's avatar
Paul committed
944
945
946
947

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
948
        return pack(f(self.axis, "axis"));
Paul's avatar
Paul committed
949
950
    }

Scott Thornton's avatar
Scott Thornton committed
951
    shape broadcast_shape;
952
953
954
    std::string name() const { return "broadcast"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Scott Thornton's avatar
Scott Thornton committed
955
956
        auto t     = inputs.at(0).type();
        auto input = inputs.at(0);
Paul's avatar
Paul committed
957

Scott Thornton's avatar
Scott Thornton committed
958
        std::vector<size_t> bcast_strides(broadcast_shape.lens().size(), 0);
959

Scott Thornton's avatar
Scott Thornton committed
960
961
962
        if(std::all_of(broadcast_shape.lens().cbegin(), broadcast_shape.lens().cend(), [&](auto x) {
               return x == 1;
           }))
963
        {
Scott Thornton's avatar
Scott Thornton committed
964
            if(axis != 0)
Paul's avatar
Paul committed
965
                MIGRAPHX_THROW("when broadcasting tensor of size 1, axis should be 0");
Scott Thornton's avatar
Scott Thornton committed
966
            return {t, broadcast_shape.lens(), std::move(bcast_strides)};
967
968
969
        }
        else
        {
Scott Thornton's avatar
Scott Thornton committed
970
            assert(broadcast_shape.lens().size() - axis >= input.lens().size());
Scott Thornton's avatar
Scott Thornton committed
971
972
            if(!std::equal(
                   input.lens().begin(), input.lens().end(), broadcast_shape.lens().begin() + axis))
Paul's avatar
Paul committed
973
                MIGRAPHX_THROW("when broadcasting success sizes must match");
Paul's avatar
Paul committed
974
            std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
Scott Thornton's avatar
Scott Thornton committed
975
            return {t, broadcast_shape.lens(), std::move(bcast_strides)};
976
977
        }
    }
Paul's avatar
Paul committed
978
    argument compute(shape output_shape, std::vector<argument> args) const
Scott Thornton's avatar
Scott Thornton committed
979
    {
Scott Thornton's avatar
Scott Thornton committed
980
        return {std::move(output_shape), std::move(args.at(0).data)};
Scott Thornton's avatar
Scott Thornton committed
981
    }
Paul's avatar
Paul committed
982
    int output_alias(const std::vector<shape>&) const { return 0; }
983
984
};

Scott Thornton's avatar
Scott Thornton committed
985
986
987
struct multibroadcast
{
    std::vector<std::size_t> output_lens;
988
989
990
991
992
993
994

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(f(self.output_lens, "output_lens"));
    }

Scott Thornton's avatar
Scott Thornton committed
995
    std::string name() const { return "multibroadcast"; }
996

Scott Thornton's avatar
Scott Thornton committed
997
998
999
1000
1001
1002
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(1);
        auto t     = inputs.at(0).type();
        auto input = inputs.at(0);

wsttiger's avatar
wsttiger committed
1003
        if(input.lens().empty())
Paul's avatar
Paul committed
1004
            MIGRAPHX_THROW("inputs dimensions should be > 0");
Scott Thornton's avatar
Scott Thornton committed
1005

Scott Thornton's avatar
Scott Thornton committed
1006
        if(input.lens().size() > output_lens.size())
Paul's avatar
Paul committed
1007
            MIGRAPHX_THROW("inputs dimensions should <= output size");
Scott Thornton's avatar
Scott Thornton committed
1008
1009

        std::vector<size_t> bcast_strides(output_lens.size(), 0);
Scott Thornton's avatar
Scott Thornton committed
1010
1011
        auto offset = output_lens.size() - input.lens().size();
        for(int i = input.lens().size() - 1; i >= 0; i--)
Scott Thornton's avatar
Scott Thornton committed
1012
        {
Scott Thornton's avatar
Scott Thornton committed
1013
            if(output_lens[i + offset] == input.lens()[i])
Scott Thornton's avatar
Scott Thornton committed
1014
            {
Scott Thornton's avatar
Scott Thornton committed
1015
                bcast_strides[i + offset] = input.strides()[i];
Scott Thornton's avatar
Scott Thornton committed
1016
1017
1018
1019
            }
        }
        return {t, output_lens, bcast_strides};
    }
Paul's avatar
Paul committed
1020
    argument compute(shape output_shape, std::vector<argument> args) const
Scott Thornton's avatar
Scott Thornton committed
1021
1022
1023
1024
1025
1026
    {
        return {std::move(output_shape), std::move(args.at(0).data)};
    }
    int output_alias(const std::vector<shape>&) const { return 0; }
};

Khalique's avatar
Khalique committed
1027
1028
1029
1030
1031
1032
1033
1034
1035
struct scalar
{
    shape scalar_bcast;

    std::string name() const { return "scalar"; }

    shape compute_shape(std::vector<shape> inputs) const
    {
        assert(check_shapes{inputs}.has(1).only_dims(1).size() == 1);
Paul's avatar
Paul committed
1036
        auto t = inputs.at(0).type();
Khalique's avatar
Khalique committed
1037
1038
1039
1040
        std::vector<std::size_t> strides(scalar_bcast.lens().size(), 0);
        return {t, scalar_bcast.lens(), strides};
    }

Paul's avatar
Paul committed
1041
    argument compute(shape output_shape, std::vector<argument> args) const
Khalique's avatar
Khalique committed
1042
1043
1044
    {
        return {std::move(output_shape), std::move(args.at(0).data)};
    }
Paul's avatar
Paul committed
1045
    int output_alias(const std::vector<shape>&) const { return 0; }
Khalique's avatar
Khalique committed
1046
1047
};

1048
struct binary
Scott Thornton's avatar
Scott Thornton committed
1049
{
1050
1051
    shape compute_shape(std::vector<shape> inputs) const
    {
1052
        check_shapes{inputs}.has(2).same_type().same_dims();
Scott Thornton's avatar
Scott Thornton committed
1053
        auto t    = inputs.at(0).type();
1054
1055
        auto lens = inputs.at(0).lens();
        return {t, lens};
1056
    }
Scott Thornton's avatar
Scott Thornton committed
1057
1058
};

1059
1060
1061
1062
1063
1064
struct add : binary
{
    std::string name() const { return "add"; }
};

struct sub : binary
Scott Thornton's avatar
Scott Thornton committed
1065
1066
1067
1068
{
    std::string name() const { return "sub"; }
};

1069
struct mul : binary
Scott Thornton's avatar
Scott Thornton committed
1070
1071
1072
1073
{
    std::string name() const { return "mul"; }
};

1074
struct div : binary
Scott Thornton's avatar
Scott Thornton committed
1075
1076
1077
1078
{
    std::string name() const { return "div"; }
};

Khalique's avatar
Khalique committed
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
struct max : binary
{
    std::string name() const { return "max"; }
};

struct min : binary
{
    std::string name() const { return "min"; }
};

Paul's avatar
Paul committed
1089
1090
1091
1092
struct load
{
    shape s;
    std::size_t offset = 0;
Paul's avatar
Paul committed
1093
1094
1095
1096

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
1097
        return pack(f(self.s, "shape"), f(self.offset, "offset"));
Paul's avatar
Paul committed
1098
1099
    }

Paul's avatar
Paul committed
1100
1101
1102
1103
1104
1105
    std::string name() const { return "load"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
        check_shapes{inputs}.has(1);
        return s;
    }
Paul's avatar
Paul committed
1106
    argument compute(const shape&, const std::vector<argument>& args) const
Paul's avatar
Paul committed
1107
1108
1109
    {
        return {s, args[0].data() + offset};
    }
Paul's avatar
Paul committed
1110
    int output_alias(const std::vector<shape>&) const { return 0; }
Paul's avatar
Paul committed
1111
1112
};

Paul's avatar
Paul committed
1113
struct outline
Scott Thornton's avatar
Scott Thornton committed
1114
{
Paul's avatar
Paul committed
1115
    shape s;
Paul's avatar
Paul committed
1116
1117
1118
1119

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
1120
        return pack(f(self.s, "shape"));
Paul's avatar
Paul committed
1121
1122
    }

Paul's avatar
Paul committed
1123
    std::string name() const { return "outline"; }
Paul's avatar
Paul committed
1124
    shape compute_shape(const std::vector<shape>& inputs) const
Paul's avatar
Paul committed
1125
    {
Paul's avatar
Paul committed
1126
        check_shapes{inputs, *this}.has(0);
Paul's avatar
Paul committed
1127
1128
        return s;
    }
Paul's avatar
Paul committed
1129
    argument compute(const shape&, const std::vector<argument>&) const { return {s, nullptr}; }
Scott Thornton's avatar
Scott Thornton committed
1130
1131
};

Shucai Xiao's avatar
Shucai Xiao committed
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
struct rnn
{

    enum rnn_direction_t
    {
        forward,
        reverse,
        bidirectional,
    };

Shucai Xiao's avatar
Shucai Xiao committed
1142
    std::size_t hidden_size = 1;
1143
    std::vector<operation> actv_funcs{tanh{}};
Shucai Xiao's avatar
Shucai Xiao committed
1144
    rnn_direction_t direction = forward;
Shucai Xiao's avatar
Shucai Xiao committed
1145
    float clip                = 0.0f;
Shucai Xiao's avatar
Shucai Xiao committed
1146
1147
1148
1149

    std::string name() const { return "rnn"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Shucai Xiao's avatar
Shucai Xiao committed
1150
        auto in_dims     = inputs[0].lens();
Shucai Xiao's avatar
Shucai Xiao committed
1151
1152
        auto hidden_dims = inputs[2].lens();
        if(hidden_size != hidden_dims[2])
Shucai Xiao's avatar
Shucai Xiao committed
1153
1154
1155
1156
1157
        {
            MIGRAPHX_THROW("RNN: hidden size mismatch in attribute and input");
        }

        std::size_t num_directions = 1;
Shucai Xiao's avatar
Shucai Xiao committed
1158
        if(direction == bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1159
1160
1161
1162
        {
            num_directions = 2;
        }

Shucai Xiao's avatar
Shucai Xiao committed
1163
        if(num_directions != hidden_dims[0])
Shucai Xiao's avatar
Shucai Xiao committed
1164
        {
1165
            MIGRAPHX_THROW("RNN: num_direction mismatch in attribute and input");
Shucai Xiao's avatar
Shucai Xiao committed
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
        }

        std::vector<std::size_t> out_dims(in_dims);
        out_dims.insert(out_dims.begin() + 1, num_directions);
        out_dims.back() = hidden_size;

        return {inputs[0].type(), out_dims};
    }
};

Shucai Xiao's avatar
Shucai Xiao committed
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
struct gru
{
    enum gru_direction_t
    {
        forward,
        reverse,
        bidirectional,
    };

    std::size_t hidden_size = 1;
    std::vector<operation> actv_funcs{sigmoid{}, tanh{}};
    gru_direction_t direction = forward;
    float clip                = 0.0f;
    int linear_before_reset   = 0;

    std::string name() const { return "gru"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        auto in_dims     = inputs[0].lens();
        auto hidden_dims = inputs[2].lens();
        if(hidden_size != hidden_dims[2])
        {
            MIGRAPHX_THROW("GRU: hidden size mismatch in attribute and input");
        }

        std::size_t num_directions = 1;
        if(direction == bidirectional)
        {
            num_directions = 2;
        }

        if(num_directions != hidden_dims[0])
        {
            MIGRAPHX_THROW("GRU: num_direction does not match the direction attribute");
        }

        std::vector<std::size_t> out_dims(in_dims);
        out_dims.insert(out_dims.begin() + 1, num_directions);
        out_dims.back() = hidden_size;

        return {inputs[0].type(), out_dims};
    }
};

Shucai Xiao's avatar
Shucai Xiao committed
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
struct rnn_last_output
{
    std::string name() const { return "rnn_last_output"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(1);
        auto dims = inputs[0].lens();

        // remove the first dimension, remaing are output shape
        dims.erase(dims.begin());
        return {inputs[0].type(), dims};
    }
};

1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
struct gru_last_output
{
    std::string name() const { return "gru_last_output"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(1);
        auto dims = inputs[0].lens();

        // remove the first dimension, remaing are output shape
        dims.erase(dims.begin());
        return {inputs[0].type(), dims};
    }
};

1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
struct undefined
{
    std::string name() const { return "undefined"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(0);
        return {};
    }

    argument compute(const shape&, const std::vector<argument>&) const { return {{}, nullptr}; }
};

1260
} // namespace op
Paul's avatar
Paul committed
1261
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
1262
} // namespace migraphx
Paul's avatar
Paul committed
1263
1264

#endif