operators.hpp 20.9 KB
Newer Older
Paul's avatar
Paul committed
1
2
#ifndef MIGRAPH_GUARD_OPERATORS_HPP
#define MIGRAPH_GUARD_OPERATORS_HPP
Paul's avatar
Paul committed
3

4
#include <array>
Paul's avatar
Paul committed
5
#include <migraph/operation.hpp>
Paul's avatar
Paul committed
6
#include <migraph/check_shapes.hpp>
Paul's avatar
Paul committed
7
8
#include <migraph/stringutils.hpp>
#include <migraph/streamutils.hpp>
Paul's avatar
Paul committed
9
#include <cmath>
Paul's avatar
Paul committed
10
#include <utility>
Paul's avatar
Paul committed
11

Paul's avatar
Paul committed
12
namespace migraph {
Paul's avatar
Paul committed
13

Paul's avatar
Paul committed
14
15
struct not_computable
{
Paul's avatar
Paul committed
16
    argument compute(context&, const shape&, const std::vector<argument>&) const
Paul's avatar
Paul committed
17
18
19
    {
        MIGRAPH_THROW("not computable");
    }
Paul's avatar
Paul committed
20
21
};

22
23
struct batch_norm_inference
{
24
25
    float epsilon  = 1.0e-6f;
    float momentum = 0.9f;
26
27
28

    std::string name() const { return "batch_norm_inference"; }

29
30
31
32
33
34
35
36
    enum bn_infer_mode_t
    {
        per_activation,
        spatial,
    };

    bn_infer_mode_t bn_mode = spatial;

37
38
    bool is_test = false;

39
40
41
42
43
44
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(5);
        return inputs.front();
    }

Paul's avatar
Paul committed
45
    argument compute(context&, const shape&, const std::vector<argument>&) const
46
47
48
49
50
    {
        MIGRAPH_THROW("not computable");
    }
};

Paul's avatar
Paul committed
51
struct convolution
Paul's avatar
Paul committed
52
{
Paul's avatar
Paul committed
53
54
55
    std::array<std::size_t, 2> padding  = {{0, 0}};
    std::array<std::size_t, 2> stride   = {{1, 1}};
    std::array<std::size_t, 2> dilation = {{1, 1}};
Paul's avatar
Paul committed
56
57
58
59
60
61
62
    enum padding_mode_t
    {
        default_, // NOLINT
        same,
        valid
    };
    padding_mode_t padding_mode = default_;
Paul's avatar
Paul committed
63
    std::string name() const { return "convolution"; }
Paul's avatar
Paul committed
64
65
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
66
        check_shapes{inputs, *this}.has(2).same_type().same_ndims().only_dims(4);
Paul's avatar
Paul committed
67

Paul's avatar
Paul committed
68
        const shape& input   = inputs.at(0);
Paul's avatar
Paul committed
69
        const shape& weights = inputs.at(1);
Paul's avatar
Paul committed
70
        auto t               = input.type();
Paul's avatar
Paul committed
71
72
        if(padding_mode == default_)
        {
Paul's avatar
Paul committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
            return {t,
                    {
                        input.lens()[0],
                        weights.lens()[0],
                        std::size_t(std::max<std::ptrdiff_t>(
                            1,
                            (input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) +
                             2 * padding[0]) /
                                    stride[0] +
                                1)),
                        std::size_t(std::max<std::ptrdiff_t>(
                            1,
                            (input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) +
                             2 * padding[1]) /
                                    stride[1] +
                                1)),
                    }};
