"docs/source/en/api/schedulers/unipc.md" did not exist on "856dad57bb7a9ee13af4a08492e524b0a145a2c5"
operators.hpp 22.1 KB
Newer Older
Paul's avatar
Paul committed
1
2
#ifndef MIGRAPH_GUARD_OPERATORS_HPP
#define MIGRAPH_GUARD_OPERATORS_HPP
Paul's avatar
Paul committed
3

4
#include <array>
Paul's avatar
Paul committed
5
#include <migraph/operation.hpp>
Paul's avatar
Paul committed
6
#include <migraph/check_shapes.hpp>
Paul's avatar
Paul committed
7
8
#include <migraph/stringutils.hpp>
#include <migraph/streamutils.hpp>
Paul's avatar
Paul committed
9
#include <cmath>
Paul's avatar
Paul committed
10
#include <utility>
Paul's avatar
Paul committed
11

Paul's avatar
Paul committed
12
namespace migraph {
13
namespace op {
Paul's avatar
Paul committed
14

Paul's avatar
Paul committed
15
16
struct not_computable
{
Paul's avatar
Paul committed
17
    argument compute(context&, const shape&, const std::vector<argument>&) const
Paul's avatar
Paul committed
18
19
20
    {
        MIGRAPH_THROW("not computable");
    }
Paul's avatar
Paul committed
21
22
};

23
24
struct batch_norm_inference
{
25
26
    float epsilon  = 1.0e-6f;
    float momentum = 0.9f;
27
28
29

    std::string name() const { return "batch_norm_inference"; }

30
31
32
33
34
35
36
37
    enum bn_infer_mode_t
    {
        per_activation,
        spatial,
    };

    bn_infer_mode_t bn_mode = spatial;

38
39
    bool is_test = false;

40
41
42
43
44
45
46
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs, *this}.has(5);
        return inputs.front();
    }
};

Paul's avatar
Paul committed
47
struct convolution
Paul's avatar
Paul committed
48
{
Paul's avatar
Paul committed
49
50
51
    std::array<std::size_t, 2> padding  = {{0, 0}};
    std::array<std::size_t, 2> stride   = {{1, 1}};
    std::array<std::size_t, 2> dilation = {{1, 1}};
Paul's avatar
Paul committed
52
53
54
55
56
57
58
    enum padding_mode_t
    {
        default_, // NOLINT
        same,
        valid
    };
    padding_mode_t padding_mode = default_;
Paul's avatar
Paul committed
59
    std::string name() const { return "convolution"; }
Paul's avatar
Paul committed
60
61
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
62
        check_shapes{inputs, *this}.has(2).same_type().same_ndims().only_dims(4);
Paul's avatar
Paul committed
63

Paul's avatar
Paul committed
64
        const shape& input   = inputs.at(0);
Paul's avatar
Paul committed
65
        const shape& weights = inputs.at(1);
Paul's avatar
Paul committed
66
        auto t               = input.type();
Paul's avatar
Paul committed
67
68
        if(padding_mode == default_)
        {
Paul's avatar
Paul committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
            return {t,
                    {
                        input.lens()[0],
                        weights.lens()[0],
                        std::size_t(std::max<std::ptrdiff_t>(
                            1,
                            (input.lens()[2] - (1 + dilation[0] * (weights.lens()[2] - 1)) +
                             2 * padding[0]) /
                                    stride[0] +
                                1)),
                        std::size_t(std::max<std::ptrdiff_t>(
                            1,
                            (input.lens()[3] - (1 + dilation[1] * (weights.lens()[3] - 1)) +
                             2 * padding[1]) /
                                    stride[1] +
                                1)),
                    }};
Paul's avatar
Paul committed
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        }
        else if(padding_mode == same)
        {
            return {t,
                    {input.lens()[0],
                     weights.lens()[0],
                     static_cast<std::size_t>(
                         std::ceil(static_cast<double>(input.lens()[2]) / stride[0])),
                     static_cast<std::size_t>(
                         std::ceil(static_cast<double>(input.lens()[3]) / stride[1]))}};
        }
        else if(padding_mode == valid)
        {
            return {
                t,
                {input.lens()[0],
                 weights.lens()[0],
                 static_cast<std::size_t>(std::ceil(
                     static_cast<double>(input.lens()[2] - weights.lens()[2] + 1) / stride[0])),
                 static_cast<std::size_t>(std::ceil(
                     static_cast<double>(input.lens()[3] - weights.lens()[3] + 1) / stride[1]))}};
        }
        else
        {
Paul's avatar
Paul committed
110
            MIGRAPH_THROW("Invalid padding mode");
Paul's avatar
Paul committed
111
        }
Paul's avatar
Paul committed
112
    }
