operators.hpp 17 KB
Newer Older
Paul's avatar
Paul committed
1
2
#ifndef MIGRAPH_GUARD_OPERATORS_HPP
#define MIGRAPH_GUARD_OPERATORS_HPP
Paul's avatar
Paul committed
3

4
#include <array>
Paul's avatar
Paul committed
5
#include <migraph/operation.hpp>
Paul's avatar
Paul committed
6
#include <migraph/check_shapes.hpp>
Paul's avatar
Paul committed
7
8
#include <migraph/stringutils.hpp>
#include <migraph/streamutils.hpp>
Paul's avatar
Paul committed
9
#include <cmath>
Paul's avatar
Paul committed
10
#include <utility>
Paul's avatar
Paul committed
11

Paul's avatar
Paul committed
12
namespace migraph {
Paul's avatar
Paul committed
13

Paul's avatar
Paul committed
14
15
struct not_computable
{
Paul's avatar
Paul committed
16
    argument compute(context&, const shape&, const std::vector<argument>&) const
Paul's avatar
Paul committed
17
18
19
    {
        MIGRAPH_THROW("not computable");
    }
Paul's avatar
Paul committed
20
21
};

22
23
struct batch_norm_inference
{
24
25
    float epsilon  = 1.0e-6f;
    float momentum = 0.9f;
26
27
28

    std::string name() const { return "batch_norm_inference"; }

29
30
31
32
33
34
35
36
    enum bn_infer_mode_t
    {
        per_activation,
        spatial,
    };

    bn_infer_mode_t bn_mode = spatial;

37
38
    bool is_test = false;

39
40
41
42
43
44
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(5);
        return inputs.front();
    }

Paul's avatar
Paul committed
45
    argument compute(context&, const shape&, const std::vector<argument>&) const
46
47
48
49
50
    {
        MIGRAPH_THROW("not computable");
    }
};

Paul's avatar
Paul committed
51
struct convolution
Paul's avatar
Paul committed
52
{
Paul's avatar
Paul committed
53
54
55
    std::array<std::size_t, 2> padding  = {{0, 0}};
    std::array<std::size_t, 2> stride   = {{1, 1}};
    std::array<std::size_t, 2> dilation = {{1, 1}};
Paul's avatar
Paul committed
56
57
58
59
60
61
62
    enum padding_mode_t
    {
        default_, // NOLINT
        same,
        valid
    };
    padding_mode_t padding_mode = default_;
Paul's avatar
Paul committed
63
    std::string name() const { return "convolution"; }
Paul's avatar
Paul committed
64
65
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
66
        check_shapes{inputs, *this}.has(2).same_type().same_ndims().only_dims(4);
Paul's avatar
Paul committed
67

Paul's avatar
Paul committed
68
        const shape& input   = inputs.at(0);
Paul's avatar
Paul committed
69
        const shape& weights = inputs.at(1);
Paul's avatar
Paul committed
70
        auto t               = input.type();
Paul's avatar
Paul committed
71
72
        if(padding_mode == default_)
        {
Paul's avatar
Paul committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
            return {t,
                    {
                        input.lens()[0],
                        weights.lens()[0],
                        std::size_t(std::max<std::ptrdiff_t>(
                            1,
                            (input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) +
                             2 * padding[0]) /
                                    stride[0] +
                                1)),
                        std::size_t(std::max<std::ptrdiff_t>(
                            1,
                            (input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) +
                             2 * padding[1]) /
                                    stride[1] +
                                1)),
                    }};
Paul's avatar
Paul committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        }
        else if(padding_mode == same)
        {
            return {t,
                    {input.lens()[0],
                     weights.lens()[0],
                     static_cast<std::size_t>(
                         std::ceil(static_cast<double>(input.lens()[2]) / stride[0])),
                     static_cast<std::size_t>(
                         std::ceil(static_cast<double>(input.lens()[3]) / stride[1]))}};
        }
        else if(padding_mode == valid)
        {
            return {
                t,
                {input.lens()[0],
                 weights.lens()[0],
                 static_cast<std::size_t>(std::ceil(
                     static_cast<double>(input.lens()[2] - weights.lens()[2] + 1) / stride[0])),
                 static_cast<std::size_t>(std::ceil(
                     static_cast<double>(input.lens()[3] - weights.lens()[3] + 1) / stride[1]))}};
        }
        else
        {
Paul's avatar
Paul committed
114
            MIGRAPH_THROW("Invalid padding mode");
Paul's avatar
Paul committed
115
        }
