operation.hpp 49.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
Paul's avatar
Paul committed
24
25
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_OPERAND_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_OPERAND_HPP
Paul's avatar
Paul committed
26

Paul's avatar
Paul committed
27
#include <cassert>
Paul's avatar
Paul committed
28
#include <string>
Paul's avatar
Paul committed
29
#include <functional>
Paul's avatar
Paul committed
30
31
32
#include <memory>
#include <type_traits>
#include <utility>
Shucai Xiao's avatar
Shucai Xiao committed
33
#include <unordered_map>
Paul's avatar
Paul committed
34
#include <migraphx/reflect.hpp>
35
36
#include <migraphx/dyn_output.hpp>
#include <migraphx/functional.hpp>
Paul's avatar
Paul committed
37
#include <migraphx/streamutils.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
38
#include <migraphx/normalize_attributes.hpp>
Paul's avatar
Paul committed
39
#include <migraphx/argument.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
40
#include <migraphx/module_ref.hpp>
41
#include <migraphx/serialize.hpp>
Paul's avatar
Paul committed
42
#include <migraphx/auto_any_cast.hpp>
43
#include <migraphx/lifetime.hpp>
Paul's avatar
Paul committed
44
45
46
#include <migraphx/config.hpp>

namespace migraphx {
Paul's avatar
Paul committed
47
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
48

Paul's avatar
Paul committed
49
50
struct context;

Paul's avatar
Paul committed
51
52
#ifdef DOXYGEN

Paul's avatar
Paul committed
53
/// The operation interface represents an action an instruction will perform. All
Paul's avatar
Paul committed
54
55
56
57
58
/// operation classes must be CopyConstructible.
struct operation
{
    /// A unique name identifying the operation
    std::string name() const;
Paul's avatar
Paul committed
59
60
    /// An optional method that can be used to finalize the operator before running
    void finalize(context& ctx);
Paul's avatar
Paul committed
61
62
63
    /// This is used to compute the resulting shape from an operation. If an
    /// operation cannot be run with input shapes, then it should throw an
    /// exception.
Paul's avatar
Paul committed
64
    shape compute_shape(const std::vector<shape>& input) const;
Paul's avatar
Paul committed
65
    /**
Paul's avatar
Paul committed
66
67
68
69
     * @brief This performs the operation's computation.
     *
     * This method can be optional when the operation is only used as a placeholder to be lowered
     * later on.
Paul's avatar
Paul committed
70
71
72
     *
     * @param ctx This is the context created by the `target` during compilation. Implementations
     * can use the target's `context` class rather than the `context` interface class.
73
74
75
76
     * @param output Equivalent to running `compute_shape` with each `shape` of the `argument`.
     * For a fixed shape, the returned argument will have the same shape as `output`.
     * For a dynamic shape, the returned `argument` will be a fixed shape within the bounds
     * set in the dynamic shape `output`.
Paul's avatar
Paul committed
77
     * @param input This is the `argument` result from the previous instruction's computation.
Paul's avatar
Paul committed
78
79
80
     * @return Return an `argument` of the result computation. The `shape` of `argument` should be
     * the same the `output` shape.
     */
Paul's avatar
Paul committed
81
    argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const;
Paul's avatar
Paul committed
82
83
    /// An optional method to return which argument the output will alias. If
    /// there is no aliased output then -1 can be returned.
Paul's avatar
Paul committed
84
    std::ptrdiff_t output_alias(const std::vector<shape>& input) const;
Paul's avatar
Paul committed
85
86
87
88
89
    /// An optional stream operator to print the operation. When this is not
    /// implemented, it will just print the operation's name.
    friend std::ostream& operator<<(std::ostream& os, const operation& op);
};

Paul's avatar
Paul committed
90
91
/// Returns true if operation does not require a context to run compute
bool is_context_free(const operation& x);
Shucai Xiao's avatar
Shucai Xiao committed
92
93
/// Returns true if operation needs normalization before running compute
bool need_normalization(const operation& x);
Paul's avatar
Paul committed
94
95
/// Returns true if the operation has a finalize method
bool has_finalize(const operation& x);
Paul's avatar
Paul committed
96

Paul's avatar
Paul committed
97
98
#else

99
100
101
namespace detail {

namespace operation_operators {
Paul's avatar
Paul committed
102
103
104
105

template <class T>
auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
{
Paul's avatar
Paul committed
106
107
    os << x.name();
    char delim = '[';
Paul's avatar
Paul committed
108
    reflect_each(x, [&](auto&& y, auto name) {
Paul's avatar
Paul committed
109
        os << delim;
Paul's avatar
Paul committed
110
111
        os << name << "=";
        stream_write_value(os, y);
Paul's avatar
Paul committed
112
113
114
115
116
        delim = ',';
    });
    if(delim == ',')
        os << "]";
    return os;
Paul's avatar
Paul committed
117
118
}

Paul's avatar
Paul committed
119
120
121
template <class T, class U>
auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
{
122
123
    static_assert(is_reflectable<T>{} or sizeof(T) <= 1,
                  "Missing equality operator or reflect method.");
Paul's avatar
Paul committed
124
125
126
127
128
129
    if(x.name() != y.name())
        return false;
    const auto& yy = any_cast<T>(y);
    return reflect_tie(x) == reflect_tie(yy);
}

130
} // namespace operation_operators
Paul's avatar
Paul committed
131

Shucai Xiao's avatar
Shucai Xiao committed
132
template <class T>
133
134
135
136
137
138
139
140
auto compute_shape_op(rank<3>, const T& x, const std::vector<shape>& inputs)
    -> decltype(x.compute_shape(inputs))
{
    return x.compute_shape(inputs);
}

template <class T>
auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
Shucai Xiao's avatar
Shucai Xiao committed
141
    -> decltype(x.normalize_compute_shape(inputs))
Shucai Xiao's avatar
Shucai Xiao committed
142
{
143
144
    if(inputs.empty())
        MIGRAPHX_THROW("At least one input is required for " + x.name());
Charlie Lin's avatar
Charlie Lin committed
145
    dependent_type<operation, T> y = x;
146
    normalize_attributes(y, inputs[0].max_lens());
Shucai Xiao's avatar
Shucai Xiao committed
147
148
149
    return any_cast<T>(y).normalize_compute_shape(inputs);
}

150
template <class T>
151
auto compute_shape_op(rank<1>, const T& x, const std::vector<shape>& inputs)
152
153
154
155
156
    -> decltype(x.compute_shape(inputs, {}))
{
    return x.compute_shape(inputs, {});
}

Shucai Xiao's avatar
Shucai Xiao committed
157
template <class T>
158
shape compute_shape_op(rank<0>, const T& x, const std::vector<shape>&)
Shucai Xiao's avatar
Shucai Xiao committed
159
160
161
162
163
164
{
    std::string name = x.name();
    MIGRAPHX_THROW("Shape not computable: " + name);
}

template <class T>
165
shape compute_shape_op(const T& x, const std::vector<shape>& inputs)
Shucai Xiao's avatar
Shucai Xiao committed
166
{
167
    return compute_shape_op(rank<3>{}, x, inputs);
Shucai Xiao's avatar
Shucai Xiao committed
168
169
170
}

template <class T>
171
172
173
174
auto mod_compute_shape_op(rank<1>,
                          const T& x,
                          const std::vector<shape>& inputs,
                          const std::vector<module_ref>& mod_args)
Shucai Xiao's avatar
Shucai Xiao committed
175
176
177
178
179
180
    -> decltype(x.compute_shape(inputs, mod_args))
{
    return x.compute_shape(inputs, mod_args);
}

template <class T>
181
182
183
184
shape mod_compute_shape_op(rank<0>,
                           const T& x,
                           const std::vector<shape>& inputs,
                           const std::vector<module_ref>& mod_args)
Shucai Xiao's avatar
Shucai Xiao committed
185
{
186
187
    if(mod_args.empty())
        return compute_shape_op(x, inputs);
Shucai Xiao's avatar
Shucai Xiao committed
188
189
190
191
192
    std::string name = x.name();
    MIGRAPHX_THROW("Shape not computable: " + name);
}

template <class T>
193
194
195
shape mod_compute_shape_op(const T& x,
                           const std::vector<shape>& inputs,
                           const std::vector<module_ref>& mod_args)
Shucai Xiao's avatar
Shucai Xiao committed
196
{
197
    return mod_compute_shape_op(rank<1>{}, x, inputs, mod_args);
Shucai Xiao's avatar
Shucai Xiao committed
198
199
}

Paul's avatar
Paul committed
200
template <class T>
Shucai Xiao's avatar
Shucai Xiao committed
201
auto compute_op(rank<1>,
Paul's avatar
Paul committed
202
203
204
205
                const T& x,
                context& ctx,
                const shape& output_shape,
                const std::vector<argument>& input)
206
207
208
    -> decltype(x.compute(auto_any_cast(ctx),
                          make_compute_output_shape(pack(x, output_shape, input)),
                          input))
Paul's avatar
Paul committed
209
{
210
211
    return x.compute(
        auto_any_cast(ctx), make_compute_output_shape(pack(x, output_shape, input)), input);
Paul's avatar
Paul committed
212
213
214
215
216
}

template <class T>
argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
{
Paul's avatar
Paul committed
217
    std::string name = x.name();
Paul's avatar
Paul committed
218
    MIGRAPHX_THROW("Not computable: " + name);
Paul's avatar
Paul committed
219
220
}

Paul's avatar
Paul committed
221
template <class T>
Paul's avatar
Paul committed
222
223
argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
Paul's avatar
Paul committed
224
{
Shucai Xiao's avatar
Shucai Xiao committed
225
    return compute_op(rank<1>{}, x, ctx, output_shape, input);
Paul's avatar
Paul committed
226
227
228
}

template <class T>
Shucai Xiao's avatar
Shucai Xiao committed
229
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
230
    -> decltype(x.compute(make_compute_output_shape(pack(x, output_shape, input)), input))
Paul's avatar
Paul committed
231
{
232
    return x.compute(make_compute_output_shape(pack(x, output_shape, input)), input);
Paul's avatar
Paul committed
233
234
235
}

template <class T>
Shucai Xiao's avatar
Shucai Xiao committed
236
argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&)
Paul's avatar
Paul committed
237
238
{
    std::string name = x.name();
Shucai Xiao's avatar
Shucai Xiao committed
239
    MIGRAPHX_THROW("Not computable: " + name);
Paul's avatar
Paul committed
240
241
242
}

template <class T>
Shucai Xiao's avatar
Shucai Xiao committed
243
244
245
246
247
248
249
250
251
252
253
argument compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input)
{
    return compute_op(rank<1>{}, x, output_shape, input);
}