Paul's avatar
Paul committed
113

Paul's avatar
Paul committed
114
    friend std::ostream& operator<<(std::ostream& os, const convolution& op)
Paul's avatar
Paul committed
115
    {
Paul's avatar
Paul committed
116
117
118
119
120
        os << op.name() << "[";
        os << "padding={" << stream_range(op.padding) << "}, ";
        os << "stride={" << stream_range(op.stride) << "}, ";
        os << "dilation={" << stream_range(op.dilation) << "}";
        os << "]";
Paul's avatar
Paul committed
121
122
        return os;
    }
Paul's avatar
Paul committed
123
124
};

Scott Thornton's avatar
Scott Thornton committed
125
126
struct im2col
{
Scott Thornton's avatar
Scott Thornton committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    std::array<std::size_t, 2> padding  = {{0, 0}};
    std::array<std::size_t, 2> stride   = {{1, 1}};
    std::array<std::size_t, 2> dilation = {{1, 1}};
    enum padding_mode_t
    {
        default_, // NOLINT
        same,
        valid
    };

    std::string name() const { return "im2col"; }

    shape compute_shape(std::vector<shape> inputs) const
    {
Scott Thornton's avatar
Scott Thornton committed
141
142
143
        auto input          = inputs[0];
        auto weights        = inputs[1];
        auto batch_size     = input.lens()[0];
Scott Thornton's avatar
Scott Thornton committed
144
        auto input_channels = weights.lens()[1];
Scott Thornton's avatar
Scott Thornton committed
145
146
        auto kernel_height  = weights.lens()[2];
        auto kernel_width   = weights.lens()[3];
Scott Thornton's avatar
Scott Thornton committed
147
        check_shapes{inputs, *this}.has(2);
Scott Thornton's avatar
Scott Thornton committed
148
149
        if(batch_size != 1)
            MIGRAPH_THROW("im2col only support batch_size 1");
Scott Thornton's avatar
Scott Thornton committed
150
        auto output_height = std::size_t(std::max<std::ptrdiff_t>(
Scott Thornton's avatar
Scott Thornton committed
151
152
153
            1,
            (input.lens()[2] - (1 + dilation[0] * (kernel_height - 1)) + 2 * padding[0]) /
                    stride[0] +
Scott Thornton's avatar
Scott Thornton committed
154
                1));
Scott Thornton's avatar
Scott Thornton committed
155
156
157
158
        auto output_width  = std::size_t(std::max<std::ptrdiff_t>(
            1,
            (input.lens()[3] - (1 + dilation[1] * (kernel_width - 1)) + 2 * padding[1]) /
                    stride[1] +
Scott Thornton's avatar
Scott Thornton committed
159
                1));
Scott Thornton's avatar
Scott Thornton committed
160
161
        auto channels_col  = kernel_height * kernel_width * input_channels;
        return {input.type(), {output_height * output_width, channels_col}};
Scott Thornton's avatar
Scott Thornton committed
162
163
164
    }
};