Paul's avatar
Paul committed
116
    }
Paul's avatar
Paul committed
117

Paul's avatar
Paul committed
118
    argument compute(context&, const shape&, const std::vector<argument>&) const
Paul's avatar
Paul committed
119
120
121
    {
        MIGRAPH_THROW("not computable");
    }
Paul's avatar
Paul committed
122

Paul's avatar
Paul committed
123
    friend std::ostream& operator<<(std::ostream& os, const convolution& op)
Paul's avatar
Paul committed
124
    {
Paul's avatar
Paul committed
125
126
127
128
129
        os << op.name() << "[";
        os << "padding={" << stream_range(op.padding) << "}, ";
        os << "stride={" << stream_range(op.stride) << "}, ";
        os << "dilation={" << stream_range(op.dilation) << "}";
        os << "]";
Paul's avatar
Paul committed
130
131
        return os;
    }
Paul's avatar
Paul committed
132
133
};

Scott Thornton's avatar
Scott Thornton committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
struct im2col {
    std::array<std::size_t, 2> padding  = {{0, 0}};
    std::array<std::size_t, 2> stride   = {{1, 1}};
    std::array<std::size_t, 2> dilation = {{1, 1}};
    enum padding_mode_t
    {
        default_, // NOLINT
        same,
        valid
    };

    std::string name() const { return "im2col"; }

    shape compute_shape(std::vector<shape> inputs) const
    {
        auto input = inputs[0];
        auto weights = inputs[1];
        auto batch_size = input.lens()[0];
        auto input_channels = weights.lens()[1];
        auto kernel_height = weights.lens()[2];
        auto kernel_width = weights.lens()[3];
        check_shapes{inputs, *this}.has(2);
        if (batch_size != 1) MIGRAPH_THROW("im2col only support batch_size 1");
        auto output_height = std::size_t(std::max<std::ptrdiff_t>(
                1,
                (input.lens()[2] - (1 + dilation[0] * (kernel_height - 1)) +
                 2 * padding[0]) /
                stride[0] +
                1));
        auto output_width = std::size_t(std::max<std::ptrdiff_t>(
                1,
                (input.lens()[3] - (1 + dilation[1] * (kernel_width - 1)) +
                 2 * padding[1]) /
                stride[1] +
                1));
        auto channels_col = kernel_height*kernel_width*input_channels;
        return {input.type(), {output_height*output_width, channels_col}};
    }

    argument compute(context&, const shape&, const std::vector<argument>&) const
    {
        MIGRAPH_THROW("not computable");
    }
};


Paul's avatar
Paul committed
180
struct pooling
Paul's avatar
Paul committed
181
{
Paul's avatar
Paul committed
182
    std::string mode                   = "average";
Paul's avatar
Paul committed
183
184
185
    std::array<std::size_t, 2> padding = {{0, 0}};
    std::array<std::size_t, 2> stride  = {{1, 1}};
    std::array<std::size_t, 2> lengths = {{1, 1}};
Paul's avatar
Paul committed
186
    std::string name() const { return "pooling"; }
Scott Thornton's avatar
Scott Thornton committed
187

Paul's avatar
Paul committed
188
189
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
190
        check_shapes{inputs, *this}.has(1).only_dims(4);
Paul's avatar
Paul committed
191

Paul's avatar
Paul committed
192
        const shape& input = inputs.at(0);
Paul's avatar
Paul committed
193
        auto t             = input.type();
Paul's avatar
Paul committed
194

Paul's avatar
Paul committed
195
196
        assert(lengths[0] <= (input.lens()[2] + 2 * padding[0]));
        assert(lengths[1] <= (input.lens()[3] + 2 * padding[1]));
Paul's avatar
Paul committed
197

Scott Thornton's avatar
Scott Thornton committed
198
199
200
201
202
203
        return {t,
                {
                    input.lens()[0],
                    input.lens()[1],
                    std::size_t(std::max<std::ptrdiff_t>(
                        1,
Paul's avatar
Paul committed
204
                        std::ptrdiff_t(std::floor((input.lens()[2] + 2 * padding[0] - lengths[0]) /
Paul's avatar
Paul committed
205
                                                  static_cast<float>(stride[0]))) +
Scott Thornton's avatar
Scott Thornton committed
206
207
208
                            1)),
                    std::size_t(std::max<std::ptrdiff_t>(
                        1,
Paul's avatar
Paul committed
209
                        std::ptrdiff_t(std::floor((input.lens()[3] + 2 * padding[1] - lengths[1]) /
Paul's avatar
Paul committed
210
                                                  static_cast<float>(stride[1]))) +
Scott Thornton's avatar
Scott Thornton committed
211
212
                            1)),
                }};
Paul's avatar
Paul committed
213
    }