Paul's avatar
Paul committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        }
        else if(padding_mode == same)
        {
            return {t,
                    {input.lens()[0],
                     weights.lens()[0],
                     static_cast<std::size_t>(
                         std::ceil(static_cast<double>(input.lens()[2]) / stride[0])),
                     static_cast<std::size_t>(
                         std::ceil(static_cast<double>(input.lens()[3]) / stride[1]))}};
        }
        else if(padding_mode == valid)
        {
            return {
                t,
                {input.lens()[0],
                 weights.lens()[0],
                 static_cast<std::size_t>(std::ceil(
                     static_cast<double>(input.lens()[2] - weights.lens()[2] + 1) / stride[0])),
                 static_cast<std::size_t>(std::ceil(
                     static_cast<double>(input.lens()[3] - weights.lens()[3] + 1) / stride[1]))}};
        }
        else
        {
Paul's avatar
Paul committed
114
            MIGRAPH_THROW("Invalid padding mode");
Paul's avatar
Paul committed
115
        }
Paul's avatar
Paul committed
116
    }
Paul's avatar
Paul committed
117

Paul's avatar
Paul committed
118
    argument compute(context&, const shape&, const std::vector<argument>&) const
Paul's avatar
Paul committed
119
120
121
    {
        MIGRAPH_THROW("not computable");
    }
Paul's avatar
Paul committed
122

Paul's avatar
Paul committed
123
    friend std::ostream& operator<<(std::ostream& os, const convolution& op)
Paul's avatar
Paul committed
124
    {
Paul's avatar
Paul committed
125
126
127
128
129
        os << op.name() << "[";
        os << "padding={" << stream_range(op.padding) << "}, ";
        os << "stride={" << stream_range(op.stride) << "}, ";
        os << "dilation={" << stream_range(op.dilation) << "}";
        os << "]";
Paul's avatar
Paul committed
130
131
        return os;
    }
Paul's avatar
Paul committed
132
133
};

Scott Thornton's avatar
Scott Thornton committed
134
135
struct im2col
{
Scott Thornton's avatar
Scott Thornton committed
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    std::array<std::size_t, 2> padding  = {{0, 0}};
    std::array<std::size_t, 2> stride   = {{1, 1}};
    std::array<std::size_t, 2> dilation = {{1, 1}};
    enum padding_mode_t
    {
        default_, // NOLINT
        same,
        valid
    };

    std::string name() const { return "im2col"; }

    shape compute_shape(std::vector<shape> inputs) const
    {
Scott Thornton's avatar
Scott Thornton committed
150
151
152
        auto input          = inputs[0];
        auto weights        = inputs[1];
        auto batch_size     = input.lens()[0];
Scott Thornton's avatar
Scott Thornton committed
153
        auto input_channels = weights.lens()[1];
Scott Thornton's avatar
Scott Thornton committed
154
155
        auto kernel_height  = weights.lens()[2];
        auto kernel_width   = weights.lens()[3];
Scott Thornton's avatar
Scott Thornton committed
156
        check_shapes{inputs, *this}.has(2);
Scott Thornton's avatar
Scott Thornton committed
157
158
        if(batch_size != 1)
            MIGRAPH_THROW("im2col only support batch_size 1");
Scott Thornton's avatar
Scott Thornton committed
159
        auto output_height = std::size_t(std::max<std::ptrdiff_t>(
Scott Thornton's avatar
Scott Thornton committed
160
161
162
            1,
            (input.lens()[2] - (1 + dilation[0] * (kernel_height - 1)) + 2 * padding[0]) /
                    stride[0] +
Scott Thornton's avatar
Scott Thornton committed
163
                1));
Scott Thornton's avatar
Scott Thornton committed
164
165
166
167
        auto output_width  = std::size_t(std::max<std::ptrdiff_t>(
            1,
            (input.lens()[3] - (1 + dilation[1] * (kernel_width - 1)) + 2 * padding[1]) /
                    stride[1] +
Scott Thornton's avatar
Scott Thornton committed
168
                1));
Scott Thornton's avatar
Scott Thornton committed
169
170
        auto channels_col  = kernel_height * kernel_width * input_channels;
        return {input.type(), {output_height * output_width, channels_col}};
Scott Thornton's avatar
Scott Thornton committed
171
172
173
174
175
176
177
178
    }