Paul's avatar
Paul committed
165
struct pooling
Paul's avatar
Paul committed
166
{
Paul's avatar
Paul committed
167
    std::string mode                   = "average";
Paul's avatar
Paul committed
168
169
170
    std::array<std::size_t, 2> padding = {{0, 0}};
    std::array<std::size_t, 2> stride  = {{1, 1}};
    std::array<std::size_t, 2> lengths = {{1, 1}};
Paul's avatar
Paul committed
171
    std::string name() const { return "pooling"; }
Scott Thornton's avatar
Scott Thornton committed
172

Paul's avatar
Paul committed
173
174
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
175
        check_shapes{inputs, *this}.has(1).only_dims(4);
Paul's avatar
Paul committed
176

Paul's avatar
Paul committed
177
        const shape& input = inputs.at(0);
Paul's avatar
Paul committed
178
        auto t             = input.type();
Paul's avatar
Paul committed
179

Paul's avatar
Paul committed
180
181
        assert(lengths[0] <= (input.lens()[2] + 2 * padding[0]));
        assert(lengths[1] <= (input.lens()[3] + 2 * padding[1]));
Paul's avatar
Paul committed
182

Scott Thornton's avatar
Scott Thornton committed
183
184
185
186
187
188
        return {t,
                {
                    input.lens()[0],
                    input.lens()[1],
                    std::size_t(std::max<std::ptrdiff_t>(
                        1,
Paul's avatar
Paul committed
189
                        std::ptrdiff_t(std::floor((input.lens()[2] + 2 * padding[0] - lengths[0]) /
Paul's avatar
Paul committed
190
                                                  static_cast<float>(stride[0]))) +
Scott Thornton's avatar
Scott Thornton committed
191
192
193
                            1)),
                    std::size_t(std::max<std::ptrdiff_t>(
                        1,
Paul's avatar
Paul committed
194
                        std::ptrdiff_t(std::floor((input.lens()[3] + 2 * padding[1] - lengths[1]) /
Paul's avatar
Paul committed
195
                                                  static_cast<float>(stride[1]))) +
Scott Thornton's avatar
Scott Thornton committed
196
197
                            1)),
                }};
Paul's avatar
Paul committed
198
    }
Paul's avatar
Paul committed
199

Paul's avatar
Paul committed
200
    friend std::ostream& operator<<(std::ostream& os, const pooling& op)
Paul's avatar
Paul committed
201
    {
Paul's avatar
Paul committed
202
203
204
205
206
        os << op.name() << "[";
        os << "padding={" << stream_range(op.padding) << "}, ";
        os << "stride={" << stream_range(op.stride) << "}, ";
        os << "lengths={" << stream_range(op.lengths) << "}";
        os << "]";
Paul's avatar
Paul committed
207
208
        return os;
    }
Paul's avatar
Paul committed
209
210
};

Paul's avatar
Paul committed
211
struct activation
Paul's avatar
Paul committed
212
213
{
    std::string mode;
Paul's avatar
Paul committed
214
    std::string name() const { return "activation"; }
Paul's avatar
Paul committed
215
216
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
217
        check_shapes{inputs, *this}.has(1);
Paul's avatar
Paul committed
218
219
        return inputs.front();
    }
Paul's avatar
Paul committed
220
    friend std::ostream& operator<<(std::ostream& os, const activation& op)
Paul's avatar
Paul committed
221
    {
Paul's avatar
Paul committed
222
        os << op.name() << ":" << op.mode;
Paul's avatar
Paul committed
223
224
        return os;
    }
Paul's avatar
Paul committed
225
226
};

227
228
229
230
231
232
struct transpose
{
    std::vector<int64_t> dims;
    std::string name() const { return "transpose"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
233
        check_shapes{inputs, *this}.has(1);
234
        auto input         = inputs.at(0);
235
        auto input_lens    = input.lens();
236
237
        auto input_strides = input.strides();
        auto t             = input.type();
Paul's avatar
Paul committed
238
239
        if(dims.size() != input_lens.size())
        {
Paul's avatar
Paul committed
240
            MIGRAPH_THROW("Permutation has wrong number of axes");
241
242
243
        }
        std::vector<int64_t> axes(dims.size());
        std::iota(axes.begin(), axes.end(), 0);
Paul's avatar
Paul committed
244
245
        if(!std::is_permutation(axes.begin(), axes.end(), dims.begin()))
        {
Paul's avatar
Paul committed
246
            MIGRAPH_THROW("Invalid permutation");
247
        }
248
249
        std::vector<size_t> output_lens(input_lens.size());
        std::vector<size_t> output_strides(input_lens.size());
Paul's avatar
Paul committed
250
251
252
        for(int i = 0; i < output_lens.size(); i++)
        {
            output_lens[i]    = input_lens[dims[i]];
253
254
            output_strides[i] = input_strides[dims[i]];
        }
255
        return {t, output_lens, output_strides};
256
    }
Paul's avatar
Paul committed
257
    argument compute(context&, shape output_shape, std::vector<argument> args) const
Paul's avatar
Paul committed
258
    {
Paul's avatar
Paul committed
259
        return {std::move(output_shape), std::move(args.front().data)};
Paul's avatar
Paul committed
260
    }
Paul's avatar
Paul committed
261
262
263
264
265
266
267
    friend std::ostream& operator<<(std::ostream& os, const transpose& op)
    {
        os << op.name() << "[";
        os << "dims={" << stream_range(op.dims) << "}";
        os << "]";
        return os;
    }
268
269
};