Paul's avatar
Paul committed
214

Paul's avatar
Paul committed
215
    argument compute(context&, const shape&, const std::vector<argument>&) const
Paul's avatar
Paul committed
216
217
218
    {
        MIGRAPH_THROW("not computable");
    }
Paul's avatar
Paul committed
219

Paul's avatar
Paul committed
220
    friend std::ostream& operator<<(std::ostream& os, const pooling& op)
Paul's avatar
Paul committed
221
    {
Paul's avatar
Paul committed
222
223
224
225
226
        os << op.name() << "[";
        os << "padding={" << stream_range(op.padding) << "}, ";
        os << "stride={" << stream_range(op.stride) << "}, ";
        os << "lengths={" << stream_range(op.lengths) << "}";
        os << "]";
Paul's avatar
Paul committed
227
228
        return os;
    }
Paul's avatar
Paul committed
229
230
};

Paul's avatar
Paul committed
231
struct activation
Paul's avatar
Paul committed
232
233
{
    std::string mode;
Paul's avatar
Paul committed
234
    std::string name() const { return "activation"; }
Paul's avatar
Paul committed
235
236
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
237
        check_shapes{inputs, *this}.has(1);
Paul's avatar
Paul committed
238
239
        return inputs.front();
    }
Paul's avatar
Paul committed
240

Paul's avatar
Paul committed
241
    argument compute(context&, const shape&, const std::vector<argument>&) const
Paul's avatar
Paul committed
242
243
244
    {
        MIGRAPH_THROW("not computable");
    }
Paul's avatar
Paul committed
245
    friend std::ostream& operator<<(std::ostream& os, const activation& op)
Paul's avatar
Paul committed
246
    {
Paul's avatar
Paul committed
247
        os << op.name() << ":" << op.mode;
Paul's avatar
Paul committed
248
249
        return os;
    }
Paul's avatar
Paul committed
250
251
};

252
253
254
255
256
257
struct transpose
{
    std::vector<int64_t> dims;
    std::string name() const { return "transpose"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
258
        check_shapes{inputs, *this}.has(1);
259
        auto input         = inputs.at(0);
260
        auto input_lens    = input.lens();
261
262
        auto input_strides = input.strides();
        auto t             = input.type();
Paul's avatar
Paul committed
263
264
        if(dims.size() != input_lens.size())
        {
Paul's avatar
Paul committed
265
            MIGRAPH_THROW("Permutation has wrong number of axes");
266
267
268
        }
        std::vector<int64_t> axes(dims.size());
        std::iota(axes.begin(), axes.end(), 0);
Paul's avatar
Paul committed
269
270
        if(!std::is_permutation(axes.begin(), axes.end(), dims.begin()))
        {
Paul's avatar
Paul committed
271
            MIGRAPH_THROW("Invalid permutation");
272
        }
273
274
        std::vector<size_t> output_lens(input_lens.size());
        std::vector<size_t> output_strides(input_lens.size());
Paul's avatar
Paul committed
275
276
277
        for(int i = 0; i < output_lens.size(); i++)
        {
            output_lens[i]    = input_lens[dims[i]];
278
279
            output_strides[i] = input_strides[dims[i]];
        }
280
        return {t, output_lens, output_strides};
281
    }
Paul's avatar
Paul committed
282
    argument compute(context&, shape output_shape, std::vector<argument> args) const
Paul's avatar
Paul committed
283
    {
Paul's avatar
Paul committed
284
        return {std::move(output_shape), std::move(args.front().data)};
Paul's avatar
Paul committed
285
    }
Paul's avatar
Paul committed
286
287
288
289
290
291
292
    friend std::ostream& operator<<(std::ostream& os, const transpose& op)
    {
        os << op.name() << "[";
        os << "dims={" << stream_range(op.dims) << "}";
        os << "]";
        return os;
    }
293
294
};