template <class T, class F>
auto compute_op(rank<1>,
                const T& x,
                const shape& output,
                const std::vector<argument>& inputs,
                const std::vector<module_ref>& module_args,
254
255
256
                F f)
    -> decltype(
        x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f))
Shucai Xiao's avatar
Shucai Xiao committed
257
{
258
    return x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f);
Shucai Xiao's avatar
Shucai Xiao committed
259
260
261
262
263
264
265
266
267
}

template <class T, class F>
argument compute_op(rank<0>,
                    const T& x,
                    const shape&,
                    const std::vector<argument>&,
                    const std::vector<module_ref>&,
                    F)
Paul's avatar
Paul committed
268
269
270
271
272
{
    std::string name = x.name();
    MIGRAPHX_THROW("Not computable: " + name);
}

Shucai Xiao's avatar
Shucai Xiao committed
273
274
275
276
277
278
template <class T, class F>
argument compute_op(const T& x,
                    const shape& output,
                    const std::vector<argument>& inputs,
                    const std::vector<module_ref>& module_args,
                    F f)
Paul's avatar
Paul committed
279
{
Shucai Xiao's avatar
Shucai Xiao committed
280
    return compute_op(rank<1>{}, x, output, inputs, module_args, f);
Paul's avatar
Paul committed
281
282
}

Shucai Xiao's avatar
Shucai Xiao committed
283
284
285
286
287
288
289
template <class T, class F>
auto compute_op(rank<4>,
                const T& x,
                context& ctx,
                const shape& output,
                const std::vector<argument>& inputs,
                const std::vector<module_ref>& module_args,
290
291
292
293
294
                F f) -> decltype(x.compute(auto_any_cast(ctx),
                                           make_compute_output_shape(pack(x, output, inputs)),
                                           inputs,
                                           module_args,
                                           f))
Shucai Xiao's avatar
Shucai Xiao committed
295
{
296
297
298
299
300
    return x.compute(auto_any_cast(ctx),
                     make_compute_output_shape(pack(x, output, inputs)),
                     inputs,
                     module_args,
                     f);
Shucai Xiao's avatar
Shucai Xiao committed
301
302
}

Shucai Xiao's avatar
Shucai Xiao committed
303
template <class T, class F>
Shucai Xiao's avatar
Shucai Xiao committed
304
auto compute_op(rank<3>,
Shucai Xiao's avatar
Shucai Xiao committed
305
                const T& x,
Shucai Xiao's avatar
Shucai Xiao committed
306
307
                context&,
                const shape& output,
Shucai Xiao's avatar
Shucai Xiao committed
308
309
                const std::vector<argument>& inputs,
                const std::vector<module_ref>& module_args,
310
311
312
                F f)
    -> decltype(
        x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f))
Shucai Xiao's avatar
Shucai Xiao committed
313
{
314
    return x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f);
Shucai Xiao's avatar
Shucai Xiao committed
315
316
317
}

template <class T, class F>
Shucai Xiao's avatar
Shucai Xiao committed
318
319
320
321
322
323
auto compute_op(rank<2>,
                const T& x,
                context&,
                const shape& output,
                const std::vector<argument>& inputs,
                const std::vector<module_ref>&,
324
325
                F)
    -> decltype(x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs))
Shucai Xiao's avatar
Shucai Xiao committed
326
{
327
    return x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs);
Shucai Xiao's avatar
Shucai Xiao committed
328
329
330
331
332
333
334
335
336
}