Paul's avatar
Paul committed
270
struct contiguous
271
272
273
274
{
    std::string name() const { return "contiguous"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
275
        check_shapes{inputs, *this}.has(1);
Paul's avatar
Paul committed
276
277
278
279
        auto lens = inputs.at(0).lens();
        auto t    = inputs.at(0).type();
        if(lens.size() < 2)
        {
Paul's avatar
Paul committed
280
            MIGRAPH_THROW("Number of dimensions should exceed 1");
281
282
283
284
285
        }
        return {t, lens};
    }
};

286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
struct concat
{
    std::size_t axis = 0;
    std::string name() const { return "concat"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        if (inputs.empty())
        {
            MIGRAPH_THROW("Number of input tensors should exceed 0");
        }

        const auto& first_shape_lens = inputs.front().lens();
        const auto& type = inputs.front().type();
        for (std::size_t l = 0; l < first_shape_lens.size(); l++) {
            if (l != axis) {
                if (!std::all_of(inputs.begin(), inputs.end(), [&] (auto s) {
                    return s.lens()[l] == first_shape_lens[l];}))
                { 
                    MIGRAPH_THROW("Non-axis dimensions should match"); 
                }
            }
        }
        std::size_t new_dim_axis = 0;
        for (const auto& input : inputs)
        {
            const auto& lens = input.lens();
            new_dim_axis += lens[axis];
        }
        std::vector<std::size_t> new_lens;
        std::copy(first_shape_lens.begin(), first_shape_lens.end(), std::back_inserter(new_lens));
        new_lens[axis] = new_dim_axis;
        return {type, new_lens};
    }
};

321
322
323
324
325
326
struct slice
{
    std::vector<int64_t> axes;
    std::vector<int64_t> starts;
    std::vector<int64_t> ends;
    std::string name() const { return "slice"; }
Scott Thornton's avatar
Scott Thornton committed
327
328

    auto fix_index(const std::vector<std::size_t>& lens, std::size_t axis, int64_t index) const
329
    {
Scott Thornton's avatar
Scott Thornton committed
330
        int64_t r = std::min(index, static_cast<int64_t>(lens[axis]));
Scott Thornton's avatar
Scott Thornton committed
331
332
        if(r < 0)
            r += lens[axis];
Scott Thornton's avatar
Scott Thornton committed
333
        return std::size_t(r);
Scott Thornton's avatar
Scott Thornton committed
334
335
336
337
338
339
340
    }

    auto compute_offset(const shape& s) const
    {
        const std::vector<std::size_t>& lens    = s.lens();
        const std::vector<std::size_t>& strides = s.strides();
        auto offset                             = 0;
Scott Thornton's avatar
Scott Thornton committed
341
        if(!axes.empty())
Scott Thornton's avatar
Scott Thornton committed
342
        {
Scott Thornton's avatar
Scott Thornton committed
343
344
345
346
347
            for(std::size_t i = 0; i < axes.size(); i++)
            {
                auto axis = axes[i];
                offset += fix_index(lens, axis, starts[i]) * strides[axis];
            }
348
        }
Scott Thornton's avatar
Scott Thornton committed
349
350
        else
        {
Scott Thornton's avatar
Scott Thornton committed
351
352
353
354
            for(std::size_t axis = 0; axis < lens.size(); axis++)
            {
                offset += fix_index(lens, axis, starts[axis]) * strides[axis];
            }
355
        }
Scott Thornton's avatar
Scott Thornton committed
356
357
358
359
360
        return offset;
    }