Paul's avatar
Paul committed
295
struct contiguous
296
297
298
299
{
    std::string name() const { return "contiguous"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
300
        check_shapes{inputs, *this}.has(1);
Paul's avatar
Paul committed
301
302
303
304
        auto lens = inputs.at(0).lens();
        auto t    = inputs.at(0).type();
        if(lens.size() < 2)
        {
Paul's avatar
Paul committed
305
            MIGRAPH_THROW("Number of dimensions should exceed 1");
306
307
308
        }
        return {t, lens};
    }
Paul's avatar
Paul committed
309
    argument compute(context&, const shape&, const std::vector<argument>&) const
Paul's avatar
Paul committed
310
311
312
    {
        MIGRAPH_THROW("not computable");
    }
313
314
};

Paul's avatar
Paul committed
315
316
317
struct reshape
{
    std::vector<int64_t> dims;
Paul's avatar
Paul committed
318
    std::string name() const { return "reshape"; }
Paul's avatar
Paul committed
319
320
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
321
        check_shapes{inputs, *this}.has(1);
Paul's avatar
Paul committed
322
323
        auto&& idims = inputs.front().lens();
        std::vector<std::size_t> rdims(dims.begin(), dims.end());
324
325
326
        auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
        if(n_neg_dims > 1)
            MIGRAPH_THROW("Dimensions for reshape can only have one -1 dim");
Paul's avatar
Paul committed
327
        for(std::size_t i = 0; i < dims.size(); i++)
Paul's avatar
Paul committed
328
329
330
331
        {
            if(dims[i] == 0)
                rdims[i] = idims[i];
        }
332
333
334
335
336
337
338
339
340
341
342
        if(n_neg_dims > 0)
        {
            size_t missing_dim =
                -inputs.front().elements() /
                std::accumulate(rdims.begin(), rdims.end(), 1, std::multiplies<int64_t>());
            for(std::size_t i = 0; i < rdims.size(); i++)
            {
                if(dims[i] == -1)
                    rdims[i] = missing_dim;
            }
        }
Paul's avatar
Paul committed
343
344
345
        if(dims.back() == -1)
        {
            rdims.pop_back();
Paul's avatar
Paul committed
346
            std::copy(idims.begin() + rdims.size(), idims.end(), std::back_inserter(rdims));
Paul's avatar
Paul committed
347
        }
Scott Thornton's avatar
Scott Thornton committed
348
        shape s{inputs.front().type(), rdims};
Paul's avatar
Paul committed
349
        if(s.elements() != inputs.front().elements())
Paul's avatar
Paul committed
350
            MIGRAPH_THROW("Wrong number of elements for reshape");
Scott Thornton's avatar
Scott Thornton committed
351
        return s;
Paul's avatar
Paul committed
352
353
    }

Paul's avatar
Paul committed
354
    argument compute(context&, shape output_shape, std::vector<argument> args) const
Paul's avatar
Paul committed
355
    {
Paul's avatar
Paul committed
356
        return {std::move(output_shape), std::move(args.front().data)};
Paul's avatar
Paul committed
357
    }
Paul's avatar
Paul committed
358

Paul's avatar
Paul committed
359
    friend std::ostream& operator<<(std::ostream& os, const reshape& op)
Paul's avatar
Paul committed
360
    {
Paul's avatar
Paul committed
361
        os << op.name() << "[";
Paul's avatar
Paul committed
362
        os << "dims={" << stream_range(op.dims) << "}";
Paul's avatar
Paul committed
363
        os << "]";
Paul's avatar
Paul committed
364
365
        return os;
    }
Paul's avatar
Paul committed
366
367
};

368
369
struct gemm
{
Paul's avatar
Paul committed
370
    float alpha = 1.0;
Paul's avatar
Paul committed
371
    float beta  = 0.0;
372
    std::string name() const { return "gemm"; }
373
374
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
375
        check_shapes{inputs, *this}.has(2).same_type();
376
377
        const shape& a = inputs.at(0);
        const shape& b = inputs.at(1);
Scott Thornton's avatar
Scott Thornton committed
378
        auto t         = a.type();
379

380
        if(a.lens()[1] != b.lens()[0])
Paul's avatar
Paul committed
381
382
            MIGRAPH_THROW("Inner dimensions do not match: {" + to_string_range(a.lens()) + "} x {" +
                          to_string_range(b.lens()) + "}");
Scott Thornton's avatar
Scott Thornton committed
383
        return {t, {a.lens()[0], b.lens()[1]}};
384
    }
385

Paul's avatar
Paul committed
386
    argument compute(context&, const shape&, const std::vector<argument>&) const
Paul's avatar
Paul committed
387
388
389
    {
        MIGRAPH_THROW("not computable");
    }
390
391

    friend std::ostream& operator<<(std::ostream& os, const gemm& op)