    argument compute(context&, const shape&, const std::vector<argument>&) const
    {
        MIGRAPH_THROW("not computable");
    }
};

Paul's avatar
Paul committed
179
struct pooling
Paul's avatar
Paul committed
180
{
Paul's avatar
Paul committed
181
    std::string mode                   = "average";
Paul's avatar
Paul committed
182
183
184
    std::array<std::size_t, 2> padding = {{0, 0}};
    std::array<std::size_t, 2> stride  = {{1, 1}};
    std::array<std::size_t, 2> lengths = {{1, 1}};
Paul's avatar
Paul committed
185
    std::string name() const { return "pooling"; }
Scott Thornton's avatar
Scott Thornton committed
186

Paul's avatar
Paul committed
187
188
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
189
        check_shapes{inputs, *this}.has(1).only_dims(4);
Paul's avatar
Paul committed
190

Paul's avatar
Paul committed
191
        const shape& input = inputs.at(0);
Paul's avatar
Paul committed
192
        auto t             = input.type();
Paul's avatar
Paul committed
193

Paul's avatar
Paul committed
194
195
        assert(lengths[0] <= (input.lens()[2] + 2 * padding[0]));
        assert(lengths[1] <= (input.lens()[3] + 2 * padding[1]));
Paul's avatar
Paul committed
196

Scott Thornton's avatar
Scott Thornton committed
197
198
199
200
201
202
        return {t,
                {
                    input.lens()[0],
                    input.lens()[1],
                    std::size_t(std::max<std::ptrdiff_t>(
                        1,
Paul's avatar
Paul committed
203
                        std::ptrdiff_t(std::floor((input.lens()[2] + 2 * padding[0] - lengths[0]) /
Paul's avatar
Paul committed
204
                                                  static_cast<float>(stride[0]))) +
Scott Thornton's avatar
Scott Thornton committed
205
206
207
                            1)),
                    std::size_t(std::max<std::ptrdiff_t>(
                        1,
Paul's avatar
Paul committed
208
                        std::ptrdiff_t(std::floor((input.lens()[3] + 2 * padding[1] - lengths[1]) /
Paul's avatar
Paul committed
209
                                                  static_cast<float>(stride[1]))) +
Scott Thornton's avatar
Scott Thornton committed
210
211
                            1)),
                }};
Paul's avatar
Paul committed
212
    }
Paul's avatar
Paul committed
213

Paul's avatar
Paul committed
214
    argument compute(context&, const shape&, const std::vector<argument>&) const
Paul's avatar
Paul committed
215
216
217
    {
        MIGRAPH_THROW("not computable");
    }
Paul's avatar
Paul committed
218

Paul's avatar
Paul committed
219
    friend std::ostream& operator<<(std::ostream& os, const pooling& op)
Paul's avatar
Paul committed
220
    {
Paul's avatar
Paul committed
221
222
223
224
225
        os << op.name() << "[";
        os << "padding={" << stream_range(op.padding) << "}, ";
        os << "stride={" << stream_range(op.stride) << "}, ";
        os << "lengths={" << stream_range(op.lengths) << "}";
        os << "]";
Paul's avatar
Paul committed
226
227
        return os;
    }
Paul's avatar
Paul committed
228
229
};

Paul's avatar
Paul committed
230
struct activation
Paul's avatar
Paul committed
231
232
{
    std::string mode;
Paul's avatar
Paul committed
233
    std::string name() const { return "activation"; }
Paul's avatar
Paul committed
234
235
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
236
        check_shapes{inputs, *this}.has(1);
Paul's avatar
Paul committed
237
238
        return inputs.front();
    }
Paul's avatar
Paul committed
239

Paul's avatar
Paul committed
240
    argument compute(context&, const shape&, const std::vector<argument>&) const
Paul's avatar
Paul committed
241
242
243
    {
        MIGRAPH_THROW("not computable");
    }
Paul's avatar
Paul committed
244
    friend std::ostream& operator<<(std::ostream& os, const activation& op)