    shape compute_shape(std::vector<shape> inputs) const
    {
Scott Thornton's avatar
Scott Thornton committed
361
362
363
364
        auto input_shape        = inputs[0];
        auto t                  = input_shape.type();
        const auto& old_lens    = input_shape.lens();
        const auto& old_strides = input_shape.strides();
Scott Thornton's avatar
Scott Thornton committed
365
366
367
368
369
370
371
372
373
374
        // std::vector<int64_t> t_axes(old_lens.size());
        // if(axes.size() == 0)
        // {
        //     std::iota(t_axes.begin(), t_axes.end(), 0);
        // }
        // else
        // {
        //     std::copy(axes.begin(), axes.end(), t_axes.begin());
        // }
        if(starts.size() != axes.size() || axes.size() != ends.size())
Scott Thornton's avatar
Scott Thornton committed
375
        {
376
377
            MIGRAPH_THROW("inconsistent sizes");
        }
Scott Thornton's avatar
Scott Thornton committed
378
379
        std::vector<std::size_t> new_lens = old_lens;
        for(std::size_t i = 0; i < axes.size(); i++)
Scott Thornton's avatar
Scott Thornton committed
380
        {
Scott Thornton's avatar
Scott Thornton committed
381
382
383
            auto axis = axes[i];
            new_lens[axis] =
                fix_index(old_lens, axis, ends[i]) - fix_index(old_lens, axis, starts[i]);
384
385
386
387
388
        }
        return shape{t, new_lens, old_strides};
    }
    argument compute(context&, shape output_shape, std::vector<argument> args) const
    {
Scott Thornton's avatar
Scott Thornton committed
389
390
391
        auto input  = args[0];
        auto offset = compute_offset(input.get_shape()) * output_shape.type_size();
        return {std::move(output_shape), [=] { return input.data() + offset; }};
392
393
394
395
396
397
398
399
400
401
    }
};

struct squeeze
{
    std::vector<int64_t> axes;
    std::string name() const { return "squeeze"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        auto input_shape = inputs[0];
Scott Thornton's avatar
Scott Thornton committed
402
403
        auto type        = input_shape.type();
        auto old_lens    = input_shape.lens();
wsttiger's avatar
wsttiger committed
404
405
        if(std::any_of(
               axes.begin(), axes.end(), [&](auto axis) { return input_shape.lens()[axis] != 1; }))
Scott Thornton's avatar
Scott Thornton committed
406
        {
wsttiger's avatar
wsttiger committed
407
            MIGRAPH_THROW("squeeze axis dimension should be equal to 1");
408
409
        }
        std::vector<std::size_t> new_lens;
Scott Thornton's avatar
Scott Thornton committed
410
        if(axes.empty())
Scott Thornton's avatar
Scott Thornton committed
411
        {
wsttiger's avatar
wsttiger committed
412
413
414
415
            std::copy_if(old_lens.begin(),
                         old_lens.end(),
                         std::back_inserter(new_lens),
                         [](auto len) { return len != 1; });
416
        }
Scott Thornton's avatar
Scott Thornton committed
417
418
419
420
421
422
        else
        {
            for(std::size_t i = 0; i < old_lens.size(); i++)
            {
                if(std::find(axes.begin(), axes.end(), i) == axes.end())
                {
423
424
425
426
427
428
429
430
431
                    new_lens.push_back(old_lens[i]);
                }
            }
        }
        return shape{type, new_lens};
    }
    argument compute(context&, shape output_shape, std::vector<argument> args) const
    {
        return {std::move(output_shape), std::move(args.front().data)};
Scott Thornton's avatar
Scott Thornton committed
432
    }
433
434
435
436
437
438
439
440
};

struct unsqueeze
{
    std::vector<int64_t> axes;
    std::string name() const { return "unsqueeze"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Scott Thornton's avatar
Scott Thornton committed
441
442
443
        auto input_shape     = inputs[0];
        auto type            = input_shape.type();
        auto old_lens        = input_shape.lens();
444
445
446
        std::size_t new_size = old_lens.size() + axes.size();
        std::vector<std::size_t> new_lens(new_size);
        std::size_t p = 0;
Scott Thornton's avatar
Scott Thornton committed
447
448
449
450
        for(std::size_t i = 0; i < new_size; i++)
        {
            if(std::find(axes.begin(), axes.end(), i) != axes.end())
            {
451
                new_lens[i] = 1;
Scott Thornton's avatar
Scott Thornton committed
452
453
454
            }
            else
            {
455
456
457
458
459
460
461
462
463
464
465
                new_lens[i] = old_lens[p++];
            }
        }
        return shape{type, new_lens};
    }
    argument compute(context&, shape output_shape, std::vector<argument> args) const
    {
        return {std::move(output_shape), std::move(args.front().data)};
    }
};