392
393
    {
        os << op.name() << "[";
394
        os << "]";
Scott Thornton's avatar
Scott Thornton committed
395
        return os;
396
397
398
    }
};

399
struct unary
Scott Thornton's avatar
Scott Thornton committed
400
{
401
402
    shape compute_shape(std::vector<shape> inputs) const
    {
403
404
        check_shapes{inputs}.has(1);
        return inputs.at(0);
405
    }
Paul's avatar
Paul committed
406
    argument compute(context&, const shape&, const std::vector<argument>&) const
Paul's avatar
Paul committed
407
408
409
    {
        MIGRAPH_THROW("not computable");
    }
Scott Thornton's avatar
Scott Thornton committed
410
411
};

412
413
struct identity : unary
{
414
    std::string name() const { return "identity"; }
415
416
417
};

struct abs : unary
Scott Thornton's avatar
Scott Thornton committed
418
{
419
    std::string name() const { return "abs"; }
Scott Thornton's avatar
Scott Thornton committed
420
421
};

422
struct exp : unary
Scott Thornton's avatar
Scott Thornton committed
423
{
424
    std::string name() const { return "exp"; }
Scott Thornton's avatar
Scott Thornton committed
425
426
};

427
struct sin : unary
Scott Thornton's avatar
Scott Thornton committed
428
{
429
    std::string name() const { return "sin"; }
Scott Thornton's avatar
Scott Thornton committed
430
431
};

432
struct cos : unary
Scott Thornton's avatar
Scott Thornton committed
433
{
434
    std::string name() const { return "cos"; }
Scott Thornton's avatar
Scott Thornton committed
435
436
};

437
struct tan : unary
Scott Thornton's avatar
Scott Thornton committed
438
{
439
    std::string name() const { return "tan"; }
Scott Thornton's avatar
Scott Thornton committed
440
441
};

442
struct asin : unary
Scott Thornton's avatar
Scott Thornton committed
443
{
444
    std::string name() const { return "asin"; }
Scott Thornton's avatar
Scott Thornton committed
445
446
};

447
struct acos : unary
Scott Thornton's avatar
Scott Thornton committed
448
{
449
    std::string name() const { return "acos"; }
Scott Thornton's avatar
Scott Thornton committed
450
451
};

452
struct atan : unary
Scott Thornton's avatar
Scott Thornton committed
453
{
454
    std::string name() const { return "atan"; }
Scott Thornton's avatar
Scott Thornton committed
455
456
};

457
struct softmax : unary
Scott Thornton's avatar
Scott Thornton committed
458
{
459
    std::string name() const { return "softmax"; }
Scott Thornton's avatar
Scott Thornton committed
460
461
};

462
struct tanh : unary
Scott Thornton's avatar
Scott Thornton committed
463
{
464
    std::string name() const { return "tanh"; }
Scott Thornton's avatar
Scott Thornton committed
465
466
};

467
struct sigmoid : unary
Scott Thornton's avatar
Scott Thornton committed
468
{
469
    std::string name() const { return "sigmoid"; }
Scott Thornton's avatar
Scott Thornton committed
470
471
};

472
struct neg : unary
Scott Thornton's avatar
Scott Thornton committed
473
{
474
    std::string name() const { return "neg"; }
Scott Thornton's avatar
Scott Thornton committed
475
476
};

477
struct flatten
Scott Thornton's avatar
Scott Thornton committed
478
{
Paul's avatar
Paul committed
479
    uint64_t axis = 0;
Scott Thornton's avatar
Scott Thornton committed
480
    std::string name() const { return "flatten"; }
Paul's avatar
Paul committed
481
482
483
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs}.has(1);
Paul's avatar
Paul committed
484
485
        auto&& lens = inputs.front().lens();

Paul's avatar
Paul committed
486
        if(axis > lens.size())
Paul's avatar
Paul committed
487
        {
Paul's avatar
Paul committed
488
            MIGRAPH_THROW("axis for flatten must be less than tensor rank");
Paul's avatar
Paul committed
489
        }
Paul's avatar
Paul committed
490
491
492
493
        auto x =
            std::accumulate(lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{});
        auto y =
            std::accumulate(lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<>{});
494
        return {inputs.at(0).type(), {x, y}};
Paul's avatar
Paul committed
495
496
497
    }
    argument compute(context&, shape output_shape, std::vector<argument> args) const
    {
Paul's avatar
Paul committed
498
        return {std::move(output_shape), std::move(args.front().data)};
Paul's avatar
Paul committed
499
    }
