operators.hpp 26.5 KB
Newer Older
Paul's avatar
Paul committed
1
2
#ifndef MIGRAPH_GUARD_OPERATORS_HPP
#define MIGRAPH_GUARD_OPERATORS_HPP
Paul's avatar
Paul committed
3

4
#include <array>
Paul's avatar
Paul committed
5
#include <migraph/operation.hpp>
Paul's avatar
Paul committed
6
#include <migraph/check_shapes.hpp>
Paul's avatar
Paul committed
7
8
#include <migraph/stringutils.hpp>
#include <migraph/streamutils.hpp>
Paul's avatar
Paul committed
9
#include <cmath>
Paul's avatar
Paul committed
10
#include <utility>
Paul's avatar
Paul committed
11

Paul's avatar
Paul committed
12
namespace migraph {
13
namespace op {
Paul's avatar
Paul committed
14

Paul's avatar
Paul committed
15
16
struct not_computable
{
Paul's avatar
Paul committed
17
    argument compute(context&, const shape&, const std::vector<argument>&) const
Paul's avatar
Paul committed
18
19
20
    {
        MIGRAPH_THROW("not computable");
    }
Paul's avatar
Paul committed
21
22
};

23
24
struct batch_norm_inference
{
25
26
    float epsilon  = 1.0e-6f;
    float momentum = 0.9f;
27
28
29

    std::string name() const { return "batch_norm_inference"; }

30
31
32
33
34
35
36
37
    enum bn_infer_mode_t
    {
        per_activation,
        spatial,
    };

    bn_infer_mode_t bn_mode = spatial;

Paul's avatar
Paul committed
38
39
40
41
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(
Paul's avatar
Paul committed
42
            f(self.epsilon, "epsilon"), f(self.momentum, "momentum"), f(self.bn_mode, "bn_mode"));
Paul's avatar
Paul committed
43
    }
44

45
46
47
48
49
50
51
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(5);
        return inputs.front();
    }
};

Paul's avatar
Paul committed
52
struct convolution
Paul's avatar
Paul committed
53
{
Paul's avatar
Paul committed
54
55
56
    std::array<std::size_t, 2> padding  = {{0, 0}};
    std::array<std::size_t, 2> stride   = {{1, 1}};
    std::array<std::size_t, 2> dilation = {{1, 1}};
Paul's avatar
Paul committed
57
58
59
60
61
62
63
    enum padding_mode_t
    {
        default_, // NOLINT
        same,
        valid
    };
    padding_mode_t padding_mode = default_;
Paul's avatar
Paul committed
64
65
66
67

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
68
69
70
71
        return pack(f(self.padding, "padding"),
                    f(self.stride, "stride"),
                    f(self.dilation, "dilation"),
                    f(self.padding_mode, "padding_mode"));
Paul's avatar
Paul committed
72
73
    }

Paul's avatar
Paul committed
74
    std::string name() const { return "convolution"; }
Paul's avatar
Paul committed
75
76
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
77
        check_shapes{inputs, *this}.has(2).same_type().same_ndims().only_dims(4);
Paul's avatar
Paul committed
78

Paul's avatar
Paul committed
79
        const shape& input   = inputs.at(0);
Paul's avatar
Paul committed
80
        const shape& weights = inputs.at(1);
Paul's avatar
Paul committed
81
        auto t               = input.type();
Paul's avatar
Paul committed
82
83
        if(padding_mode == default_)
        {
Paul's avatar
Paul committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
            return {t,
                    {
                        input.lens()[0],
                        weights.lens()[0],
                        std::size_t(std::max<std::ptrdiff_t>(
                            1,
                            (input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) +
                             2 * padding[0]) /
                                    stride[0] +
                                1)),
                        std::size_t(std::max<std::ptrdiff_t>(
                            1,
                            (input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) +
                             2 * padding[1]) /
                                    stride[1] +
                                1)),
                    }};
Paul's avatar
Paul committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        }
        else if(padding_mode == same)
        {
            return {t,
                    {input.lens()[0],
                     weights.lens()[0],
                     static_cast<std::size_t>(
                         std::ceil(static_cast<double>(input.lens()[2]) / stride[0])),
                     static_cast<std::size_t>(
                         std::ceil(static_cast<double>(input.lens()[3]) / stride[1]))}};
        }
        else if(padding_mode == valid)
        {
            return {
                t,
                {input.lens()[0],
                 weights.lens()[0],
                 static_cast<std::size_t>(std::ceil(
                     static_cast<double>(input.lens()[2] - weights.lens()[2] + 1) / stride[0])),
                 static_cast<std::size_t>(std::ceil(
                     static_cast<double>(input.lens()[3] - weights.lens()[3] + 1) / stride[1]))}};
        }
        else
        {
Paul's avatar
Paul committed
125
            MIGRAPH_THROW("Invalid padding mode");
Paul's avatar
Paul committed
126
        }