template <class T, class F>
auto compute_op(rank<1>,
                const T& x,
                context& ctx,
                const shape& output,
                const std::vector<argument>& inputs,
                const std::vector<module_ref>&,
337
338
339
                F) -> decltype(x.compute(auto_any_cast(ctx),
                                         make_compute_output_shape(pack(x, output, inputs)),
                                         inputs))
Shucai Xiao's avatar
Shucai Xiao committed
340
{
341
342
    return x.compute(
        auto_any_cast(ctx), make_compute_output_shape(pack(x, output, inputs)), inputs);
Shucai Xiao's avatar
Shucai Xiao committed
343
344
345
346
347
348
349
350
351
352
}

template <class T, class F>
argument compute_op(rank<0>,
                    const T& x,
                    context&,
                    const shape&,
                    const std::vector<argument>&,
                    const std::vector<module_ref>&,
                    F)
Shucai Xiao's avatar
Shucai Xiao committed
353
354
355
356
357
358
359
{
    std::string name = x.name();
    MIGRAPHX_THROW("Not computable: " + name);
}

template <class T, class F>
argument compute_op(const T& x,
Shucai Xiao's avatar
Shucai Xiao committed
360
361
                    context& ctx,
                    const shape& output,
Shucai Xiao's avatar
Shucai Xiao committed
362
363
364
365
                    const std::vector<argument>& inputs,
                    const std::vector<module_ref>& module_args,
                    F f)
{
Shucai Xiao's avatar
Shucai Xiao committed
366
    return compute_op(rank<4>{}, x, ctx, output, inputs, module_args, f);
Shucai Xiao's avatar
Shucai Xiao committed
367
368
}

Paul's avatar
Paul committed
369
370
371
372
373
template <class T>
auto is_context_free_op(rank<1>,
                        const T& x,
                        const shape& output_shape,
                        const std::vector<argument>& input)
374
375
    -> decltype(x.compute(make_compute_output_shape(pack(x, output_shape, input)), input),
                std::true_type{});
Paul's avatar
Paul committed
376
377
378
379
380
381
382
383
384
385

template <class T>
auto is_context_free_op(rank<0>, const T&, const shape&, const std::vector<argument>&)
    -> std::false_type;

template <class T>
auto is_context_free_op(const T& x) -> decltype(is_context_free_op(
    rank<1>{}, x, std::declval<const shape&>(), std::declval<std::vector<argument>>()))
{
    return {};
Paul's avatar
Paul committed
386
387
}

Shucai Xiao's avatar
Shucai Xiao committed
388
389
390
391
392
393
394
395
396
397
398
399
400
401
template <class T>
auto need_normalization_op(rank<1>, const T& x, const std::vector<shape>& inputs)
    -> decltype(x.normalize_compute_shape(inputs), std::true_type{});

template <class T>
auto need_normalization_op(rank<0>, const T&, const std::vector<shape>&) -> std::false_type;

template <class T>
auto need_normalization_op(const T& x)
    -> decltype(need_normalization_op(rank<1>{}, x, std::declval<std::vector<shape>>()))
{
    return {};
}

Paul's avatar
Paul committed
402
template <class T>
403
std::ptrdiff_t output_alias_op(const T&, const std::vector<shape>&)
Paul's avatar
Paul committed
404
405
406
407
{
    return -1;
}

Paul's avatar
Paul committed
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
template <class T>
auto finalize_op(
    rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
    -> decltype(x.finalize(auto_any_cast(ctx), output_shape, input), void())
{
    x.finalize(auto_any_cast(ctx), output_shape, input);
}

template <class T>
void finalize_op(rank<0>, T&, context&, const shape&, const std::vector<shape>&)
{
}

template <class T>
void finalize_op(T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
{
    finalize_op(rank<1>{}, x, ctx, output_shape, input);
}

template <class T>
auto has_finalize_op(
    rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
    -> decltype(x.finalize(auto_any_cast(ctx), output_shape, input), std::true_type{});

template <class T>
auto has_finalize_op(rank<0>, T&, context&, const shape&, const std::vector<shape>&)
    -> std::false_type;

template <class T>
auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
                                                           std::declval<T&>(),
                                                           std::declval<context&>(),
                                                           std::declval<const shape&>(),
                                                           std::declval<std::vector<shape>>()))
{
    return {};
}

446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
template <class T>
auto compile_op(
    rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
    -> decltype(x.compile(auto_any_cast(ctx), output_shape, input))
{
    return x.compile(auto_any_cast(ctx), output_shape, input);
}

template <class T>
value compile_op(rank<0>, T&, context&, const shape&, const std::vector<shape>&)
{
    return value::object{};
}

template <class T>
value compile_op(const T& x,
                 context& ctx,
                 const shape& output_shape,
                 const std::vector<shape>& input)
{
    return compile_op(rank<1>{}, x, ctx, output_shape, input);
}

469
470
471
472
473
474
template <class T>
value attributes_op(const T&)
{
    return value::object{};
}

475
476
477
478
479
480
481
482
483
template <class T>
value to_value_op(const T& x)
{
    return migraphx::to_value(x);
}

template <class T>
void from_value_op(T& x, const value& v)
{
484
485
    if(not(v.is_object() or (v.empty() and v.is_array())))
        MIGRAPHX_THROW("Value is not an object");
486
487
488
    return migraphx::from_value(v, x);
}

489
template <class T>
490
lifetime get_lifetime_op(const T&)
491
{
492
    return lifetime::local;
493
494
}

495
496
} // namespace detail

497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
#ifdef TYPE_ERASED_DECLARATION

// Type-erased interface for:
struct operation
{
    //
    std::string name() const;
    // (optional)
    bool is_context_free() const;
    // (optional)
    bool need_normalization() const;
    // (optional)
    bool has_finalize() const;
    // (optional)
    lifetime get_lifetime() const;
    // (optional)
    std::ptrdiff_t output_alias(const std::vector<shape>& input) const;
    // (optional)
    value compile(context& ctx, const shape& output, const std::vector<shape>& input);
    // (optional)
    void finalize(context& ctx, const shape& output, const std::vector<shape>& input);
    // (optional)
    shape compute_shape(const std::vector<shape>& input) const;
    // (optional)
    shape compute_shape(const std::vector<shape>& inputs,
                        const std::vector<module_ref>& mod_args) const;
    // (optional)
    argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const;
    // (optional)
    argument compute(const shape& output, const std::vector<argument>& input) const;
    // (optional)
    argument compute(const shape& output,
                     const std::vector<argument>& input,
                     const std::vector<module_ref>& module_args,
                     std::function<std::vector<argument>(
                         module_ref&, const std::unordered_map<std::string, argument>&)> run) const;
    // (optional)
    argument compute(context& ctx,
                     const shape& output,
                     const std::vector<argument>& input,
                     const std::vector<module_ref>& module_args,
                     std::function<std::vector<argument>(
                         module_ref&, const std::unordered_map<std::string, argument>&)> run) const;
    // (optional)
    value to_value() const;
    // (optional)
    void from_value(const value& v);
    // (optional)
    value attributes() const;
    //
    friend std::ostream& operator<<(std::ostream& os, const operation& op);
    //
    friend bool operator==(const operation& x, const operation& y);
};

#else
Paul's avatar
Paul committed
553

Paul's avatar
Paul committed
554
struct operation
Paul's avatar
Paul committed
555
{
Paul's avatar
Paul committed
556
    // Constructors
Paul's avatar
Paul committed
557
    operation() = default;
Paul's avatar
Paul committed
558

Paul's avatar
Paul committed
559
    template <typename PrivateDetailTypeErasedT>
Paul's avatar
Paul committed
560
    operation(PrivateDetailTypeErasedT value)
Paul's avatar
Paul committed
561
562
563
564
        : private_detail_te_handle_mem_var(
              std::make_shared<private_detail_te_handle_type<
                  typename std::remove_reference<PrivateDetailTypeErasedT>::type>>(
                  std::forward<PrivateDetailTypeErasedT>(value)))
Paul's avatar
Paul committed
565
566
567
568
    {
    }

    // Assignment
Paul's avatar
Paul committed
569
    template <typename PrivateDetailTypeErasedT>
Paul's avatar
Paul committed
570
    operation& operator=(PrivateDetailTypeErasedT value)
Paul's avatar
Paul committed
571
    {
Paul Fultz II's avatar
Paul Fultz II committed
572
573
574
575
576
577
578
579
580
581
582
        using std::swap;
        auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
        if(derived and private_detail_te_handle_mem_var.unique())
        {
            *derived = std::forward<PrivateDetailTypeErasedT>(value);
        }
        else
        {
            operation rhs(value);
            swap(private_detail_te_handle_mem_var, rhs.private_detail_te_handle_mem_var);
        }
Paul's avatar
Paul committed
583
584
585
        return *this;
    }

Paul's avatar
Paul committed
586
587
588
589
    // Cast
    template <typename PrivateDetailTypeErasedT>
    PrivateDetailTypeErasedT* any_cast()
    {
Paul Fultz II's avatar
Paul Fultz II committed
590
        return this->type_id() == typeid(PrivateDetailTypeErasedT)
Paul's avatar
Paul committed
591
592
593
594
595
596
597
598
599
600
                   ? std::addressof(static_cast<private_detail_te_handle_type<
                                        typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
                                        private_detail_te_get_handle())
                                        .private_detail_te_value)
                   : nullptr;
    }

    template <typename PrivateDetailTypeErasedT>
    const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
    {
Paul Fultz II's avatar
Paul Fultz II committed
601
        return this->type_id() == typeid(PrivateDetailTypeErasedT)
Paul's avatar
Paul committed
602
603
604
605
606
607
608
                   ? std::addressof(static_cast<const private_detail_te_handle_type<
                                        typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
                                        private_detail_te_get_handle())
                                        .private_detail_te_value)
                   : nullptr;
    }