Paul's avatar
Paul committed
500
501
502
503
504
505
506
    friend std::ostream& operator<<(std::ostream& os, const flatten& op)
    {
        os << op.name() << "[";
        os << "axis=" << op.axis;
        os << "]";
        return os;
    }
Scott Thornton's avatar
Scott Thornton committed
507
};
508
509
510
511
512
513
struct broadcast
{
    uint64_t axis = 0;
    std::string name() const { return "broadcast"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
514
515
516
517
        auto t      = inputs.at(0).type();
        auto result = inputs.at(0);
        auto input  = inputs.at(1);

Paul's avatar
Paul committed
518
        std::vector<size_t> bcast_strides(result.lens().size(), 0);
519

Paul's avatar
Paul committed
520
521
        if(std::all_of(
               result.lens().cbegin(), result.lens().cend(), [&](auto x) { return x == 1; }))
522
        {
Scott Thornton's avatar
Scott Thornton committed
523
            if(axis != 0)
Paul's avatar
Paul committed
524
                MIGRAPH_THROW("when broadcasting tensor of size 1, axis should be 0");
Paul's avatar
Paul committed
525
            return {t, result.lens(), std::move(bcast_strides)};
526
527
528
        }
        else
        {
Paul's avatar
Paul committed
529
530
            assert(result.lens().size() - axis >= input.lens().size());
            if(!std::equal(input.lens().begin(), input.lens().end(), result.lens().begin() + axis))
Paul's avatar
Paul committed
531
                MIGRAPH_THROW("when broadcasting success sizes must match");
Paul's avatar
Paul committed
532
            std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
Paul's avatar
Paul committed
533
            return {t, result.lens(), std::move(bcast_strides)};
534
535
        }
    }
Paul's avatar
Paul committed
536
    argument compute(context&, shape output_shape, std::vector<argument> args) const
Scott Thornton's avatar
Scott Thornton committed
537
    {
Paul's avatar
Paul committed
538
        return {std::move(output_shape), std::move(args.at(1).data)};
Scott Thornton's avatar
Scott Thornton committed
539
    }
Paul's avatar
Paul committed
540
541
542
543
544
545
546
    friend std::ostream& operator<<(std::ostream& os, const broadcast& op)
    {
        os << op.name() << "[";
        os << "axis=" << op.axis;
        os << "]";
        return os;
    }
547
548
};

549
struct binary
Scott Thornton's avatar
Scott Thornton committed
550
{
551
    uint64_t broadcast = 0;
552
553
    shape compute_shape(std::vector<shape> inputs) const
    {
554
555
        check_shapes{inputs}.has(2).same_type().same_dims();
        return inputs.at(0);
556
    }
Paul's avatar
Paul committed
557
    argument compute(context&, const shape&, const std::vector<argument>&) const
Paul's avatar
Paul committed
558
559
560
    {
        MIGRAPH_THROW("not computable");
    }
Scott Thornton's avatar
Scott Thornton committed
561
562
};

563
564
565
566
567
568
struct add : binary
{
    std::string name() const { return "add"; }
};

struct sub : binary
Scott Thornton's avatar
Scott Thornton committed
569
570
571
572
{
    std::string name() const { return "sub"; }
};

573
struct mul : binary
Scott Thornton's avatar
Scott Thornton committed
574
575
576
577
{
    std::string name() const { return "mul"; }
};

578
struct div : binary
Scott Thornton's avatar
Scott Thornton committed
579
580
581
582
{
    std::string name() const { return "div"; }
};

Paul's avatar
Paul committed
583
struct outline
Scott Thornton's avatar
Scott Thornton committed
584
{
Paul's avatar
Paul committed
585
586
    shape s;
    std::string name() const { return "outline"; }
Paul's avatar
Paul committed
587
    shape compute_shape(const std::vector<shape>& inputs) const
Paul's avatar
Paul committed
588
    {
Paul's avatar
Paul committed
589
        check_shapes{inputs, *this}.has(0);
Paul's avatar
Paul committed
590
591
        return s;
    }
Paul's avatar
Paul committed
592
593
594
595
    argument compute(context&, const shape&, const std::vector<argument>&) const
    {
        return {s, nullptr};
    }
Scott Thornton's avatar
Scott Thornton committed
596
597
};

Paul's avatar
Paul committed
598
} // namespace migraph
Paul's avatar
Paul committed
599
600

#endif