Paul's avatar
Paul committed
245
    {
Paul's avatar
Paul committed
246
        os << op.name() << ":" << op.mode;
Paul's avatar
Paul committed
247
248
        return os;
    }
Paul's avatar
Paul committed
249
250
};

251
252
253
254
255
256
struct transpose
{
    std::vector<int64_t> dims;
    std::string name() const { return "transpose"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
257
        check_shapes{inputs, *this}.has(1);
258
        auto input         = inputs.at(0);
259
        auto input_lens    = input.lens();
260
261
        auto input_strides = input.strides();
        auto t             = input.type();
Paul's avatar
Paul committed
262
263
        if(dims.size() != input_lens.size())
        {
Paul's avatar
Paul committed
264
            MIGRAPH_THROW("Permutation has wrong number of axes");
265
266
267
        }
        std::vector<int64_t> axes(dims.size());
        std::iota(axes.begin(), axes.end(), 0);
Paul's avatar
Paul committed
268
269
        if(!std::is_permutation(axes.begin(), axes.end(), dims.begin()))
        {
Paul's avatar
Paul committed
270
            MIGRAPH_THROW("Invalid permutation");
271
        }
272
273
        std::vector<size_t> output_lens(input_lens.size());
        std::vector<size_t> output_strides(input_lens.size());
Paul's avatar
Paul committed
274
275
276
        for(int i = 0; i < output_lens.size(); i++)
        {
            output_lens[i]    = input_lens[dims[i]];
277
278
            output_strides[i] = input_strides[dims[i]];
        }
279
        return {t, output_lens, output_strides};
280
    }
Paul's avatar
Paul committed
281
    argument compute(context&, shape output_shape, std::vector<argument> args) const
Paul's avatar
Paul committed
282
    {
Paul's avatar
Paul committed
283
        return {std::move(output_shape), std::move(args.front().data)};
Paul's avatar
Paul committed
284
    }
Paul's avatar
Paul committed
285
286
287
288
289
290
291
    friend std::ostream& operator<<(std::ostream& os, const transpose& op)
    {
        os << op.name() << "[";
        os << "dims={" << stream_range(op.dims) << "}";
        os << "]";
        return os;
    }
292
293
};

Paul's avatar
Paul committed
294
struct contiguous
295
296
297
298
{
    std::string name() const { return "contiguous"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
299
        check_shapes{inputs, *this}.has(1);
Paul's avatar
Paul committed
300
301
302
303
        auto lens = inputs.at(0).lens();
        auto t    = inputs.at(0).type();
        if(lens.size() < 2)
        {
Paul's avatar
Paul committed
304
            MIGRAPH_THROW("Number of dimensions should exceed 1");
305
306
307
        }
        return {t, lens};
    }
Paul's avatar
Paul committed
308
    argument compute(context&, const shape&, const std::vector<argument>&) const
Paul's avatar
Paul committed
309
310
311
    {
        MIGRAPH_THROW("not computable");
    }
312
313
};

314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
struct slice
{
    std::vector<int64_t> axes;
    std::vector<int64_t> starts;
    std::vector<int64_t> ends;
    std::string name() const { return "slice"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        auto input_shape = inputs[0];
        auto t = input_shape.type();
        auto old_lens = input_shape.lens();
        auto old_strides = input_shape.strides();
        std::vector<int64_t> t_axes(old_lens.size());
        if (axes.size() == 0) {
            std::iota(t_axes.begin(), t_axes.end(), 0);
        }
        else {
            std::copy(axes.begin(), axes.end(), t_axes.begin());
        }
        if (starts.size() || t_axes.size() != ends.size()) {
            MIGRAPH_THROW("inconsistent sizes");
        }
        std::vector<std::size_t> new_lens;
        std::copy(old_lens.begin(), old_lens.end(), new_lens.begin());
        auto fix_index = [&] (std::size_t axis, int64_t index) {
            auto r = std::min(index, static_cast<int64_t>(old_lens[axis]-1));
            if (r < 0) r+= old_lens[axis];
            return r;
        };
        for (std::size_t i = 0; i < t_axes.size(); i++) {
            auto axis = t_axes[i];
            new_lens[axis] = fix_index(axis, ends[i]) - fix_index(axis, starts[i]); 
        }
        return shape{t, new_lens, old_strides};
    }
    argument compute(context&, shape output_shape, std::vector<argument> args) const
    {
        return {std::move(output_shape), std::move(args.front().data)};
    }
};