Paul's avatar
Paul committed
466
467
468
struct reshape
{
    std::vector<int64_t> dims;
Paul's avatar
Paul committed
469
    std::string name() const { return "reshape"; }
Paul's avatar
Paul committed
470
471
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
472
        check_shapes{inputs, *this}.has(1);
Paul's avatar
Paul committed
473
474
        auto&& idims = inputs.front().lens();
        std::vector<std::size_t> rdims(dims.begin(), dims.end());
475
476
477
        auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
        if(n_neg_dims > 1)
            MIGRAPH_THROW("Dimensions for reshape can only have one -1 dim");
Paul's avatar
Paul committed
478
        for(std::size_t i = 0; i < dims.size(); i++)
Paul's avatar
Paul committed
479
480
481
482
        {
            if(dims[i] == 0)
                rdims[i] = idims[i];
        }
483
484
485
486
487
488
489
490
491
492
493
        if(n_neg_dims > 0)
        {
            size_t missing_dim =
                -inputs.front().elements() /
                std::accumulate(rdims.begin(), rdims.end(), 1, std::multiplies<int64_t>());
            for(std::size_t i = 0; i < rdims.size(); i++)
            {
                if(dims[i] == -1)
                    rdims[i] = missing_dim;
            }
        }
Paul's avatar
Paul committed
494
495
496
        if(dims.back() == -1)
        {
            rdims.pop_back();
Paul's avatar
Paul committed
497
            std::copy(idims.begin() + rdims.size(), idims.end(), std::back_inserter(rdims));
Paul's avatar
Paul committed
498
        }
Scott Thornton's avatar
Scott Thornton committed
499
        shape s{inputs.front().type(), rdims};
Paul's avatar
Paul committed
500
        if(s.elements() != inputs.front().elements())
Paul's avatar
Paul committed
501
            MIGRAPH_THROW("Wrong number of elements for reshape");
Scott Thornton's avatar
Scott Thornton committed
502
        return s;
Paul's avatar
Paul committed
503
    }
Paul's avatar
Paul committed
504
    argument compute(context&, shape output_shape, std::vector<argument> args) const
Paul's avatar
Paul committed
505
    {
Paul's avatar
Paul committed
506
        return {std::move(output_shape), std::move(args.front().data)};
Paul's avatar
Paul committed
507
    }
Paul's avatar
Paul committed
508
    friend std::ostream& operator<<(std::ostream& os, const reshape& op)
Paul's avatar
Paul committed
509
    {
Paul's avatar
Paul committed
510
        os << op.name() << "[";
Paul's avatar
Paul committed
511
        os << "dims={" << stream_range(op.dims) << "}";
Paul's avatar
Paul committed
512
        os << "]";
Paul's avatar
Paul committed
513
514
        return os;
    }
Paul's avatar
Paul committed
515
516
};

517
518
struct gemm
{
Paul's avatar
Paul committed
519
    float alpha = 1.0;
Paul's avatar
Paul committed
520
    float beta  = 0.0;
521
    std::string name() const { return "gemm"; }
522
523
    shape compute_shape(std::vector<shape> inputs) const
    {
Paul's avatar
Paul committed
524
        check_shapes{inputs, *this}.has(2).same_type();
525
526
        const shape& a = inputs.at(0);
        const shape& b = inputs.at(1);
Scott Thornton's avatar
Scott Thornton committed
527
        auto t         = a.type();
528

529
        if(a.lens()[1] != b.lens()[0])
Paul's avatar
Paul committed
530
531
            MIGRAPH_THROW("Inner dimensions do not match: {" + to_string_range(a.lens()) + "} x {" +
                          to_string_range(b.lens()) + "}");
Scott Thornton's avatar
Scott Thornton committed
532
        return {t, {a.lens()[0], b.lens()[1]}};
533
    }
534
535

    friend std::ostream& operator<<(std::ostream& os, const gemm& op)
536
537
    {
        os << op.name() << "[";
538
        os << "]";
Scott Thornton's avatar
Scott Thornton committed
539
        return os;
540
541
542
    }
};

543
struct unary
Scott Thornton's avatar
Scott Thornton committed
544
{
545
546
    shape compute_shape(std::vector<shape> inputs) const
    {
547
548
        check_shapes{inputs}.has(1);
        return inputs.at(0);
549
    }
Scott Thornton's avatar
Scott Thornton committed
550
551
};

552
553
struct identity : unary
{
554
    std::string name() const { return "identity"; }
555
556
557
};

struct abs : unary
Scott Thornton's avatar
Scott Thornton committed
558
{
559
    std::string name() const { return "abs"; }
Scott Thornton's avatar
Scott Thornton committed
560
561
};