Paul's avatar
Paul committed
127
128
129
    }
};

Scott Thornton's avatar
Scott Thornton committed
130
131
struct im2col
{
Scott Thornton's avatar
Scott Thornton committed
132
133
134
135
136
137
138
139
140
    std::array<std::size_t, 2> padding  = {{0, 0}};
    std::array<std::size_t, 2> stride   = {{1, 1}};
    std::array<std::size_t, 2> dilation = {{1, 1}};
    enum padding_mode_t
    {
        default_, // NOLINT
        same,
        valid
    };
Paul's avatar
Paul committed
141
142
143
144
145
    padding_mode_t padding_mode = default_;

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
146
147
148
149
        return pack(f(self.padding, "padding"),
                    f(self.stride, "stride"),
                    f(self.dilation, "dilation"),
                    f(self.padding_mode, "padding_mode"));
Paul's avatar
Paul committed
150
    }
Scott Thornton's avatar
Scott Thornton committed
151
152
153
154
155

    std::string name() const { return "im2col"; }

    shape compute_shape(std::vector<shape> inputs) const
    {
Scott Thornton's avatar
Scott Thornton committed
156
157
158
        auto input          = inputs[0];
        auto weights        = inputs[1];
        auto batch_size     = input.lens()[0];
Scott Thornton's avatar
Scott Thornton committed
159
        auto input_channels = weights.lens()[1];
Scott Thornton's avatar
Scott Thornton committed
160
161
        auto kernel_height  = weights.lens()[2];
        auto kernel_width   = weights.lens()[3];
Scott Thornton's avatar
Scott Thornton committed
162
        check_shapes{inputs, *this}.has(2);
Scott Thornton's avatar
Scott Thornton committed
163
164
        if(batch_size != 1)
            MIGRAPH_THROW("im2col only support batch_size 1");
Scott Thornton's avatar
Scott Thornton committed
165
        auto output_height = std::size_t(std::max<std::ptrdiff_t>(
Scott Thornton's avatar
Scott Thornton committed
166
167
168
            1,
            (input.lens()[2] - (1 + dilation[0] * (kernel_height - 1)) + 2 * padding[0]) /
                    stride[0] +
Scott Thornton's avatar
Scott Thornton committed
169
                1));
Scott Thornton's avatar
Scott Thornton committed
170
171
172
173
        auto output_width  = std::size_t(std::max<std::ptrdiff_t>(
            1,
            (input.lens()[3] - (1 + dilation[1] * (kernel_width - 1)) + 2 * padding[1]) /
                    stride[1] +
Scott Thornton's avatar
Scott Thornton committed
174
                1));
Scott Thornton's avatar
Scott Thornton committed
175
176
        auto channels_col  = kernel_height * kernel_width * input_channels;
        return {input.type(), {output_height * output_width, channels_col}};
Scott Thornton's avatar
Scott Thornton committed
177
178
179
    }
};

Paul's avatar
Paul committed
180
struct pooling
Paul's avatar
Paul committed
181
{
Paul's avatar
Paul committed
182
    std::string mode                   = "average";
Paul's avatar
Paul committed
183
184
185
    std::array<std::size_t, 2> padding = {{0, 0}};
    std::array<std::size_t, 2> stride  = {{1, 1}};
    std::array<std::size_t, 2> lengths = {{1, 1}};
Paul's avatar
Paul committed
186
187
188
189

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
190
191
192
193
        return pack(f(self.mode, "mode"),
                    f(self.padding, "padding"),
                    f(self.stride, "stride"),
                    f(self.lengths, "lengths"));
Paul's avatar
Paul committed
194
195
    }

Paul's avatar
Paul committed
196
    std::string name() const { return "pooling"; }
Scott Thornton's avatar
Scott Thornton committed
197

Paul's avatar
Paul committed
198
199
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
200
        check_shapes{inputs, *this}.has(1).only_dims(4);
Paul's avatar
Paul committed
201

Paul's avatar
Paul committed
202
        const shape& input = inputs.at(0);
Paul's avatar
Paul committed
203
        auto t             = input.type();
Paul's avatar
Paul committed
204

Paul's avatar
Paul committed
205
206
        assert(lengths[0] <= (input.lens()[2] + 2 * padding[0]));
        assert(lengths[1] <= (input.lens()[3] + 2 * padding[1]));
Paul's avatar
Paul committed
207

Scott Thornton's avatar
Scott Thornton committed
208
209
210
211
212
213
        return {t,
                {
                    input.lens()[0],
                    input.lens()[1],
                    std::size_t(std::max<std::ptrdiff_t>(
                        1,
Paul's avatar
Paul committed
214
                        std::ptrdiff_t(std::floor((input.lens()[2] + 2 * padding[0] - lengths[0]) /
Paul's avatar
Paul committed
215
                                                  static_cast<float>(stride[0]))) +
Scott Thornton's avatar
Scott Thornton committed
216
217
218
                            1)),
                    std::size_t(std::max<std::ptrdiff_t>(
                        1,
Paul's avatar
Paul committed
219
                        std::ptrdiff_t(std::floor((input.lens()[3] + 2 * padding[1] - lengths[1]) /
Paul's avatar
Paul committed
220
                                                  static_cast<float>(stride[1]))) +
Scott Thornton's avatar
Scott Thornton committed
221
222
                            1)),
                }};