struct squeeze
{
    std::vector<int64_t> axes;
    std::string name() const { return "squeeze"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        auto input_shape = inputs[0];
        auto type = input_shape.type();
        auto old_lens = input_shape.lens();
        for (auto axis : axes) {
            if (input_shape.lens()[axis] != 1) {
                MIGRAPH_THROW("squeeze axis dimension should be equal to 1");
            }
        }
        std::vector<std::size_t> new_lens;
        if (axes.size() == 0) {
            for (std::size_t i = 0; i < old_lens.size(); i++) {
                if (old_lens[i] != 1)
                    new_lens.push_back(old_lens[i]);
            }
        }
        else {
            for (std::size_t i = 0; i < old_lens.size(); i++) {
                if (std::find(axes.begin(), axes.end(), i)
                   == axes.end()) {
                    new_lens.push_back(old_lens[i]);
                }
            }
        }
        return shape{type, new_lens};
    }
    argument compute(context&, shape output_shape, std::vector<argument> args) const
    {
        return {std::move(output_shape), std::move(args.front().data)};
    }    
};

struct unsqueeze
{
    std::vector<int64_t> axes;
    std::string name() const { return "unsqueeze"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        auto input_shape = inputs[0];
        auto type = input_shape.type();
        auto old_lens = input_shape.lens();
        std::size_t new_size = old_lens.size() + axes.size();
        std::vector<std::size_t> new_lens(new_size);
        std::size_t p = 0;
        for (std::size_t i = 0; i < new_size; i++) {
            if (std::find(axes.begin(), axes.end(), i) != axes.end()) {
                new_lens[i] = 1;
            } else {
                new_lens[i] = old_lens[p++];
            }
        }
        return shape{type, new_lens};
    }
    argument compute(context&, shape output_shape, std::vector<argument> args) const
    {
        return {std::move(output_shape), std::move(args.front().data)};
    }
};

Paul's avatar
Paul committed
419
420
421
struct reshape
{
    std::vector<int64_t> dims;
Paul's avatar
Paul committed
422
    std::string name() const { return "reshape"; }
Paul's avatar
Paul committed
423
424
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
425
        check_shapes{inputs, *this}.has(1);
Paul's avatar
Paul committed
426
427
        auto&& idims = inputs.front().lens();
        std::vector<std::size_t> rdims(dims.begin(), dims.end());
428
429
430
        auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
        if(n_neg_dims > 1)
            MIGRAPH_THROW("Dimensions for reshape can only have one -1 dim");
Paul's avatar
Paul committed
431
        for(std::size_t i = 0; i < dims.size(); i++)
Paul's avatar
Paul committed
432
433
434
435
        {
            if(dims[i] == 0)
                rdims[i] = idims[i];
        }
436
437
438
439
440
441
442
443
444
445
446
        if(n_neg_dims > 0)
        {
            size_t missing_dim =
                -inputs.front().elements() /
                std::accumulate(rdims.begin(), rdims.end(), 1, std::multiplies<int64_t>());
            for(std::size_t i = 0; i < rdims.size(); i++)
            {
                if(dims[i] == -1)
                    rdims[i] = missing_dim;
            }
        }
Paul's avatar
Paul committed
447
448
449
        if(dims.back() == -1)
        {
            rdims.pop_back();
Paul's avatar
Paul committed
450
            std::copy(idims.begin() + rdims.size(), idims.end(), std::back_inserter(rdims));
Paul's avatar
Paul committed
451
        }
Scott Thornton's avatar
Scott Thornton committed
452
        shape s{inputs.front().type(), rdims};
Paul's avatar
Paul committed
453
        if(s.elements() != inputs.front().elements())
Paul's avatar
Paul committed
454
            MIGRAPH_THROW("Wrong number of elements for reshape");
Scott Thornton's avatar
Scott Thornton committed
455
        return s;
Paul's avatar
Paul committed
456
457
    }