562
struct exp : unary
Scott Thornton's avatar
Scott Thornton committed
563
{
564
    std::string name() const { return "exp"; }
Scott Thornton's avatar
Scott Thornton committed
565
566
};

567
struct sin : unary
Scott Thornton's avatar
Scott Thornton committed
568
{
569
    std::string name() const { return "sin"; }
Scott Thornton's avatar
Scott Thornton committed
570
571
};

572
struct cos : unary
Scott Thornton's avatar
Scott Thornton committed
573
{
574
    std::string name() const { return "cos"; }
Scott Thornton's avatar
Scott Thornton committed
575
576
};

577
struct tan : unary
Scott Thornton's avatar
Scott Thornton committed
578
{
579
    std::string name() const { return "tan"; }
Scott Thornton's avatar
Scott Thornton committed
580
581
};

582
struct asin : unary
Scott Thornton's avatar
Scott Thornton committed
583
{
584
    std::string name() const { return "asin"; }
Scott Thornton's avatar
Scott Thornton committed
585
586
};

587
struct acos : unary
Scott Thornton's avatar
Scott Thornton committed
588
{
589
    std::string name() const { return "acos"; }
Scott Thornton's avatar
Scott Thornton committed
590
591
};

592
struct atan : unary
Scott Thornton's avatar
Scott Thornton committed
593
{
594
    std::string name() const { return "atan"; }
Scott Thornton's avatar
Scott Thornton committed
595
596
};

597
struct tanh : unary
Scott Thornton's avatar
Scott Thornton committed
598
{
599
    std::string name() const { return "tanh"; }
Scott Thornton's avatar
Scott Thornton committed
600
601
};

602
struct sigmoid : unary
Scott Thornton's avatar
Scott Thornton committed
603
{
604
    std::string name() const { return "sigmoid"; }
Scott Thornton's avatar
Scott Thornton committed
605
606
};

607
struct neg : unary
Scott Thornton's avatar
Scott Thornton committed
608
{
609
    std::string name() const { return "neg"; }
Scott Thornton's avatar
Scott Thornton committed
610
611
};

Paul's avatar
Paul committed
612
613
614
615
616
617
618
619
620
621
struct softmax
{
    std::string name() const { return "softmax"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs}.has(1).only_dims(4);
        return inputs.at(0);
    }
};