Paul's avatar
Paul committed
609
610
611
612
613
614
615
616
    const std::type_info& type_id() const
    {
        if(private_detail_te_handle_empty())
            return typeid(std::nullptr_t);
        else
            return private_detail_te_get_handle().type();
    }

Paul's avatar
Paul committed
617
618
    std::string name() const
    {
Paul's avatar
Paul committed
619
620
        assert((*this).private_detail_te_handle_mem_var);
        return (*this).private_detail_te_get_handle().name();
Paul's avatar
Paul committed
621
622
    }

Paul's avatar
Paul committed
623
624
625
626
627
628
    bool is_context_free() const
    {
        assert((*this).private_detail_te_handle_mem_var);
        return (*this).private_detail_te_get_handle().is_context_free();
    }

Shucai Xiao's avatar
Shucai Xiao committed
629
630
631
632
633
634
    bool need_normalization() const
    {
        assert((*this).private_detail_te_handle_mem_var);
        return (*this).private_detail_te_get_handle().need_normalization();
    }

Paul's avatar
Paul committed
635
636
637
638
639
640
    bool has_finalize() const
    {
        assert((*this).private_detail_te_handle_mem_var);
        return (*this).private_detail_te_get_handle().has_finalize();
    }

641
    lifetime get_lifetime() const
642
643
    {
        assert((*this).private_detail_te_handle_mem_var);
644
        return (*this).private_detail_te_get_handle().get_lifetime();
645
646
    }

Paul's avatar
Paul committed
647
    std::ptrdiff_t output_alias(const std::vector<shape>& input) const
Paul's avatar
Paul committed
648
649
650
651
652
    {
        assert((*this).private_detail_te_handle_mem_var);
        return (*this).private_detail_te_get_handle().output_alias(input);
    }

653
654
655
656
657
658
    value compile(context& ctx, const shape& output, const std::vector<shape>& input)
    {
        assert((*this).private_detail_te_handle_mem_var);
        return (*this).private_detail_te_get_handle().compile(ctx, output, input);
    }

Paul's avatar
Paul committed
659
660
661
662
663
664
    void finalize(context& ctx, const shape& output, const std::vector<shape>& input)
    {
        assert((*this).private_detail_te_handle_mem_var);
        (*this).private_detail_te_get_handle().finalize(ctx, output, input);
    }

Paul's avatar
Paul committed
665
    shape compute_shape(const std::vector<shape>& input) const
Paul's avatar
Paul committed
666
    {
Paul's avatar
Paul committed
667
        assert((*this).private_detail_te_handle_mem_var);
Paul's avatar
Paul committed
668
        return (*this).private_detail_te_get_handle().compute_shape(input);
Paul's avatar
Paul committed
669
670
    }

Shucai Xiao's avatar
Shucai Xiao committed
671
672
673
674
675
676
677
    shape compute_shape(const std::vector<shape>& inputs,
                        const std::vector<module_ref>& mod_args) const
    {
        assert((*this).private_detail_te_handle_mem_var);
        return (*this).private_detail_te_get_handle().compute_shape(inputs, mod_args);
    }

Paul's avatar
Paul committed
678
    argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const
Paul's avatar
Paul committed
679
    {
Paul's avatar
Paul committed
680
        assert((*this).private_detail_te_handle_mem_var);
Paul's avatar
Paul committed
681
        return (*this).private_detail_te_get_handle().compute(ctx, output, input);
Paul's avatar
Paul committed
682
683
    }

Paul's avatar
Paul committed
684
685
686
687
688
689
    argument compute(const shape& output, const std::vector<argument>& input) const
    {
        assert((*this).private_detail_te_handle_mem_var);
        return (*this).private_detail_te_get_handle().compute(output, input);
    }

Shucai Xiao's avatar
Shucai Xiao committed
690
691
692
693
694
    argument compute(const shape& output,
                     const std::vector<argument>& input,
                     const std::vector<module_ref>& module_args,
                     std::function<std::vector<argument>(
                         module_ref&, const std::unordered_map<std::string, argument>&)> run) const
Shucai Xiao's avatar
Shucai Xiao committed
695
696
    {
        assert((*this).private_detail_te_handle_mem_var);
Shucai Xiao's avatar
Shucai Xiao committed
697
698
699
700
701
702
703
704
705
706
707
708
709
710
        return (*this).private_detail_te_get_handle().compute(
            output, input, module_args, std::move(run));
    }

    argument compute(context& ctx,
                     const shape& output,
                     const std::vector<argument>& input,
                     const std::vector<module_ref>& module_args,
                     std::function<std::vector<argument>(
                         module_ref&, const std::unordered_map<std::string, argument>&)> run) const
    {
        assert((*this).private_detail_te_handle_mem_var);
        return (*this).private_detail_te_get_handle().compute(
            ctx, output, input, module_args, std::move(run));
Shucai Xiao's avatar
Shucai Xiao committed
711
712
    }