Paul's avatar
Paul committed
458
    argument compute(context&, shape output_shape, std::vector<argument> args) const
Paul's avatar
Paul committed
459
    {
Paul's avatar
Paul committed
460
        return {std::move(output_shape), std::move(args.front().data)};
Paul's avatar
Paul committed
461
    }
Paul's avatar
Paul committed
462

Paul's avatar
Paul committed
463
    friend std::ostream& operator<<(std::ostream& os, const reshape& op)
Paul's avatar
Paul committed
464
    {
Paul's avatar
Paul committed
465
        os << op.name() << "[";
Paul's avatar
Paul committed
466
        os << "dims={" << stream_range(op.dims) << "}";
Paul's avatar
Paul committed
467
        os << "]";
Paul's avatar
Paul committed
468
469
        return os;
    }
Paul's avatar
Paul committed
470
471
};

472
473
struct gemm
{
Paul's avatar
Paul committed
474
    float alpha = 1.0;
Paul's avatar
Paul committed
475
    float beta  = 0.0;
476
    std::string name() const { return "gemm"; }
477
478
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
479
        check_shapes{inputs, *this}.has(2).same_type();
480
481
        const shape& a = inputs.at(0);
        const shape& b = inputs.at(1);
Scott Thornton's avatar
Scott Thornton committed
482
        auto t         = a.type();
483

484
        if(a.lens()[1] != b.lens()[0])
Paul's avatar
Paul committed
485
486
            MIGRAPH_THROW("Inner dimensions do not match: {" + to_string_range(a.lens()) + "} x {" +
                          to_string_range(b.lens()) + "}");
Scott Thornton's avatar
Scott Thornton committed
487
        return {t, {a.lens()[0], b.lens()[1]}};
488
    }
489

Paul's avatar
Paul committed
490
    argument compute(context&, const shape&, const std::vector<argument>&) const
Paul's avatar
Paul committed
491
492
493
    {
        MIGRAPH_THROW("not computable");
    }
494
495

    friend std::ostream& operator<<(std::ostream& os, const gemm& op)
496
497
    {
        os << op.name() << "[";
498
        os << "]";
Scott Thornton's avatar
Scott Thornton committed
499
        return os;
500
501
502
    }
};

503
struct unary
Scott Thornton's avatar
Scott Thornton committed
504
{
505
506
    shape compute_shape(std::vector<shape> inputs) const
    {
507
508
        check_shapes{inputs}.has(1);
        return inputs.at(0);
509
    }
Paul's avatar
Paul committed
510
    argument compute(context&, const shape&, const std::vector<argument>&) const
Paul's avatar
Paul committed
511
512
513
    {
        MIGRAPH_THROW("not computable");
    }
Scott Thornton's avatar
Scott Thornton committed
514
515
};

516
517
struct identity : unary
{
518
    std::string name() const { return "identity"; }
519
520
521
};

struct abs : unary
Scott Thornton's avatar
Scott Thornton committed
522
{
523
    std::string name() const { return "abs"; }
Scott Thornton's avatar
Scott Thornton committed
524
525
};

526
struct exp : unary
Scott Thornton's avatar
Scott Thornton committed
527
{
528
    std::string name() const { return "exp"; }
Scott Thornton's avatar
Scott Thornton committed
529
530
};

531
struct sin : unary
Scott Thornton's avatar
Scott Thornton committed
532
{
533
    std::string name() const { return "sin"; }
Scott Thornton's avatar
Scott Thornton committed
534
535
};