622
struct flatten
Scott Thornton's avatar
Scott Thornton committed
623
{
Paul's avatar
Paul committed
624
    uint64_t axis = 0;
Scott Thornton's avatar
Scott Thornton committed
625
    std::string name() const { return "flatten"; }
Paul's avatar
Paul committed
626
627
628
    shape compute_shape(std::vector<shape> inputs) const
    {
        check_shapes{inputs}.has(1);
Paul's avatar
Paul committed
629
630
        auto&& lens = inputs.front().lens();

Paul's avatar
Paul committed
631
        if(axis > lens.size())
Paul's avatar
Paul committed
632
        {
Paul's avatar
Paul committed
633
            MIGRAPH_THROW("axis for flatten must be less than tensor rank");
Paul's avatar
Paul committed
634
        }
Paul's avatar
Paul committed
635
636
637
638
        auto x =
            std::accumulate(lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{});
        auto y =
            std::accumulate(lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<>{});
639
        return {inputs.at(0).type(), {x, y}};
Paul's avatar
Paul committed
640
641
642
    }
    argument compute(context&, shape output_shape, std::vector<argument> args) const
    {
Paul's avatar
Paul committed
643
        return {std::move(output_shape), std::move(args.front().data)};
Paul's avatar
Paul committed
644
    }
Paul's avatar
Paul committed
645
646
647
648
649
650
651
    friend std::ostream& operator<<(std::ostream& os, const flatten& op)
    {
        os << op.name() << "[";
        os << "axis=" << op.axis;
        os << "]";
        return os;
    }
Scott Thornton's avatar
Scott Thornton committed
652
};
653
654
655
struct broadcast
{
    uint64_t axis = 0;
Scott Thornton's avatar
Scott Thornton committed
656
    shape broadcast_shape;
657
658
659
    std::string name() const { return "broadcast"; }
    shape compute_shape(std::vector<shape> inputs) const
    {
Scott Thornton's avatar
Scott Thornton committed
660
661
        auto t     = inputs.at(0).type();
        auto input = inputs.at(0);
Paul's avatar
Paul committed
662

Scott Thornton's avatar
Scott Thornton committed
663
        std::vector<size_t> bcast_strides(broadcast_shape.lens().size(), 0);
664

Scott Thornton's avatar
Scott Thornton committed
665
666
667
        if(std::all_of(broadcast_shape.lens().cbegin(), broadcast_shape.lens().cend(), [&](auto x) {
               return x == 1;
           }))
668
        {
Scott Thornton's avatar
Scott Thornton committed
669
            if(axis != 0)
Paul's avatar
Paul committed
670
                MIGRAPH_THROW("when broadcasting tensor of size 1, axis should be 0");
Scott Thornton's avatar
Scott Thornton committed
671
            return {t, broadcast_shape.lens(), std::move(bcast_strides)};
672
673
674
        }
        else
        {
Scott Thornton's avatar
Scott Thornton committed
675
            assert(broadcast_shape.lens().size() - axis >= input.lens().size());
Scott Thornton's avatar
Scott Thornton committed
676
677
            if(!std::equal(
                   input.lens().begin(), input.lens().end(), broadcast_shape.lens().begin() + axis))
Paul's avatar
Paul committed
678
                MIGRAPH_THROW("when broadcasting success sizes must match");
Paul's avatar
Paul committed
679
            std::copy(input.strides().begin(), input.strides().end(), bcast_strides.begin() + axis);
Scott Thornton's avatar
Scott Thornton committed
680
            return {t, broadcast_shape.lens(), std::move(bcast_strides)};
681
682
        }
    }
Paul's avatar
Paul committed
683
    argument compute(context&, shape output_shape, std::vector<argument> args) const
Scott Thornton's avatar
Scott Thornton committed
684
    {
Scott Thornton's avatar
Scott Thornton committed
685
        return {std::move(output_shape), std::move(args.at(0).data)};
Scott Thornton's avatar
Scott Thornton committed
686
    }
Paul's avatar
Paul committed
687
688
689
690
691
692
693
    friend std::ostream& operator<<(std::ostream& os, const broadcast& op)
    {
        os << op.name() << "[";
        os << "axis=" << op.axis;
        os << "]";
        return os;
    }
694
695
};

696
struct binary
Scott Thornton's avatar
Scott Thornton committed
697
{
698
    uint64_t broadcast = 0;
699
700
    shape compute_shape(std::vector<shape> inputs) const
    {
701
702
        check_shapes{inputs}.has(2).same_type().same_dims();
        return inputs.at(0);
703
    }
Scott Thornton's avatar
Scott Thornton committed
704
705
};

706
707
708
709
710
711
struct add : binary
{
    std::string name() const { return "add"; }
};

struct sub : binary
Scott Thornton's avatar
Scott Thornton committed
712
713
714
715
{
    std::string name() const { return "sub"; }
};

716
struct mul : binary
Scott Thornton's avatar
Scott Thornton committed
717
718
719
720
{
    std::string name() const { return "mul"; }
};

721
struct div : binary
Scott Thornton's avatar
Scott Thornton committed
722
723
724
725
{
    std::string name() const { return "div"; }
};

Paul's avatar
Paul committed
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
struct load
{
    shape s;
    std::size_t offset = 0;
    std::string name() const { return "load"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
        check_shapes{inputs}.has(1);
        return s;
    }
    argument compute(context&, const shape&, const std::vector<argument>& args) const
    {
        return {s, args[0].data() + offset};
    }
};

Paul's avatar
Paul committed
742
struct outline
Scott Thornton's avatar
Scott Thornton committed
743
{
Paul's avatar
Paul committed
744
745
    shape s;
    std::string name() const { return "outline"; }
Paul's avatar
Paul committed
746
    shape compute_shape(const std::vector<shape>& inputs) const
Paul's avatar
Paul committed
747
    {
Paul's avatar
Paul committed
748
        check_shapes{inputs, *this}.has(0);
Paul's avatar
Paul committed
749
750
        return s;
    }
Paul's avatar
Paul committed
751
752
753
754
    argument compute(context&, const shape&, const std::vector<argument>&) const
    {
        return {s, nullptr};
    }
Scott Thornton's avatar
Scott Thornton committed
755
756
};

757
} // namespace op
Paul's avatar
Paul committed
758
} // namespace migraph
Paul's avatar
Paul committed
759
760

#endif