713
714
715
716
717
718
719
720
721
722
723
724
    value to_value() const
    {
        assert((*this).private_detail_te_handle_mem_var);
        return (*this).private_detail_te_get_handle().to_value();
    }

    void from_value(const value& v)
    {
        assert((*this).private_detail_te_handle_mem_var);
        (*this).private_detail_te_get_handle().from_value(v);
    }

725
726
727
728
729
730
    value attributes() const
    {
        assert((*this).private_detail_te_handle_mem_var);
        return (*this).private_detail_te_get_handle().attributes();
    }

Paul's avatar
Paul committed
731
732
733
734
    friend std::ostream& operator<<(std::ostream& os, const operation& op)
    {
        assert(op.private_detail_te_handle_mem_var);
        return op.private_detail_te_get_handle().operator_shift_left(os);
Paul's avatar
Paul committed
735
736
    }

Paul's avatar
Paul committed
737
738
739
740
741
742
    friend bool operator==(const operation& x, const operation& y)
    {
        assert(x.private_detail_te_handle_mem_var);
        return x.private_detail_te_get_handle().operator==(y);
    }

Paul's avatar
Paul committed
743
744
745
746
747
748
    friend bool is_shared(const operation& private_detail_x, const operation& private_detail_y)
    {
        return private_detail_x.private_detail_te_handle_mem_var ==
               private_detail_y.private_detail_te_handle_mem_var;
    }

Paul's avatar
Paul committed
749
    private:
Paul's avatar
Paul committed
750
    struct private_detail_te_handle_base_type
Paul's avatar
Paul committed
751
    {
Paul's avatar
Paul committed
752
753
        virtual ~private_detail_te_handle_base_type() {}
        virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
Paul's avatar
Paul committed
754
        virtual const std::type_info& type() const                                = 0;
Paul's avatar
Paul committed
755

Paul's avatar
Paul committed
756
757
        virtual std::string name() const                                           = 0;
        virtual bool is_context_free() const                                       = 0;
Shucai Xiao's avatar
Shucai Xiao committed
758
        virtual bool need_normalization() const                                    = 0;
Paul's avatar
Paul committed
759
        virtual bool has_finalize() const                                          = 0;
760
        virtual lifetime get_lifetime() const                                      = 0;
Paul's avatar
Paul committed
761
        virtual std::ptrdiff_t output_alias(const std::vector<shape>& input) const = 0;
762
763
        virtual value
        compile(context& ctx, const shape& output, const std::vector<shape>& input) = 0;
Paul's avatar
Paul committed
764
765
766
        virtual void
        finalize(context& ctx, const shape& output, const std::vector<shape>& input) = 0;
        virtual shape compute_shape(const std::vector<shape>& input) const           = 0;
Shucai Xiao's avatar
Shucai Xiao committed
767
768
        virtual shape compute_shape(const std::vector<shape>& inputs,
                                    const std::vector<module_ref>& mod_args) const   = 0;
Paul's avatar
Paul committed
769
        virtual argument
Paul's avatar
Paul committed
770
771
        compute(context& ctx, const shape& output, const std::vector<argument>& input) const    = 0;
        virtual argument compute(const shape& output, const std::vector<argument>& input) const = 0;
Shucai Xiao's avatar
Shucai Xiao committed
772
        virtual argument
Shucai Xiao's avatar
Shucai Xiao committed
773
774
775
776
777
778
779
780
781
        compute(const shape& output,
                const std::vector<argument>& input,
                const std::vector<module_ref>& module_args,
                std::function<std::vector<argument>(
                    module_ref&, const std::unordered_map<std::string, argument>&)> run) const = 0;
        virtual argument
        compute(context& ctx,
                const shape& output,
                const std::vector<argument>& input,
Shucai Xiao's avatar
Shucai Xiao committed
782
783
                const std::vector<module_ref>& module_args,
                std::function<std::vector<argument>(
Shucai Xiao's avatar
Shucai Xiao committed
784
785
786
787
788
789
                    module_ref&, const std::unordered_map<std::string, argument>&)> run) const = 0;
        virtual value to_value() const                                                         = 0;
        virtual void from_value(const value& v)                                                = 0;
        virtual value attributes() const                                                       = 0;
        virtual std::ostream& operator_shift_left(std::ostream& os) const                      = 0;
        virtual bool operator==(const operation& y) const                                      = 0;
Paul's avatar
Paul committed
790
791
    };

792
793
794
795
796
797
798
799
800
801
802
803
804
    template <class T>
    static auto private_detail_te_default_is_context_free(char, T&& private_detail_te_self)
        -> decltype(private_detail_te_self.is_context_free())
    {
        return private_detail_te_self.is_context_free();
    }

    template <class T>
    static bool private_detail_te_default_is_context_free(float, T&& private_detail_te_self)
    {
        return detail::is_context_free_op(private_detail_te_self);
    }

Shucai Xiao's avatar
Shucai Xiao committed
805
806
807
808
809
810
811
812
813
814
815
816
817
    template <class T>
    static auto private_detail_te_default_need_normalization(char, T&& private_detail_te_self)
        -> decltype(private_detail_te_self.need_normalization())
    {
        return private_detail_te_self.need_normalization();
    }

    template <class T>
    static bool private_detail_te_default_need_normalization(float, T&& private_detail_te_self)
    {
        return detail::need_normalization_op(private_detail_te_self);
    }

818
819
820
821
822
823
824
825
826
827
828
829
830
    template <class T>
    static auto private_detail_te_default_has_finalize(char, T&& private_detail_te_self)
        -> decltype(private_detail_te_self.has_finalize())
    {
        return private_detail_te_self.has_finalize();
    }

    template <class T>
    static bool private_detail_te_default_has_finalize(float, T&& private_detail_te_self)
    {
        return detail::has_finalize_op(private_detail_te_self);
    }

831
    template <class T>
832
833
    static auto private_detail_te_default_get_lifetime(char, T&& private_detail_te_self)
        -> decltype(private_detail_te_self.get_lifetime())
834
    {
835
        return private_detail_te_self.get_lifetime();
836
837
838
    }

    template <class T>
839
    static lifetime private_detail_te_default_get_lifetime(float, T&& private_detail_te_self)
840
    {
841
        return detail::get_lifetime_op(private_detail_te_self);
842
843
    }

844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
    template <class T>
    static auto private_detail_te_default_output_alias(char,
                                                       T&& private_detail_te_self,
                                                       const std::vector<shape>& input)
        -> decltype(private_detail_te_self.output_alias(input))
    {
        return private_detail_te_self.output_alias(input);
    }

    template <class T>
    static std::ptrdiff_t private_detail_te_default_output_alias(float,
                                                                 T&& private_detail_te_self,
                                                                 const std::vector<shape>& input)
    {
        return detail::output_alias_op(private_detail_te_self, input);
    }