536
struct cos : unary
Scott Thornton's avatar
Scott Thornton committed
537
{
538
    std::string name() const { return "cos"; }
Scott Thornton's avatar
Scott Thornton committed
539
540
};

541
struct tan : unary
Scott Thornton's avatar
Scott Thornton committed
542
{
543
    std::string name() const { return "tan"; }
Scott Thornton's avatar
Scott Thornton committed
544
545
};

546
struct asin : unary
Scott Thornton's avatar
Scott Thornton committed
547
{
548
    std::string name() const { return "asin"; }
Scott Thornton's avatar
Scott Thornton committed
549
550
};

551
struct acos : unary
Scott Thornton's avatar
Scott Thornton committed
552
{
553
    std::string name() const { return "acos"; }
Scott Thornton's avatar
Scott Thornton committed
554
555
};

556
struct atan : unary
Scott Thornton's avatar
Scott Thornton committed
557
{
558
    std::string name() const { return "atan"; }
Scott Thornton's avatar
Scott Thornton committed
559
560
};

561
struct softmax : unary
Scott Thornton's avatar
Scott Thornton committed
562
{
563
    std::string name() const { return "softmax"; }
Scott Thornton's avatar
Scott Thornton committed
564
565
};

566
struct tanh : unary
Scott Thornton's avatar
Scott Thornton committed
567
{
568
    std::string name() const { return "tanh"; }
Scott Thornton's avatar
Scott Thornton committed
569
570
};

571
struct sigmoid : unary
Scott Thornton's avatar
Scott Thornton committed
572
{
573
    std::string name() const { return "sigmoid"; }
Scott Thornton's avatar
Scott Thornton committed
574
575
};

576
struct neg : unary
Scott Thornton's avatar
Scott Thornton committed
577
{
578
    std::string name() const { return "neg"; }
Scott Thornton's avatar
Scott Thornton committed
579
580
};

581
struct flatten
Scott Thornton's avatar
Scott Thornton committed
582
{
Paul's avatar
Paul committed
583
    uint64_t axis = 0;
Scott Thornton's avatar
Scott Thornton committed
584
    std::string name() const { return "flatten"; }
Paul's avatar
Paul committed
585
586
587
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs}.has(1);
Paul's avatar
Paul committed
588
589
        auto&& lens = inputs.front().lens();

Paul's avatar
Paul committed
590
        if(axis > lens.size())
Paul's avatar
Paul committed
591
        {
Paul's avatar
Paul committed
592
            MIGRAPH_THROW("axis for flatten must be less than tensor rank");
Paul's avatar
Paul committed
593
        }
Paul's avatar
Paul committed
594
595
596
597
        auto x =
            std::accumulate(lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{});
        auto y =
            std::accumulate(lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<>{});
598
        return {inputs.at(0).type(), {x, y}};
Paul's avatar
Paul committed
599
600
601
    }
    argument compute(context&, shape output_shape, std::vector<argument> args) const
    {
Paul's avatar
Paul committed
602
        return {std::move(output_shape), std::move(args.front().data)};
Paul's avatar
Paul committed
603
    }
Paul's avatar
Paul committed
604
605
606
607
608
609
610
    friend std::ostream& operator<<(std::ostream& os, const flatten& op)
    {
        os << op.name() << "[";
        os << "axis=" << op.axis;
        os << "]";
        return os;
    }