Paul's avatar
Paul committed
223
224
225
    }
};

Khalique's avatar
Khalique committed
226
227
228
229
230
231
232
233
234
235
236
237
238
struct leaky_relu
{
    std::string name() const { return "leaky_relu"; }
    float alpha;
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(1);
        return inputs.front();
    }
    friend std::ostream& operator<<(std::ostream& os, const leaky_relu& op)
    {
        os << op.name() << ":" << op.alpha;
        return os;
Khalique's avatar
Khalique committed
239
    }
Khalique's avatar
Khalique committed
240
241
};

242
243
244
struct transpose
{
    std::vector<int64_t> dims;
Paul's avatar
Paul committed
245
246
247
248

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
249
        return pack(f(self.dims, "dims"));
Paul's avatar
Paul committed
250
251
    }

252
253
254
    std::string name() const { return "transpose"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
255
        check_shapes{inputs, *this}.has(1);
256
        auto input         = inputs.at(0);
257
        auto input_lens    = input.lens();
258
259
        auto input_strides = input.strides();
        auto t             = input.type();
Paul's avatar
Paul committed
260
261
        if(dims.size() != input_lens.size())
        {
Paul's avatar
Paul committed
262
            MIGRAPH_THROW("Permutation has wrong number of axes");
263
264
265
        }
        std::vector<int64_t> axes(dims.size());
        std::iota(axes.begin(), axes.end(), 0);
Paul's avatar
Paul committed
266
267
        if(!std::is_permutation(axes.begin(), axes.end(), dims.begin()))
        {
Paul's avatar
Paul committed
268
            MIGRAPH_THROW("Invalid permutation");
269
        }
270
271
        std::vector<size_t> output_lens(input_lens.size());
        std::vector<size_t> output_strides(input_lens.size());
Paul's avatar
Paul committed
272
273
274
        for(int i = 0; i < output_lens.size(); i++)
        {
            output_lens[i]    = input_lens[dims[i]];
275
276
            output_strides[i] = input_strides[dims[i]];
        }
277
        return {t, output_lens, output_strides};
278
    }
Paul's avatar
Paul committed
279
    argument compute(context&, shape output_shape, std::vector<argument> args) const
Paul's avatar
Paul committed
280
    {
Paul's avatar
Paul committed
281
        return {std::move(output_shape), std::move(args.front().data)};
Paul's avatar
Paul committed
282
    }
Paul's avatar
Paul committed
283
    int output_alias(const std::vector<shape>&) const { return 0; }
284
285
};

Paul's avatar
Paul committed
286
struct contiguous
287
288
289
290
{
    std::string name() const { return "contiguous"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
291
        check_shapes{inputs, *this}.has(1);
Paul's avatar
Paul committed
292
293
        auto lens = inputs.at(0).lens();
        auto t    = inputs.at(0).type();
294
295
296
297
        return {t, lens};
    }
};

298
299
300
301
struct concat
{
    std::size_t axis = 0;
    std::string name() const { return "concat"; }
302
303
304
305
306
307
308
309
310
311
312
313
314
    std::vector<std::size_t> compute_offsets(const shape& output_shape,
                                             const std::vector<argument> args) const
    {
        std::vector<std::size_t> offsets;
        std::vector<std::size_t> offset(args[0].get_shape().lens().size(), 0);
        offset[axis] = 0;
        for(const auto& arg : args)
        {
            offsets.push_back(output_shape.index(offset));
            offset[axis] += arg.get_shape().lens()[axis];
        }
        return offsets;
    }
315
316
    shape compute_shape(std::vector<shape> inputs) const
    {
Scott Thornton's avatar
Scott Thornton committed
317
        if(inputs.empty())
318
319
320
321
322
        {
            MIGRAPH_THROW("Number of input tensors should exceed 0");
        }

        const auto& first_shape_lens = inputs.front().lens();
Scott Thornton's avatar
Scott Thornton committed
323
324
325
326
327
328
329
330
331
332
        const auto& type             = inputs.front().type();
        for(std::size_t l = 0; l < first_shape_lens.size(); l++)
        {
            if(l != axis)
            {
                if(!std::all_of(inputs.begin(), inputs.end(), [&](auto s) {
                       return s.lens()[l] == first_shape_lens[l];
                   }))
                {
                    MIGRAPH_THROW("Non-axis dimensions should match");
333
334
335
336
                }
            }
        }
        std::size_t new_dim_axis = 0;
Scott Thornton's avatar
Scott Thornton committed
337
        for(const auto& input : inputs)
338
339
340
341
342
343
344
345
346
        {
            const auto& lens = input.lens();
            new_dim_axis += lens[axis];
        }
        std::vector<std::size_t> new_lens;
        std::copy(first_shape_lens.begin(), first_shape_lens.end(), std::back_inserter(new_lens));
        new_lens[axis] = new_dim_axis;
        return {type, new_lens};
    }
Paul's avatar
Paul committed
347
    int output_alias(const std::vector<shape>&) const { return 0; }
348
349
};