861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
    template <class T>
    static auto private_detail_te_default_compile(char,
                                                  T&& private_detail_te_self,
                                                  context& ctx,
                                                  const shape& output,
                                                  const std::vector<shape>& input)
        -> decltype(private_detail_te_self.compile(ctx, output, input))
    {
        return private_detail_te_self.compile(ctx, output, input);
    }

    template <class T>
    static value private_detail_te_default_compile(float,
                                                   T&& private_detail_te_self,
                                                   context& ctx,
                                                   const shape& output,
                                                   const std::vector<shape>& input)
    {
        return detail::compile_op(private_detail_te_self, ctx, output, input);
    }

882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
    template <class T>
    static auto private_detail_te_default_finalize(char,
                                                   T&& private_detail_te_self,
                                                   context& ctx,
                                                   const shape& output,
                                                   const std::vector<shape>& input)
        -> decltype(private_detail_te_self.finalize(ctx, output, input))
    {
        private_detail_te_self.finalize(ctx, output, input);
    }

    template <class T>
    static void private_detail_te_default_finalize(float,
                                                   T&& private_detail_te_self,
                                                   context& ctx,
                                                   const shape& output,
                                                   const std::vector<shape>& input)
    {
        detail::finalize_op(private_detail_te_self, ctx, output, input);
    }

Shucai Xiao's avatar
Shucai Xiao committed
903
904
905
906
907
908
909
910
911
912
913
914
915
916
    template <class T>
    static auto private_detail_te_default_compute_shape(char,
                                                        T&& private_detail_te_self,
                                                        const std::vector<shape>& input)
        -> decltype(private_detail_te_self.compute_shape(input))
    {
        return private_detail_te_self.compute_shape(input);
    }

    template <class T>
    static shape private_detail_te_default_compute_shape(float,
                                                         T&& private_detail_te_self,
                                                         const std::vector<shape>& input)
    {
917
        return detail::compute_shape_op(private_detail_te_self, input);
Shucai Xiao's avatar
Shucai Xiao committed
918
919
    }

Shucai Xiao's avatar
Shucai Xiao committed
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
    template <class T>
    static auto private_detail_te_default_compute_shape(char,
                                                        T&& private_detail_te_self,
                                                        const std::vector<shape>& inputs,
                                                        const std::vector<module_ref>& mod_args)
        -> decltype(private_detail_te_self.compute_shape(inputs, mod_args))
    {
        return private_detail_te_self.compute_shape(inputs, mod_args);
    }

    template <class T>
    static shape private_detail_te_default_compute_shape(float,
                                                         T&& private_detail_te_self,
                                                         const std::vector<shape>& inputs,
                                                         const std::vector<module_ref>& mod_args)
    {
936
        return detail::mod_compute_shape_op(private_detail_te_self, inputs, mod_args);
Shucai Xiao's avatar
Shucai Xiao committed
937
938
    }

939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
    template <class T>
    static auto private_detail_te_default_compute(char,
                                                  T&& private_detail_te_self,
                                                  context& ctx,
                                                  const shape& output,
                                                  const std::vector<argument>& input)
        -> decltype(private_detail_te_self.compute(ctx, output, input))
    {
        return private_detail_te_self.compute(ctx, output, input);
    }

    template <class T>
    static argument private_detail_te_default_compute(float,
                                                      T&& private_detail_te_self,
                                                      context& ctx,
                                                      const shape& output,
                                                      const std::vector<argument>& input)
    {
        return detail::compute_op(private_detail_te_self, ctx, output, input);
    }

    template <class T>
    static auto private_detail_te_default_compute(char,
                                                  T&& private_detail_te_self,
                                                  const shape& output,
                                                  const std::vector<argument>& input)
        -> decltype(private_detail_te_self.compute(output, input))
    {
        return private_detail_te_self.compute(output, input);
    }

    template <class T>
    static argument private_detail_te_default_compute(float,
                                                      T&& private_detail_te_self,
                                                      const shape& output,
                                                      const std::vector<argument>& input)
    {
        return detail::compute_op(private_detail_te_self, output, input);
    }

Shucai Xiao's avatar
Shucai Xiao committed
979
980
981
982
    template <class T>
    static auto private_detail_te_default_compute(
        char,
        T&& private_detail_te_self,
Shucai Xiao's avatar
Shucai Xiao committed
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
        const shape& output,
        const std::vector<argument>& input,
        const std::vector<module_ref>& module_args,
        std::function<std::vector<argument>(module_ref&,
                                            const std::unordered_map<std::string, argument>&)> run)
        -> decltype(private_detail_te_self.compute(output, input, module_args, std::move(run)))
    {
        return private_detail_te_self.compute(output, input, module_args, std::move(run));
    }

    template <class T>
    static argument private_detail_te_default_compute(
        float,
        T&& private_detail_te_self,
        const shape& output,
        const std::vector<argument>& input,
        const std::vector<module_ref>& module_args,
        std::function<std::vector<argument>(module_ref&,
                                            const std::unordered_map<std::string, argument>&)> run)
    {
        return detail::compute_op(
            private_detail_te_self, output, input, module_args, std::move(run));
    }

    template <class T>
    static auto private_detail_te_default_compute(
        char,
        T&& private_detail_te_self,
        context& ctx,
        const shape& output,
Shucai Xiao's avatar
Shucai Xiao committed
1013
1014
        const std::vector<argument>& input,
        const std::vector<module_ref>& module_args,
Shucai Xiao's avatar
Shucai Xiao committed
1015
1016
1017
        std::function<std::vector<argument>(module_ref&,
                                            const std::unordered_map<std::string, argument>&)> run)
        -> decltype(private_detail_te_self.compute(ctx, output, input, module_args, std::move(run)))
Shucai Xiao's avatar
Shucai Xiao committed
1018
    {
Shucai Xiao's avatar
Shucai Xiao committed
1019
        return private_detail_te_self.compute(ctx, output, input, module_args, std::move(run));
Shucai Xiao's avatar
Shucai Xiao committed
1020
1021
1022
1023
1024
1025
    }

    template <class T>
    static argument private_detail_te_default_compute(
        float,
        T&& private_detail_te_self,
Shucai Xiao's avatar
Shucai Xiao committed
1026
1027
        context& ctx,
        const shape& output,
Shucai Xiao's avatar
Shucai Xiao committed
1028
1029
        const std::vector<argument>& input,
        const std::vector<module_ref>& module_args,
Shucai Xiao's avatar
Shucai Xiao committed
1030
1031
        std::function<std::vector<argument>(module_ref&,
                                            const std::unordered_map<std::string, argument>&)> run)
Shucai Xiao's avatar
Shucai Xiao committed
1032
    {
Shucai Xiao's avatar
Shucai Xiao committed
1033
1034
        return detail::compute_op(
            private_detail_te_self, ctx, output, input, module_args, std::move(run));
Shucai Xiao's avatar
Shucai Xiao committed
1035
1036
    }