Scott Thornton's avatar
Scott Thornton committed
611
};
612
613
614
615
616
617
struct broadcast
{
    uint64_t axis = 0;
    std::string name() const { return "broadcast"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
618
619
620
621
        auto t      = inputs.at(0).type();
        auto result = inputs.at(0);
        auto input  = inputs.at(1);

Paul's avatar
Paul committed
622
        std::vector<size_t> bcast_strides(result.lens().size(), 0);
623

Paul's avatar
Paul committed
624
625
        if(std::all_of(
               result.lens().cbegin(), result.lens().cend(), [&](auto x) { return x == 1; }))
626
        {
Scott Thornton's avatar
Scott Thornton committed
627
            if(axis != 0)
Paul's avatar
Paul committed
628
                MIGRAPH_THROW("when broadcasting tensor of size 1, axis should be 0");
Paul's avatar
Paul committed
629
            return {t, result.lens(), std::move(bcast_strides)};
630
631
632
        }
        else
        {
Paul's avatar
Paul committed
633
634
            assert(result.lens().size() - axis >= input.lens().size());
            if(!std::equal(input.lens().begin(), input.lens().end(), result.lens().begin() + axis))
Paul's avatar
Paul committed
635
                MIGRAPH_THROW("when broadcasting success sizes must match");
Paul's avatar
Paul committed
636
            std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
Paul's avatar
Paul committed
637
            return {t, result.lens(), std::move(bcast_strides)};
638
639
        }
    }
Paul's avatar
Paul committed
640
    argument compute(context&, shape output_shape, std::vector<argument> args) const
Scott Thornton's avatar
Scott Thornton committed
641
    {
Paul's avatar
Paul committed
642
        return {std::move(output_shape), std::move(args.at(1).data)};
Scott Thornton's avatar
Scott Thornton committed
643
    }
Paul's avatar
Paul committed
644
645
646
647
648
649
650
    friend std::ostream& operator<<(std::ostream& os, const broadcast& op)
    {
        os << op.name() << "[";
        os << "axis=" << op.axis;
        os << "]";
        return os;
    }
651
652
};

653
struct binary
Scott Thornton's avatar
Scott Thornton committed
654
{
655
    uint64_t broadcast = 0;
656
657
    shape compute_shape(std::vector<shape> inputs) const
    {
658
659
        check_shapes{inputs}.has(2).same_type().same_dims();
        return inputs.at(0);
660
    }
Paul's avatar
Paul committed
661
    argument compute(context&, const shape&, const std::vector<argument>&) const
Paul's avatar
Paul committed
662
663
664
    {
        MIGRAPH_THROW("not computable");
    }
Scott Thornton's avatar
Scott Thornton committed
665
666
};

667
668
669
670
671
672
struct add : binary
{
    std::string name() const { return "add"; }
};

struct sub : binary
Scott Thornton's avatar
Scott Thornton committed
673
674
675
676
{
    std::string name() const { return "sub"; }
};

677
struct mul : binary
Scott Thornton's avatar
Scott Thornton committed
678
679
680
681
{
    std::string name() const { return "mul"; }
};

682
struct div : binary
Scott Thornton's avatar
Scott Thornton committed
683
684
685
686
{
    std::string name() const { return "div"; }
};

Paul's avatar
Paul committed
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
struct load
{
    shape s;
    std::size_t offset = 0;
    std::string name() const { return "load"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
        check_shapes{inputs}.has(1);
        return s;
    }
    argument compute(context&, const shape&, const std::vector<argument>& args) const
    {
        return {s, args[0].data() + offset};
    }
};

Paul's avatar
Paul committed
703
struct outline
Scott Thornton's avatar
Scott Thornton committed
704
{
Paul's avatar
Paul committed
705
706
    shape s;
    std::string name() const { return "outline"; }
Paul's avatar
Paul committed
707
    shape compute_shape(const std::vector<shape>& inputs) const
Paul's avatar
Paul committed
708
    {
Paul's avatar
Paul committed
709
        check_shapes{inputs, *this}.has(0);
Paul's avatar
Paul committed
710
711
        return s;
    }
Paul's avatar
Paul committed
712
713
714
715
    argument compute(context&, const shape&, const std::vector<argument>&) const
    {
        return {s, nullptr};
    }
Scott Thornton's avatar
Scott Thornton committed
716
717
};

Paul's avatar
Paul committed
718
} // namespace migraph
Paul's avatar
Paul committed
719
720

#endif