350
351
352
353
354
struct slice
{
    std::vector<int64_t> axes;
    std::vector<int64_t> starts;
    std::vector<int64_t> ends;
Paul's avatar
Paul committed
355
356
357
358

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
359
        return pack(f(self.axes, "axes"), f(self.starts, "starts"), f(self.ends, "ends"));
Paul's avatar
Paul committed
360
361
    }

362
    std::string name() const { return "slice"; }
Scott Thornton's avatar
Scott Thornton committed
363
364

    auto fix_index(const std::vector<std::size_t>& lens, std::size_t axis, int64_t index) const
365
    {
Scott Thornton's avatar
Scott Thornton committed
366
        int64_t r = std::min(index, static_cast<int64_t>(lens[axis]));
Scott Thornton's avatar
Scott Thornton committed
367
368
        if(r < 0)
            r += lens[axis];
Scott Thornton's avatar
Scott Thornton committed
369
        return std::size_t(r);
Scott Thornton's avatar
Scott Thornton committed
370
371
372
373
374
375
376
    }

    auto compute_offset(const shape& s) const
    {
        const std::vector<std::size_t>& lens    = s.lens();
        const std::vector<std::size_t>& strides = s.strides();
        auto offset                             = 0;
Scott Thornton's avatar
Scott Thornton committed
377
        if(!axes.empty())
Scott Thornton's avatar
Scott Thornton committed
378
        {
Scott Thornton's avatar
Scott Thornton committed
379
380
381
382
383
            for(std::size_t i = 0; i < axes.size(); i++)
            {
                auto axis = axes[i];
                offset += fix_index(lens, axis, starts[i]) * strides[axis];
            }
384
        }
Scott Thornton's avatar
Scott Thornton committed
385
386
        else
        {
Scott Thornton's avatar
Scott Thornton committed
387
388
389
390
            for(std::size_t axis = 0; axis < lens.size(); axis++)
            {
                offset += fix_index(lens, axis, starts[axis]) * strides[axis];
            }
391
        }
Scott Thornton's avatar
Scott Thornton committed
392
393
394
395
396
        return offset;
    }

    shape compute_shape(std::vector<shape> inputs) const
    {
Scott Thornton's avatar
Scott Thornton committed
397
398
399
400
        auto input_shape        = inputs[0];
        auto t                  = input_shape.type();
        const auto& old_lens    = input_shape.lens();
        const auto& old_strides = input_shape.strides();
Scott Thornton's avatar
Scott Thornton committed
401
402
403
404
405
406
407
408
409
410
        // std::vector<int64_t> t_axes(old_lens.size());
        // if(axes.size() == 0)
        // {
        //     std::iota(t_axes.begin(), t_axes.end(), 0);
        // }
        // else
        // {
        //     std::copy(axes.begin(), axes.end(), t_axes.begin());
        // }
        if(starts.size() != axes.size() || axes.size() != ends.size())
Scott Thornton's avatar
Scott Thornton committed
411
        {
412
413
            MIGRAPH_THROW("inconsistent sizes");
        }
Scott Thornton's avatar
Scott Thornton committed
414
415
        std::vector<std::size_t> new_lens = old_lens;
        for(std::size_t i = 0; i < axes.size(); i++)
Scott Thornton's avatar
Scott Thornton committed
416
        {
Scott Thornton's avatar
Scott Thornton committed
417
418
419
            auto axis = axes[i];
            new_lens[axis] =
                fix_index(old_lens, axis, ends[i]) - fix_index(old_lens, axis, starts[i]);
420
421
422
423
424
        }
        return shape{t, new_lens, old_strides};
    }
    argument compute(context&, shape output_shape, std::vector<argument> args) const
    {
Scott Thornton's avatar
Scott Thornton committed
425
426
427
        auto input  = args[0];
        auto offset = compute_offset(input.get_shape()) * output_shape.type_size();
        return {std::move(output_shape), [=] { return input.data() + offset; }};
428
    }
Paul's avatar
Paul committed
429
    int output_alias(const std::vector<shape>&) const { return 0; }
430
431
432
433
434
};

struct squeeze
{
    std::vector<int64_t> axes;
Paul's avatar
Paul committed
435
436
437
438

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
439
        return pack(f(self.axes, "axes"));
Paul's avatar
Paul committed
440
441
    }