1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
    template <class T>
    static auto private_detail_te_default_to_value(char, T&& private_detail_te_self)
        -> decltype(private_detail_te_self.to_value())
    {
        return private_detail_te_self.to_value();
    }

    template <class T>
    static value private_detail_te_default_to_value(float, T&& private_detail_te_self)
    {
        return detail::to_value_op(private_detail_te_self);
    }

    template <class T>
    static auto
    private_detail_te_default_from_value(char, T&& private_detail_te_self, const value& v)
        -> decltype(private_detail_te_self.from_value(v))
    {
        private_detail_te_self.from_value(v);
    }

    template <class T>
    static void
    private_detail_te_default_from_value(float, T&& private_detail_te_self, const value& v)
    {
        detail::from_value_op(private_detail_te_self, v);
    }

1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
    template <class T>
    static auto private_detail_te_default_attributes(char, T&& private_detail_te_self)
        -> decltype(private_detail_te_self.attributes())
    {
        return private_detail_te_self.attributes();
    }

    template <class T>
    static value private_detail_te_default_attributes(float, T&& private_detail_te_self)
    {
        return detail::attributes_op(private_detail_te_self);
    }

Paul's avatar
Paul committed
1078
1079
    template <typename PrivateDetailTypeErasedT>
    struct private_detail_te_handle_type : private_detail_te_handle_base_type
Paul's avatar
Paul committed
1080
    {
Paul's avatar
Paul committed
1081
1082
1083
1084
1085
1086
        template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
        private_detail_te_handle_type(
            PrivateDetailTypeErasedT value,
            typename std::enable_if<std::is_reference<PrivateDetailTypeErasedU>::value>::type* =
                nullptr)
            : private_detail_te_value(value)
Paul's avatar
Paul committed
1087
1088
1089
        {
        }

Paul's avatar
Paul committed
1090
1091
1092
        template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
        private_detail_te_handle_type(
            PrivateDetailTypeErasedT value,
1093
            typename std::enable_if<not std::is_reference<PrivateDetailTypeErasedU>::value,
Paul's avatar
Paul committed
1094
1095
                                    int>::type* = nullptr) noexcept
            : private_detail_te_value(std::move(value))
Paul's avatar
Paul committed
1096
1097
1098
        {
        }

Paul's avatar
Paul committed
1099
        std::shared_ptr<private_detail_te_handle_base_type> clone() const override
Paul's avatar
Paul committed
1100
        {
Paul's avatar
Paul committed
1101
            return std::make_shared<private_detail_te_handle_type>(private_detail_te_value);
Paul's avatar
Paul committed
1102
1103
        }

Paul's avatar
Paul committed
1104
        const std::type_info& type() const override { return typeid(private_detail_te_value); }
Paul's avatar
Paul committed
1105

Paul's avatar
Paul committed
1106
        std::string name() const override { return private_detail_te_value.name(); }
Paul's avatar
Paul committed
1107

Paul's avatar
Paul committed
1108
1109
1110
        bool is_context_free() const override
        {

1111
            return private_detail_te_default_is_context_free(char(0), private_detail_te_value);
Paul's avatar
Paul committed
1112
1113
        }

Shucai Xiao's avatar
Shucai Xiao committed
1114
1115
1116
1117
1118
1119
        bool need_normalization() const override
        {

            return private_detail_te_default_need_normalization(char(0), private_detail_te_value);
        }

1120
1121
1122
1123
1124
        bool has_finalize() const override
        {

            return private_detail_te_default_has_finalize(char(0), private_detail_te_value);
        }
Paul's avatar
Paul committed
1125

1126
        lifetime get_lifetime() const override
1127
1128
        {

1129
            return private_detail_te_default_get_lifetime(char(0), private_detail_te_value);
1130
1131
        }

Paul's avatar
Paul committed
1132
        std::ptrdiff_t output_alias(const std::vector<shape>& input) const override
Paul's avatar
Paul committed
1133
1134
        {

1135
            return private_detail_te_default_output_alias(char(0), private_detail_te_value, input);
Paul's avatar
Paul committed
1136
1137
        }

1138
1139
1140
1141
1142
1143
1144
        value compile(context& ctx, const shape& output, const std::vector<shape>& input) override
        {

            return private_detail_te_default_compile(
                char(0), private_detail_te_value, ctx, output, input);
        }

Paul's avatar
Paul committed
1145
1146
1147
        void finalize(context& ctx, const shape& output, const std::vector<shape>& input) override
        {

1148
1149
            private_detail_te_default_finalize(
                char(0), private_detail_te_value, ctx, output, input);
Paul's avatar
Paul committed
1150
1151
        }

Paul's avatar
Paul committed
1152
        shape compute_shape(const std::vector<shape>& input) const override
Paul's avatar
Paul committed
1153
        {
Paul's avatar
Paul committed
1154

Shucai Xiao's avatar
Shucai Xiao committed
1155
            return private_detail_te_default_compute_shape(char(0), private_detail_te_value, input);
Paul's avatar
Paul committed
1156
1157
        }

Shucai Xiao's avatar
Shucai Xiao committed
1158
1159
1160
1161
1162
1163
1164
1165
        shape compute_shape(const std::vector<shape>& inputs,
                            const std::vector<module_ref>& mod_args) const override
        {

            return private_detail_te_default_compute_shape(
                char(0), private_detail_te_value, inputs, mod_args);
        }

Paul's avatar
Paul committed
1166
1167
1168
        argument compute(context& ctx,
                         const shape& output,
                         const std::vector<argument>& input) const override
Paul's avatar
Paul committed
1169
        {
Paul's avatar
Paul committed
1170

1171
1172
            return private_detail_te_default_compute(
                char(0), private_detail_te_value, ctx, output, input);
Paul's avatar
Paul committed
1173
1174
        }

Paul's avatar
Paul committed
1175
1176
1177
        argument compute(const shape& output, const std::vector<argument>& input) const override
        {

1178
1179
            return private_detail_te_default_compute(
                char(0), private_detail_te_value, output, input);
Paul's avatar
Paul committed
1180
1181
        }

Shucai Xiao's avatar
Shucai Xiao committed
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
        argument compute(
            const shape& output,
            const std::vector<argument>& input,
            const std::vector<module_ref>& module_args,
            std::function<std::vector<argument>(
                module_ref&, const std::unordered_map<std::string, argument>&)> run) const override
        {

            return private_detail_te_default_compute(
                char(0), private_detail_te_value, output, input, module_args, std::move(run));
        }

        argument compute(
            context& ctx,
            const shape& output,
            const std::vector<argument>& input,
            const std::vector<module_ref>& module_args,
            std::function<std::vector<argument>(
                module_ref&, const std::unordered_map<std::string, argument>&)> run) const override
Shucai Xiao's avatar
Shucai Xiao committed
1201
1202
1203
        {

            return private_detail_te_default_compute(
Shucai Xiao's avatar
Shucai Xiao committed
1204
                char(0), private_detail_te_value, ctx, output, input, module_args, std::move(run));
Shucai Xiao's avatar
Shucai Xiao committed
1205
1206
        }

1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
        value to_value() const override
        {

            return private_detail_te_default_to_value(char(0), private_detail_te_value);
        }

        void from_value(const value& v) override
        {

            private_detail_te_default_from_value(char(0), private_detail_te_value, v);
        }

1219
1220
1221
1222
1223
1224
        value attributes() const override
        {

            return private_detail_te_default_attributes(char(0), private_detail_te_value);
        }

Paul's avatar
Paul committed
1225
1226
        std::ostream& operator_shift_left(std::ostream& os) const override
        {
1227
            using migraphx::detail::operation_operators::operator<<;
Paul's avatar
Paul committed
1228
1229
1230
            return os << private_detail_te_value;
        }

Paul's avatar
Paul committed
1231
1232
        bool operator==(const operation& y) const override
        {
1233
            using migraphx::detail::operation_operators::operator==;
Paul's avatar
Paul committed
1234
1235
1236
            return private_detail_te_value == y;
        }

Paul's avatar
Paul committed
1237
        PrivateDetailTypeErasedT private_detail_te_value;
Paul's avatar
Paul committed
1238
1239
    };

Paul's avatar
Paul committed
1240
1241
1242
    template <typename PrivateDetailTypeErasedT>
    struct private_detail_te_handle_type<std::reference_wrapper<PrivateDetailTypeErasedT>>
        : private_detail_te_handle_type<PrivateDetailTypeErasedT&>
Paul's avatar
Paul committed
1243
    {
Paul's avatar
Paul committed
1244
1245
        private_detail_te_handle_type(std::reference_wrapper<PrivateDetailTypeErasedT> ref)
            : private_detail_te_handle_type<PrivateDetailTypeErasedT&>(ref.get())
Paul's avatar
Paul committed
1246
1247
1248
1249
        {
        }
    };

Paul's avatar
Paul committed
1250
1251
1252
1253
1254
    bool private_detail_te_handle_empty() const
    {
        return private_detail_te_handle_mem_var == nullptr;
    }

Paul's avatar
Paul committed
1255
1256
    const private_detail_te_handle_base_type& private_detail_te_get_handle() const
    {
Paul's avatar
Paul committed
1257
        assert(private_detail_te_handle_mem_var != nullptr);
Paul's avatar
Paul committed
1258
1259
        return *private_detail_te_handle_mem_var;
    }
Paul's avatar
Paul committed
1260

Paul's avatar
Paul committed
1261
    private_detail_te_handle_base_type& private_detail_te_get_handle()
Paul's avatar
Paul committed
1262
    {
Paul's avatar
Paul committed
1263
        assert(private_detail_te_handle_mem_var != nullptr);
1264
        if(not private_detail_te_handle_mem_var.unique())
Paul's avatar
Paul committed
1265
1266
            private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
        return *private_detail_te_handle_mem_var;
Paul's avatar
Paul committed
1267
1268
    }

Paul's avatar
Paul committed
1269
    std::shared_ptr<private_detail_te_handle_base_type> private_detail_te_handle_mem_var;
Paul's avatar
Paul committed
1270
1271
};

Paul's avatar
Paul committed
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
template <typename ValueType>
inline const ValueType* any_cast(const operation* x)
{
    return x->any_cast<ValueType>();
}

template <typename ValueType>
inline ValueType* any_cast(operation* x)
{
    return x->any_cast<ValueType>();
}

template <typename ValueType>
inline ValueType& any_cast(operation& x)
{
Paul's avatar
Paul committed
1287
    auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
Paul's avatar
Paul committed
1288
1289
1290
1291
1292
1293
1294
1295
    if(y == nullptr)
        throw std::bad_cast();
    return *y;
}

template <typename ValueType>
inline const ValueType& any_cast(const operation& x)
{
Paul's avatar
Paul committed
1296
    const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
Paul's avatar
Paul committed
1297
1298
1299
1300
    if(y == nullptr)
        throw std::bad_cast();
    return *y;
}
1301
#endif
Paul's avatar
Paul committed
1302

1303
inline bool operator!=(const operation& x, const operation& y) { return not(x == y); }
Paul's avatar
Paul committed
1304

1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
inline value
compile(operation& op, context& ctx, const shape& output_shape, const std::vector<shape>& input)
{
    return op.compile(ctx, output_shape, input);
}
template <class Context>
inline value
compile(operation& op, Context& ctx, const shape& output_shape, const std::vector<shape>& input)
{
    dependent_type<context, Context> ctx2 = std::ref(ctx);
    return compile(op, ctx2, output_shape, input);
}
template <class T, class Context>
inline auto compile(T& op, Context& ctx, const shape& output_shape, const std::vector<shape>& input)
    -> decltype(op.compile(ctx, ctx, output_shape, input))
{
    return op.compile(ctx, ctx, output_shape, input);
}
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
inline shape compute_shape(const operation& op, const std::vector<shape>& inputs)
{
    return op.compute_shape(inputs);
}

template <class T>
inline auto compute_shape(const T& op, const std::vector<shape>& inputs)
    -> decltype(op.compute_shape(inputs))
{
    return op.compute_shape(inputs);
}

template <class T>
inline auto compute_shape(const T& op, const std::vector<shape>& inputs)
    -> decltype(op.normalize_compute_shape(inputs))
{
1339
    return detail::compute_shape_op(op, inputs);
1340
1341
}

Shucai Xiao's avatar
Shucai Xiao committed
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
inline shape compute_shape(const operation& op,
                           const std::vector<shape>& inputs,
                           const std::vector<module_ref>& mod_args)
{
    return op.compute_shape(inputs, mod_args);
}

template <class T>
inline auto compute_shape(const T& op,
                          const std::vector<shape>& inputs,
                          const std::vector<module_ref>& mod_args)
    -> decltype(op.compute_shape(inputs, mod_args))
{
    return op.compute_shape(inputs, mod_args);
}

template <class T>
inline auto compute_shape(const T& op,
                          const std::vector<shape>& inputs,
                          const std::vector<module_ref>& mod_args)
    -> decltype(op.normalize_compute_shape(inputs, mod_args))
{
1364
    return detail::compute_shape_op(op, inputs, mod_args);
Shucai Xiao's avatar
Shucai Xiao committed
1365
1366
}

Paul's avatar
Paul committed
1367
1368
1369
1370
1371
inline bool is_context_free(const operation& op) { return op.is_context_free(); }

template <class T>
bool is_context_free(const T& x)
{
1372
    return detail::is_context_free_op(x);
Paul's avatar
Paul committed
1373
1374
}

Shucai Xiao's avatar
Shucai Xiao committed
1375
1376
1377
1378
1379
1380
1381
1382
inline bool need_normalization(const operation& op) { return op.need_normalization(); }

template <class T>
bool need_normalization(const T& x)
{
    return detail::need_normalization_op(x);
}

Paul's avatar
Paul committed
1383
1384
1385
1386
1387
inline bool has_finalize(const operation& op) { return op.has_finalize(); }

template <class T>
bool has_finalize(const T& x)
{
1388
    return detail::has_finalize_op(x);
Paul's avatar
Paul committed
1389
1390
}

1391
1392
1393
void migraphx_to_value(value& v, const operation& op);
void migraphx_from_value(const value& v, operation& op);

Paul's avatar
Paul committed
1394
1395
#endif

Paul's avatar
Paul committed
1396
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
1397
} // namespace migraphx
Paul's avatar
Paul committed
1398
1399

#endif