442
443
444
445
    std::string name() const { return "squeeze"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        auto input_shape = inputs[0];
Scott Thornton's avatar
Scott Thornton committed
446
447
        auto type        = input_shape.type();
        auto old_lens    = input_shape.lens();
wsttiger's avatar
wsttiger committed
448
449
        if(std::any_of(
               axes.begin(), axes.end(), [&](auto axis) { return input_shape.lens()[axis] != 1; }))
Scott Thornton's avatar
Scott Thornton committed
450
        {
wsttiger's avatar
wsttiger committed
451
            MIGRAPH_THROW("squeeze axis dimension should be equal to 1");
452
453
        }
        std::vector<std::size_t> new_lens;
Scott Thornton's avatar
Scott Thornton committed
454
        if(axes.empty())
Scott Thornton's avatar
Scott Thornton committed
455
        {
wsttiger's avatar
wsttiger committed
456
457
458
459
            std::copy_if(old_lens.begin(),
                         old_lens.end(),
                         std::back_inserter(new_lens),
                         [](auto len) { return len != 1; });
460
        }
Scott Thornton's avatar
Scott Thornton committed
461
462
463
464
465
466
        else
        {
            for(std::size_t i = 0; i < old_lens.size(); i++)
            {
                if(std::find(axes.begin(), axes.end(), i) == axes.end())
                {
467
468
469
470
471
472
473
474
475
                    new_lens.push_back(old_lens[i]);
                }
            }
        }
        return shape{type, new_lens};
    }
    argument compute(context&, shape output_shape, std::vector<argument> args) const
    {
        return {std::move(output_shape), std::move(args.front().data)};
Scott Thornton's avatar
Scott Thornton committed
476
    }
Paul's avatar
Paul committed
477
    int output_alias(const std::vector<shape>&) const { return 0; }
478
479
480
481
482
};

struct unsqueeze
{
    std::vector<int64_t> axes;
Paul's avatar
Paul committed
483
484
485
486

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
487
        return pack(f(self.axes, "axes"));
Paul's avatar
Paul committed
488
489
    }

490
491
492
    std::string name() const { return "unsqueeze"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Scott Thornton's avatar
Scott Thornton committed
493
494
495
        auto input_shape     = inputs[0];
        auto type            = input_shape.type();
        auto old_lens        = input_shape.lens();
496
497
498
        std::size_t new_size = old_lens.size() + axes.size();
        std::vector<std::size_t> new_lens(new_size);
        std::size_t p = 0;
Scott Thornton's avatar
Scott Thornton committed
499
500
501
502
        for(std::size_t i = 0; i < new_size; i++)
        {
            if(std::find(axes.begin(), axes.end(), i) != axes.end())
            {
503
                new_lens[i] = 1;
Scott Thornton's avatar
Scott Thornton committed
504
505
506
            }
            else
            {
507
508
509
510
511
512
513
514
515
                new_lens[i] = old_lens[p++];
            }
        }
        return shape{type, new_lens};
    }
    argument compute(context&, shape output_shape, std::vector<argument> args) const
    {
        return {std::move(output_shape), std::move(args.front().data)};
    }
Paul's avatar
Paul committed
516
    int output_alias(const std::vector<shape>&) const { return 0; }
517
518
};

Paul's avatar
Paul committed
519
520
521
struct reshape
{
    std::vector<int64_t> dims;
Paul's avatar
Paul committed
522
523
524
525

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
526
        return pack(f(self.dims, "dims"));
Paul's avatar
Paul committed
527
528
    }

Paul's avatar
Paul committed
529
    std::string name() const { return "reshape"; }
Paul's avatar
Paul committed
530
531
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
532
        check_shapes{inputs, *this}.has(1);
Paul's avatar
Paul committed
533
534
        auto&& idims = inputs.front().lens();
        std::vector<std::size_t> rdims(dims.begin(), dims.end());
535
536
537
        auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
        if(n_neg_dims > 1)
            MIGRAPH_THROW("Dimensions for reshape can only have one -1 dim");
Paul's avatar
Paul committed
538
        for(std::size_t i = 0; i < dims.size(); i++)
Paul's avatar
Paul committed
539
540
541
542
        {
            if(dims[i] == 0)
                rdims[i] = idims[i];
        }
543
544
545
546
547
548
549
550
551
552
553
        if(n_neg_dims > 0)
        {
            size_t missing_dim =
                -inputs.front().elements() /
                std::accumulate(rdims.begin(), rdims.end(), 1, std::multiplies<int64_t>());
            for(std::size_t i = 0; i < rdims.size(); i++)
            {
                if(dims[i] == -1)
                    rdims[i] = missing_dim;
            }
        }
Paul's avatar
Paul committed
554
555
556
        if(dims.back() == -1)
        {
            rdims.pop_back();
Paul's avatar
Paul committed
557
            std::copy(idims.begin() + rdims.size(), idims.end(), std::back_inserter(rdims));
Paul's avatar
Paul committed
558
        }
Scott Thornton's avatar
Scott Thornton committed
559
        shape s{inputs.front().type(), rdims};
Paul's avatar
Paul committed
560
        if(s.elements() != inputs.front().elements())
Paul's avatar
Paul committed
561
            MIGRAPH_THROW("Wrong number of elements for reshape");
Scott Thornton's avatar
Scott Thornton committed
562
        return s;
Paul's avatar
Paul committed
563
    }
Paul's avatar
Paul committed
564
    argument compute(context&, shape output_shape, std::vector<argument> args) const
Paul's avatar
Paul committed
565
    {
Paul's avatar
Paul committed
566
        return {std::move(output_shape), std::move(args.front().data)};
Paul's avatar
Paul committed
567
    }
Paul's avatar
Paul committed
568
    int output_alias(const std::vector<shape>&) const { return 0; }
Paul's avatar
Paul committed
569
570
};

Shucai Xiao's avatar
Shucai Xiao committed
571
struct dot
572
{
Paul's avatar
Paul committed
573
    float alpha = 1.0;
Paul's avatar
Paul committed
574
    float beta  = 0.0;
Paul's avatar
Paul committed
575
576
577
578

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
579
        return pack(f(self.alpha, "alpha"), f(self.beta, "beta"));
Paul's avatar
Paul committed
580
581
    }

Shucai Xiao's avatar
Shucai Xiao committed
582
    std::string name() const { return "dot"; }
583
584
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
585
        check_shapes{inputs, *this}.has(2).same_type();
586
587
        const shape& a = inputs.at(0);
        const shape& b = inputs.at(1);
Scott Thornton's avatar
Scott Thornton committed
588
        auto t         = a.type();
589

590
        if(a.lens()[1] != b.lens()[0])
Paul's avatar
Paul committed
591
592
            MIGRAPH_THROW("Inner dimensions do not match: {" + to_string_range(a.lens()) + "} x {" +
                          to_string_range(b.lens()) + "}");
Scott Thornton's avatar
Scott Thornton committed
593
        return {t, {a.lens()[0], b.lens()[1]}};
594
595
596
    }
};

597
struct unary
Scott Thornton's avatar
Scott Thornton committed
598
{
599
600
    shape compute_shape(std::vector<shape> inputs) const
    {
601
602
        check_shapes{inputs}.has(1);
        return inputs.at(0);
603
    }
Scott Thornton's avatar
Scott Thornton committed
604
605
};

606
struct identity
607
{
608
    std::string name() const { return "identity"; }
Scott Thornton's avatar
Scott Thornton committed
609
    shape compute_shape(std::vector<shape> inputs) const { return inputs.at(0); }
610
611
612
613
    argument compute(context&, shape output_shape, std::vector<argument> args) const
    {
        return {std::move(output_shape), std::move(args.at(0).data)};
    }
614
615
616
};

struct abs : unary
Scott Thornton's avatar
Scott Thornton committed
617
{
618
    std::string name() const { return "abs"; }
Scott Thornton's avatar
Scott Thornton committed
619
620
};

621
struct exp : unary
Scott Thornton's avatar
Scott Thornton committed
622
{
623
    std::string name() const { return "exp"; }
Scott Thornton's avatar
Scott Thornton committed
624
625
};

626
struct sin : unary
Scott Thornton's avatar
Scott Thornton committed
627
{
628
    std::string name() const { return "sin"; }
Scott Thornton's avatar
Scott Thornton committed
629
630
};

631
struct cos : unary
Scott Thornton's avatar
Scott Thornton committed
632
{
633
    std::string name() const { return "cos"; }
Scott Thornton's avatar
Scott Thornton committed
634
635
};

636
struct tan : unary
Scott Thornton's avatar
Scott Thornton committed
637
{
638
    std::string name() const { return "tan"; }
Scott Thornton's avatar
Scott Thornton committed
639
640
};

641
struct asin : unary
Scott Thornton's avatar
Scott Thornton committed
642
{
643
    std::string name() const { return "asin"; }
Scott Thornton's avatar
Scott Thornton committed
644
645
};

646
struct acos : unary
Scott Thornton's avatar
Scott Thornton committed
647
{
648
    std::string name() const { return "acos"; }
Scott Thornton's avatar
Scott Thornton committed
649
650
};

651
struct atan : unary
Scott Thornton's avatar
Scott Thornton committed
652
{
653
    std::string name() const { return "atan"; }
Scott Thornton's avatar
Scott Thornton committed
654
655
};

656
struct tanh : unary
Scott Thornton's avatar
Scott Thornton committed
657
{
658
    std::string name() const { return "tanh"; }
Scott Thornton's avatar
Scott Thornton committed
659
660
};

661
struct sigmoid : unary
Scott Thornton's avatar
Scott Thornton committed
662
{
663
    std::string name() const { return "sigmoid"; }
Scott Thornton's avatar
Scott Thornton committed
664
665
};

666
struct neg : unary
Scott Thornton's avatar
Scott Thornton committed
667
{
668
    std::string name() const { return "neg"; }
Scott Thornton's avatar
Scott Thornton committed
669
670
};

Khalique's avatar
Khalique committed
671
672
673
674
675
struct relu : unary
{
    std::string name() const { return "relu"; }
};

Paul's avatar
Paul committed
676
677
678
679
680
681
682
683
684
685
struct softmax
{
    std::string name() const { return "softmax"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs}.has(1).only_dims(4);
        return inputs.at(0);
    }
};

686
struct flatten
Scott Thornton's avatar
Scott Thornton committed
687
{
Paul's avatar
Paul committed
688
    uint64_t axis = 0;
Paul's avatar
Paul committed
689
690
691
692

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
693
        return pack(f(self.axis, "axis"));
Paul's avatar
Paul committed
694
695
    }

Scott Thornton's avatar
Scott Thornton committed
696
    std::string name() const { return "flatten"; }
Paul's avatar
Paul committed
697
698
699
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs}.has(1);
Paul's avatar
Paul committed
700
701
        auto&& lens = inputs.front().lens();

Paul's avatar
Paul committed
702
        if(axis > lens.size())
Paul's avatar
Paul committed
703
        {
Paul's avatar
Paul committed
704
            MIGRAPH_THROW("axis for flatten must be less than tensor rank");
Paul's avatar
Paul committed
705
        }
Paul's avatar
Paul committed
706
707
708
709
        auto x =
            std::accumulate(lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{});
        auto y =
            std::accumulate(lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<>{});
710
        return {inputs.at(0).type(), {x, y}};
Paul's avatar
Paul committed
711
712
713
    }
    argument compute(context&, shape output_shape, std::vector<argument> args) const
    {
Paul's avatar
Paul committed
714
        return {std::move(output_shape), std::move(args.front().data)};
Paul's avatar
Paul committed
715
    }
Paul's avatar
Paul committed
716
    int output_alias(const std::vector<shape>&) const { return 0; }
Scott Thornton's avatar
Scott Thornton committed
717
};
718
719
720
struct broadcast
{
    uint64_t axis = 0;
Paul's avatar
Paul committed
721
722
723
724

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
725
        return pack(f(self.axis, "axis"));
Paul's avatar
Paul committed
726
727
    }

Scott Thornton's avatar
Scott Thornton committed
728
    shape broadcast_shape;
729
730
731
    std::string name() const { return "broadcast"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Scott Thornton's avatar
Scott Thornton committed
732
733
        auto t     = inputs.at(0).type();
        auto input = inputs.at(0);
Paul's avatar
Paul committed
734

Scott Thornton's avatar
Scott Thornton committed
735
        std::vector<size_t> bcast_strides(broadcast_shape.lens().size(), 0);
736

Scott Thornton's avatar
Scott Thornton committed
737
738
739
        if(std::all_of(broadcast_shape.lens().cbegin(), broadcast_shape.lens().cend(), [&](auto x) {
               return x == 1;
           }))
740
        {
Scott Thornton's avatar
Scott Thornton committed
741
            if(axis != 0)
Paul's avatar
Paul committed
742
                MIGRAPH_THROW("when broadcasting tensor of size 1, axis should be 0");
Scott Thornton's avatar
Scott Thornton committed
743
            return {t, broadcast_shape.lens(), std::move(bcast_strides)};
744
745
746
        }
        else
        {
Scott Thornton's avatar
Scott Thornton committed
747
            assert(broadcast_shape.lens().size() - axis >= input.lens().size());
Scott Thornton's avatar
Scott Thornton committed
748
749
            if(!std::equal(
                   input.lens().begin(), input.lens().end(), broadcast_shape.lens().begin() + axis))
Paul's avatar
Paul committed
750
                MIGRAPH_THROW("when broadcasting success sizes must match");
Paul's avatar
Paul committed
751
            std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
Scott Thornton's avatar
Scott Thornton committed
752
            return {t, broadcast_shape.lens(), std::move(bcast_strides)};
753
754
        }
    }
Paul's avatar
Paul committed
755
    argument compute(context&, shape output_shape, std::vector<argument> args) const
Scott Thornton's avatar
Scott Thornton committed
756
    {
Scott Thornton's avatar
Scott Thornton committed
757
        return {std::move(output_shape), std::move(args.at(0).data)};
Scott Thornton's avatar
Scott Thornton committed
758
    }
Paul's avatar
Paul committed
759
    int output_alias(const std::vector<shape>&) const { return 0; }
760
761
};

Scott Thornton's avatar
Scott Thornton committed
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
struct multibroadcast
{
    std::vector<std::size_t> output_lens;
    std::string name() const { return "multibroadcast"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(1);
        auto t     = inputs.at(0).type();
        auto input = inputs.at(0);

        if (input.lens().size() <= 0)
            MIGRAPH_THROW("inputs dimensions should be > 0");

        if (input.lens().size() > output_lens.size())
            MIGRAPH_THROW("inputs dimensions should <= output size");

        std::vector<size_t> bcast_strides(output_lens.size(), 0);
        auto extra = output_lens.size()-input.lens().size();
        if (input.lens().size() < output_lens.size())
        {
            for (std::size_t i = output_lens.size()-1; i > 0; i--)
            {
                if (output_lens[i] == input.lens()[i-extra]) 
                {
                    bcast_strides[i] = input.strides()[i-extra];
                }
            }
        }
        else
        {
            for (std::size_t i = 0; i < input.lens().size(); i++)
            {
                if (output_lens[i] == input.lens()[i]) 
                {
                    bcast_strides[i] = input.strides()[i];
                }
            }
        }
        return {t, output_lens, bcast_strides};
    }
    argument compute(context&, shape output_shape, std::vector<argument> args) const
    {
        return {std::move(output_shape), std::move(args.at(0).data)};
    }
    int output_alias(const std::vector<shape>&) const { return 0; }
};

Khalique's avatar
Khalique committed
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
struct scalar
{
    shape scalar_bcast;

    std::string name() const { return "scalar"; }

    shape compute_shape(std::vector<shape> inputs) const
    {
        assert(check_shapes{inputs}.has(1).only_dims(1).size() == 1);
        auto t     = inputs.at(0).type();
        auto input = inputs.at(0);
        std::vector<std::size_t> strides(scalar_bcast.lens().size(), 0);
        return {t, scalar_bcast.lens(), strides};
    }

    argument compute(context&, shape output_shape, std::vector<argument> args) const
    {
        return {std::move(output_shape), std::move(args.at(0).data)};
    }
Paul's avatar
Paul committed
828
    int output_alias(const std::vector<shape>&) const { return 0; }
Khalique's avatar
Khalique committed
829
830
};

831
struct binary
Scott Thornton's avatar
Scott Thornton committed
832
{
833
834
    shape compute_shape(std::vector<shape> inputs) const
    {
835
836
        check_shapes{inputs}.has(2).same_type().same_dims();
        return inputs.at(0);
837
    }
Scott Thornton's avatar
Scott Thornton committed
838
839
};

840
841
842
843
844
845
struct add : binary
{
    std::string name() const { return "add"; }
};

struct sub : binary
Scott Thornton's avatar
Scott Thornton committed
846
847
848
849
{
    std::string name() const { return "sub"; }
};

850
struct mul : binary
Scott Thornton's avatar
Scott Thornton committed
851
852
853
854
{
    std::string name() const { return "mul"; }
};

855
struct div : binary
Scott Thornton's avatar
Scott Thornton committed
856
857
858
859
{
    std::string name() const { return "div"; }
};

Paul's avatar
Paul committed
860
861
862
863
struct load
{
    shape s;
    std::size_t offset = 0;
Paul's avatar
Paul committed
864
865
866
867

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
868
        return pack(f(self.s, "shape"), f(self.offset, "offset"));
Paul's avatar
Paul committed
869
870
    }

Paul's avatar
Paul committed
871
872
873
874
875
876
877
878
879
880
    std::string name() const { return "load"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
        check_shapes{inputs}.has(1);
        return s;
    }
    argument compute(context&, const shape&, const std::vector<argument>& args) const
    {
        return {s, args[0].data() + offset};
    }
Paul's avatar
Paul committed
881
    int output_alias(const std::vector<shape>&) const { return 0; }
Paul's avatar
Paul committed
882
883
};

Paul's avatar
Paul committed
884
struct outline
Scott Thornton's avatar
Scott Thornton committed
885
{
Paul's avatar
Paul committed
886
    shape s;
Paul's avatar
Paul committed
887
888
889
890

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
Paul's avatar
Paul committed
891
        return pack(f(self.s, "shape"));
Paul's avatar
Paul committed
892
893
    }

Paul's avatar
Paul committed
894
    std::string name() const { return "outline"; }
Paul's avatar
Paul committed
895
    shape compute_shape(const std::vector<shape>& inputs) const
Paul's avatar
Paul committed
896
    {
Paul's avatar
Paul committed
897
        check_shapes{inputs, *this}.has(0);
Paul's avatar
Paul committed
898
899
        return s;
    }
Paul's avatar
Paul committed
900
901
902
903
    argument compute(context&, const shape&, const std::vector<argument>&) const
    {
        return {s, nullptr};
    }
Scott Thornton's avatar
Scott Thornton committed
904
905
};

906
} // namespace op
Paul's avatar
Paul committed
907
} // namespace migraph
Paul's avatar
Paul committed
908
909

#endif