operation.hpp 46.1 KB
Newer Older
Paul's avatar
Paul committed
1
2
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_OPERAND_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_OPERAND_HPP
Paul's avatar
Paul committed
3

Paul's avatar
Paul committed
4
#include <cassert>
Paul's avatar
Paul committed
5
#include <string>
Paul's avatar
Paul committed
6
#include <functional>
Paul's avatar
Paul committed
7
8
9
#include <memory>
#include <type_traits>
#include <utility>
Shucai Xiao's avatar
Shucai Xiao committed
10
#include <unordered_map>
Paul's avatar
Paul committed
11
12
#include <migraphx/reflect.hpp>
#include <migraphx/streamutils.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
13
#include <migraphx/normalize_attributes.hpp>
Paul's avatar
Paul committed
14
#include <migraphx/argument.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
15
#include <migraphx/module_ref.hpp>
16
#include <migraphx/serialize.hpp>
Paul's avatar
Paul committed
17
18
19
20
#include <migraphx/auto_any_cast.hpp>
#include <migraphx/config.hpp>

namespace migraphx {
Paul's avatar
Paul committed
21
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
22

Paul's avatar
Paul committed
23
24
struct context;

Paul's avatar
Paul committed
25
26
#ifdef DOXYGEN

Paul's avatar
Paul committed
27
/// The operation interface represents an action an instruction will perform. All
Paul's avatar
Paul committed
28
29
30
31
32
/// operation classes must be CopyConstructible.
struct operation
{
    /// A unique name identifying the operation
    std::string name() const;
Paul's avatar
Paul committed
33
34
    /// An optional method that can be used to finalize the operator before running
    void finalize(context& ctx);
Paul's avatar
Paul committed
35
36
37
    /// This is used to compute the resulting shape from an operation. If an
    /// operation cannot be run with input shapes, then it should throw an
    /// exception.
Paul's avatar
Paul committed
38
    shape compute_shape(const std::vector<shape>& input) const;
Paul's avatar
Paul committed
39
    /**
Paul's avatar
Paul committed
40
41
42
43
     * @brief This performs the operation's computation.
     *
     * This method can be optional when the operation is only used as a placeholder to be lowered
     * later on.
Paul's avatar
Paul committed
44
45
46
47
48
     *
     * @param ctx This is the context created by the `target` during compilation. Implementations
     * can use the target's `context` class rather than the `context` interface class.
     * @param output This is the output shape. It is equivalent to running `compute_shape` with each
     * `shape` of the `argument`.
Paul's avatar
Paul committed
49
     * @param input This is the `argument` result from the previous instruction's computation.
Paul's avatar
Paul committed
50
51
52
     * @return Return an `argument` of the result computation. The `shape` of `argument` should be
     * the same the `output` shape.
     */
Paul's avatar
Paul committed
53
    argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const;
Paul's avatar
Paul committed
54
55
    /// An optional method to return which argument the output will alias. If
    /// there is no aliased output then -1 can be returned.
Paul's avatar
Paul committed
56
    std::ptrdiff_t output_alias(const std::vector<shape>& input) const;
Paul's avatar
Paul committed
57
58
59
60
61
    /// An optional stream operator to print the operation. When this is not
    /// implemented, it will just print the operation's name.
    friend std::ostream& operator<<(std::ostream& os, const operation& op);
};

Paul's avatar
Paul committed
62
63
/// Returns true if operation does not require a context to run compute
bool is_context_free(const operation& x);
Shucai Xiao's avatar
Shucai Xiao committed
64
65
/// Returns true if operation needs normalization before running compute
bool need_normalization(const operation& x);
Paul's avatar
Paul committed
66
67
/// Returns true if the operation has a finalize method
bool has_finalize(const operation& x);
Paul's avatar
Paul committed
68

Paul's avatar
Paul committed
69
70
#else

71
72
73
namespace detail {

namespace operation_operators {
Paul's avatar
Paul committed
74
75
76
77

template <class T>
auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
{
Paul's avatar
Paul committed
78
79
    os << x.name();
    char delim = '[';
Paul's avatar
Paul committed
80
    reflect_each(x, [&](auto&& y, auto name) {
Paul's avatar
Paul committed
81
        os << delim;
Paul's avatar
Paul committed
82
83
        os << name << "=";
        stream_write_value(os, y);
Paul's avatar
Paul committed
84
85
86
87
88
        delim = ',';
    });
    if(delim == ',')
        os << "]";
    return os;
Paul's avatar
Paul committed
89
90
}

Paul's avatar
Paul committed
91
92
93
template <class T, class U>
auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
{
94
95
    static_assert(is_reflectable<T>{} or sizeof(T) <= 1,
                  "Missing equality operator or reflect method.");
Paul's avatar
Paul committed
96
97
98
99
100
101
    if(x.name() != y.name())
        return false;
    const auto& yy = any_cast<T>(y);
    return reflect_tie(x) == reflect_tie(yy);
}

102
} // namespace operation_operators
Paul's avatar
Paul committed
103

Shucai Xiao's avatar
Shucai Xiao committed
104
template <class T>
Shucai Xiao's avatar
Shucai Xiao committed
105
106
auto normalize_compute_shape_op(rank<1>, const T& x, const std::vector<shape>& inputs)
    -> decltype(x.normalize_compute_shape(inputs))
Shucai Xiao's avatar
Shucai Xiao committed
107
108
109
110
111
112
{
    dependent_type<operation, T> y = x;
    normalize_attributes(y, inputs[0].lens());
    return any_cast<T>(y).normalize_compute_shape(inputs);
}

Shucai Xiao's avatar
Shucai Xiao committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
template <class T>
shape normalize_compute_shape_op(rank<0>, const T& x, const std::vector<shape>&)
{
    std::string name = x.name();
    MIGRAPHX_THROW("Shape not computable: " + name);
}

template <class T>
shape normalize_compute_shape_op(const T& x, const std::vector<shape>& inputs)
{
    return normalize_compute_shape_op(rank<1>{}, x, inputs);
}

template <class T>
auto compute_shape_op(rank<1>,
                      const T& x,
                      const std::vector<shape>& inputs,
                      const std::vector<module_ref>& mod_args)
    -> decltype(x.compute_shape(inputs, mod_args))
{
    return x.compute_shape(inputs, mod_args);
}

template <class T>
shape
    compute_shape_op(rank<0>, const T& x, const std::vector<shape>&, const std::vector<module_ref>&)
{
    std::string name = x.name();
    MIGRAPHX_THROW("Shape not computable: " + name);
}

template <class T>
shape compute_shape_op(const T& x,
                       const std::vector<shape>& inputs,
                       const std::vector<module_ref>& mod_args)
{
    return compute_shape_op(rank<1>{}, x, inputs, mod_args);
}

template <class T>
auto normalize_compute_shape_op(rank<1>,
                                const T& x,
                                const std::vector<shape>& inputs,
                                std::vector<module_ref>& mod_args)
    -> decltype(x.normalize_compute_shape(inputs, mod_args))
{
    return x.normalize_compute_shape(inputs, mod_args);
}

template <class T>
shape normalize_compute_shape_op(rank<0>,
                                 const T& x,
                                 const std::vector<shape>&,
                                 const std::vector<module_ref>&)
{
    std::string name = x.name();
    MIGRAPHX_THROW("Shape not computable: " + name);
}

template <class T>
shape normalize_compute_shape_op(const T& x,
                                 const std::vector<shape>& inputs,
                                 std::vector<module_ref>& mod_args)
{
    return normalize_compute_shape_op(rank<1>{}, x, inputs, mod_args);
}

Paul's avatar
Paul committed
180
template <class T>
Shucai Xiao's avatar
Shucai Xiao committed
181
auto compute_op(rank<1>,
Paul's avatar
Paul committed
182
183
184
185
186
187
188
189
190
191
192
193
                const T& x,
                context& ctx,
                const shape& output_shape,
                const std::vector<argument>& input)
    -> decltype(x.compute(auto_any_cast(ctx), output_shape, input))
{
    return x.compute(auto_any_cast(ctx), output_shape, input);
}

template <class T>
argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
{
Paul's avatar
Paul committed
194
    std::string name = x.name();
Paul's avatar
Paul committed
195
    MIGRAPHX_THROW("Not computable: " + name);
Paul's avatar
Paul committed
196
197
}

Paul's avatar
Paul committed
198
template <class T>
Paul's avatar
Paul committed
199
200
argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
Paul's avatar
Paul committed
201
{
Shucai Xiao's avatar
Shucai Xiao committed
202
    return compute_op(rank<1>{}, x, ctx, output_shape, input);
Paul's avatar
Paul committed
203
204
205
}

template <class T>
Shucai Xiao's avatar
Shucai Xiao committed
206
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
Paul's avatar
Paul committed
207
208
209
210
211
212
    -> decltype(x.compute(output_shape, input))
{
    return x.compute(output_shape, input);
}

template <class T>
Shucai Xiao's avatar
Shucai Xiao committed
213
argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&)
Paul's avatar
Paul committed
214
215
{
    std::string name = x.name();
Shucai Xiao's avatar
Shucai Xiao committed
216
    MIGRAPHX_THROW("Not computable: " + name);
Paul's avatar
Paul committed
217
218
219
}

template <class T>
Shucai Xiao's avatar
Shucai Xiao committed
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
argument compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input)
{
    return compute_op(rank<1>{}, x, output_shape, input);
}

template <class T, class F>
auto compute_op(rank<1>,
                const T& x,
                const shape& output,
                const std::vector<argument>& inputs,
                const std::vector<module_ref>& module_args,
                F f) -> decltype(x.compute(output, inputs, module_args, f))
{
    return x.compute(output, inputs, module_args, f);
}

template <class T, class F>
argument compute_op(rank<0>,
                    const T& x,
                    const shape&,
                    const std::vector<argument>&,
                    const std::vector<module_ref>&,
                    F)
Paul's avatar
Paul committed
243
244
245
246
247
{
    std::string name = x.name();
    MIGRAPHX_THROW("Not computable: " + name);
}

Shucai Xiao's avatar
Shucai Xiao committed
248
249
250
251
252
253
template <class T, class F>
argument compute_op(const T& x,
                    const shape& output,
                    const std::vector<argument>& inputs,
                    const std::vector<module_ref>& module_args,
                    F f)
Paul's avatar
Paul committed
254
{
Shucai Xiao's avatar
Shucai Xiao committed
255
    return compute_op(rank<1>{}, x, output, inputs, module_args, f);
Paul's avatar
Paul committed
256
257
}

Shucai Xiao's avatar
Shucai Xiao committed
258
template <class T, class F>
Shucai Xiao's avatar
Shucai Xiao committed
259
auto compute_op(rank<3>,
Shucai Xiao's avatar
Shucai Xiao committed
260
                const T& x,
Shucai Xiao's avatar
Shucai Xiao committed
261
262
                context&,
                const shape& output,
Shucai Xiao's avatar
Shucai Xiao committed
263
264
                const std::vector<argument>& inputs,
                const std::vector<module_ref>& module_args,
Shucai Xiao's avatar
Shucai Xiao committed
265
                F f) -> decltype(x.compute(output, inputs, module_args, f))
Shucai Xiao's avatar
Shucai Xiao committed
266
{
Shucai Xiao's avatar
Shucai Xiao committed
267
    return x.compute(output, inputs, module_args, f);
Shucai Xiao's avatar
Shucai Xiao committed
268
269
270
}

template <class T, class F>
Shucai Xiao's avatar
Shucai Xiao committed
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
auto compute_op(rank<2>,
                const T& x,
                context&,
                const shape& output,
                const std::vector<argument>& inputs,
                const std::vector<module_ref>&,
                F) -> decltype(x.compute(output, inputs))
{
    return x.compute(output, inputs);
}

template <class T, class F>
auto compute_op(rank<1>,
                const T& x,
                context& ctx,
                const shape& output,
                const std::vector<argument>& inputs,
                const std::vector<module_ref>&,
                F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs))
{
    return x.compute(auto_any_cast(ctx), output, inputs);
}

template <class T, class F>
argument compute_op(rank<0>,
                    const T& x,
                    context&,
                    const shape&,
                    const std::vector<argument>&,
                    const std::vector<module_ref>&,
                    F)
Shucai Xiao's avatar
Shucai Xiao committed
302
303
304
305
306
307
308
{
    std::string name = x.name();
    MIGRAPHX_THROW("Not computable: " + name);
}

template <class T, class F>
argument compute_op(const T& x,
Shucai Xiao's avatar
Shucai Xiao committed
309
310
                    context& ctx,
                    const shape& output,
Shucai Xiao's avatar
Shucai Xiao committed
311
312
313
314
                    const std::vector<argument>& inputs,
                    const std::vector<module_ref>& module_args,
                    F f)
{
Shucai Xiao's avatar
Shucai Xiao committed
315
    return compute_op(rank<3>{}, x, ctx, output, inputs, module_args, f);
Shucai Xiao's avatar
Shucai Xiao committed
316
317
}

Paul's avatar
Paul committed
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
template <class T>
auto is_context_free_op(rank<1>,
                        const T& x,
                        const shape& output_shape,
                        const std::vector<argument>& input)
    -> decltype(x.compute(output_shape, input), std::true_type{});

template <class T>
auto is_context_free_op(rank<0>, const T&, const shape&, const std::vector<argument>&)
    -> std::false_type;

template <class T>
auto is_context_free_op(const T& x) -> decltype(is_context_free_op(
    rank<1>{}, x, std::declval<const shape&>(), std::declval<std::vector<argument>>()))
{
    return {};
Paul's avatar
Paul committed
334
335
}

Shucai Xiao's avatar
Shucai Xiao committed
336
337
338
339
340
341
342
343
344
345
346
347
348
349
template <class T>
auto need_normalization_op(rank<1>, const T& x, const std::vector<shape>& inputs)
    -> decltype(x.normalize_compute_shape(inputs), std::true_type{});

template <class T>
auto need_normalization_op(rank<0>, const T&, const std::vector<shape>&) -> std::false_type;

template <class T>
auto need_normalization_op(const T& x)
    -> decltype(need_normalization_op(rank<1>{}, x, std::declval<std::vector<shape>>()))
{
    return {};
}

Paul's avatar
Paul committed
350
template <class T>
351
std::ptrdiff_t output_alias_op(const T&, const std::vector<shape>&)
Paul's avatar
Paul committed
352
353
354
355
{
    return -1;
}

Paul's avatar
Paul committed
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
template <class T>
auto finalize_op(
    rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
    -> decltype(x.finalize(auto_any_cast(ctx), output_shape, input), void())
{
    x.finalize(auto_any_cast(ctx), output_shape, input);
}

template <class T>
void finalize_op(rank<0>, T&, context&, const shape&, const std::vector<shape>&)
{
}

template <class T>
void finalize_op(T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
{
    finalize_op(rank<1>{}, x, ctx, output_shape, input);
}

template <class T>
auto has_finalize_op(
    rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
    -> decltype(x.finalize(auto_any_cast(ctx), output_shape, input), std::true_type{});

template <class T>
auto has_finalize_op(rank<0>, T&, context&, const shape&, const std::vector<shape>&)
    -> std::false_type;

template <class T>
auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
                                                           std::declval<T&>(),
                                                           std::declval<context&>(),
                                                           std::declval<const shape&>(),
                                                           std::declval<std::vector<shape>>()))
{
    return {};
}

394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
template <class T>
auto compile_op(
    rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
    -> decltype(x.compile(auto_any_cast(ctx), output_shape, input))
{
    return x.compile(auto_any_cast(ctx), output_shape, input);
}

template <class T>
value compile_op(rank<0>, T&, context&, const shape&, const std::vector<shape>&)
{
    return value::object{};
}

template <class T>
value compile_op(const T& x,
                 context& ctx,
                 const shape& output_shape,
                 const std::vector<shape>& input)
{
    return compile_op(rank<1>{}, x, ctx, output_shape, input);
}

417
418
419
420
421
422
template <class T>
value attributes_op(const T&)
{
    return value::object{};
}

423
424
425
426
427
428
429
430
431
template <class T>
value to_value_op(const T& x)
{
    return migraphx::to_value(x);
}

template <class T>
void from_value_op(T& x, const value& v)
{
432
433
    if(not(v.is_object() or (v.empty() and v.is_array())))
        MIGRAPHX_THROW("Value is not an object");
434
435
436
    return migraphx::from_value(v, x);
}

437
438
439
440
441
442
template <class T>
bool is_borrowed_op(const T&)
{
    return false;
}

443
444
} // namespace detail

Paul's avatar
Paul committed
445
/*
Paul's avatar
Paul committed
446
447
 * Type-erased interface for:
 *
Paul's avatar
Paul committed
448
 * struct operation
Paul's avatar
Paul committed
449
 * {
Paul's avatar
Paul committed
450
 *      std::string name() const;
Paul's avatar
Paul committed
451
 *      bool is_context_free() const;
Shucai Xiao's avatar
Shucai Xiao committed
452
 *      bool need_normalization() const;
Paul's avatar
Paul committed
453
 *      bool has_finalize() const;
454
 *      bool is_borrowed() const;
Paul's avatar
Paul committed
455
 *      std::ptrdiff_t output_alias(const std::vector<shape>& input) const;
456
 *      value compile(context& ctx,const shape& output,const std::vector<shape>& input) ;
Paul's avatar
Paul committed
457
 *      void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ;
Paul's avatar
Paul committed
458
 *      shape compute_shape(const std::vector<shape>& input) const;
Shucai Xiao's avatar
Shucai Xiao committed
459
460
461
 *      shape compute_shape(const std::vector<shape>& inputs,const std::vector<module_ref>&
 * mod_args) const; argument compute(context& ctx,const shape& output,const std::vector<argument>&
 * input) const; argument compute(const shape& output,const std::vector<argument>& input)
Shucai Xiao's avatar
Shucai Xiao committed
462
463
464
465
466
467
 * const; argument compute(const shape& output,const std::vector<argument>& input,const
 * std::vector<module_ref>& module_args,std::function<std::vector<argument>(module_ref&, const
 * std::unordered_map<std::string, argument>&)> run) const; argument compute(context& ctx,const
 * shape& output,const std::vector<argument>& input,const std::vector<module_ref>&
 * module_args,std::function<std::vector<argument>(module_ref&, const
 * std::unordered_map<std::string, argument>&)> run) const; value to_value() const; void
Shucai Xiao's avatar
Shucai Xiao committed
468
469
470
 * from_value(const value& v) ; value attributes() const; friend std::ostream &
 * operator<<(std::ostream & os,const operation & op) ; friend bool operator==(const operation &
 * x,const operation & y) ;
Paul's avatar
Paul committed
471
472
473
 * };
 *
 */
Paul's avatar
Paul committed
474

Paul's avatar
Paul committed
475
struct operation
Paul's avatar
Paul committed
476
{
Paul's avatar
Paul committed
477
    // Constructors
Paul's avatar
Paul committed
478
    operation() = default;
Paul's avatar
Paul committed
479

Paul's avatar
Paul committed
480
    template <typename PrivateDetailTypeErasedT>
Paul's avatar
Paul committed
481
    operation(PrivateDetailTypeErasedT value)
Paul's avatar
Paul committed
482
483
484
485
        : private_detail_te_handle_mem_var(
              std::make_shared<private_detail_te_handle_type<
                  typename std::remove_reference<PrivateDetailTypeErasedT>::type>>(
                  std::forward<PrivateDetailTypeErasedT>(value)))
Paul's avatar
Paul committed
486
487
488
489
    {
    }

    // Assignment
Paul's avatar
Paul committed
490
    template <typename PrivateDetailTypeErasedT>
Paul's avatar
Paul committed
491
    operation& operator=(PrivateDetailTypeErasedT value)
Paul's avatar
Paul committed
492
    {
Paul Fultz II's avatar
Paul Fultz II committed
493
494
495
496
497
498
499
500
501
502
503
        using std::swap;
        auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
        if(derived and private_detail_te_handle_mem_var.unique())
        {
            *derived = std::forward<PrivateDetailTypeErasedT>(value);
        }
        else
        {
            operation rhs(value);
            swap(private_detail_te_handle_mem_var, rhs.private_detail_te_handle_mem_var);
        }
Paul's avatar
Paul committed
504
505
506
        return *this;
    }

Paul's avatar
Paul committed
507
508
509
510
    // Cast
    template <typename PrivateDetailTypeErasedT>
    PrivateDetailTypeErasedT* any_cast()
    {
Paul Fultz II's avatar
Paul Fultz II committed
511
        return this->type_id() == typeid(PrivateDetailTypeErasedT)
Paul's avatar
Paul committed
512
513
514
515
516
517
518
519
520
521
                   ? std::addressof(static_cast<private_detail_te_handle_type<
                                        typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
                                        private_detail_te_get_handle())
                                        .private_detail_te_value)
                   : nullptr;
    }

    template <typename PrivateDetailTypeErasedT>
    const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
    {
Paul Fultz II's avatar
Paul Fultz II committed
522
        return this->type_id() == typeid(PrivateDetailTypeErasedT)
Paul's avatar
Paul committed
523
524
525
526
527
528
529
                   ? std::addressof(static_cast<const private_detail_te_handle_type<
                                        typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
                                        private_detail_te_get_handle())
                                        .private_detail_te_value)
                   : nullptr;
    }

Paul's avatar
Paul committed
530
531
532
533
534
535
536
537
    const std::type_info& type_id() const
    {
        if(private_detail_te_handle_empty())
            return typeid(std::nullptr_t);
        else
            return private_detail_te_get_handle().type();
    }

Paul's avatar
Paul committed
538
539
    std::string name() const
    {
Paul's avatar
Paul committed
540
541
        assert((*this).private_detail_te_handle_mem_var);
        return (*this).private_detail_te_get_handle().name();
Paul's avatar
Paul committed
542
543
    }

Paul's avatar
Paul committed
544
545
546
547
548
549
    bool is_context_free() const
    {
        assert((*this).private_detail_te_handle_mem_var);
        return (*this).private_detail_te_get_handle().is_context_free();
    }

Shucai Xiao's avatar
Shucai Xiao committed
550
551
552
553
554
555
    bool need_normalization() const
    {
        assert((*this).private_detail_te_handle_mem_var);
        return (*this).private_detail_te_get_handle().need_normalization();
    }

Paul's avatar
Paul committed
556
557
558
559
560
561
    bool has_finalize() const
    {
        assert((*this).private_detail_te_handle_mem_var);
        return (*this).private_detail_te_get_handle().has_finalize();
    }

562
563
564
565
566
567
    bool is_borrowed() const
    {
        assert((*this).private_detail_te_handle_mem_var);
        return (*this).private_detail_te_get_handle().is_borrowed();
    }

Paul's avatar
Paul committed
568
    std::ptrdiff_t output_alias(const std::vector<shape>& input) const
Paul's avatar
Paul committed
569
570
571
572
573
    {
        assert((*this).private_detail_te_handle_mem_var);
        return (*this).private_detail_te_get_handle().output_alias(input);
    }

574
575
576
577
578
579
    value compile(context& ctx, const shape& output, const std::vector<shape>& input)
    {
        assert((*this).private_detail_te_handle_mem_var);
        return (*this).private_detail_te_get_handle().compile(ctx, output, input);
    }

Paul's avatar
Paul committed
580
581
582
583
584
585
    void finalize(context& ctx, const shape& output, const std::vector<shape>& input)
    {
        assert((*this).private_detail_te_handle_mem_var);
        (*this).private_detail_te_get_handle().finalize(ctx, output, input);
    }

Paul's avatar
Paul committed
586
    shape compute_shape(const std::vector<shape>& input) const
Paul's avatar
Paul committed
587
    {
Paul's avatar
Paul committed
588
        assert((*this).private_detail_te_handle_mem_var);
Paul's avatar
Paul committed
589
        return (*this).private_detail_te_get_handle().compute_shape(input);
Paul's avatar
Paul committed
590
591
    }

Shucai Xiao's avatar
Shucai Xiao committed
592
593
594
595
596
597
598
    shape compute_shape(const std::vector<shape>& inputs,
                        const std::vector<module_ref>& mod_args) const
    {
        assert((*this).private_detail_te_handle_mem_var);
        return (*this).private_detail_te_get_handle().compute_shape(inputs, mod_args);
    }

Paul's avatar
Paul committed
599
    argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const
Paul's avatar
Paul committed
600
    {
Paul's avatar
Paul committed
601
        assert((*this).private_detail_te_handle_mem_var);
Paul's avatar
Paul committed
602
        return (*this).private_detail_te_get_handle().compute(ctx, output, input);
Paul's avatar
Paul committed
603
604
    }

Paul's avatar
Paul committed
605
606
607
608
609
610
    argument compute(const shape& output, const std::vector<argument>& input) const
    {
        assert((*this).private_detail_te_handle_mem_var);
        return (*this).private_detail_te_get_handle().compute(output, input);
    }

Shucai Xiao's avatar
Shucai Xiao committed
611
612
613
614
615
    argument compute(const shape& output,
                     const std::vector<argument>& input,
                     const std::vector<module_ref>& module_args,
                     std::function<std::vector<argument>(
                         module_ref&, const std::unordered_map<std::string, argument>&)> run) const
Shucai Xiao's avatar
Shucai Xiao committed
616
617
    {
        assert((*this).private_detail_te_handle_mem_var);
Shucai Xiao's avatar
Shucai Xiao committed
618
619
620
621
622
623
624
625
626
627
628
629
630
631
        return (*this).private_detail_te_get_handle().compute(
            output, input, module_args, std::move(run));
    }

    argument compute(context& ctx,
                     const shape& output,
                     const std::vector<argument>& input,
                     const std::vector<module_ref>& module_args,
                     std::function<std::vector<argument>(
                         module_ref&, const std::unordered_map<std::string, argument>&)> run) const
    {
        assert((*this).private_detail_te_handle_mem_var);
        return (*this).private_detail_te_get_handle().compute(
            ctx, output, input, module_args, std::move(run));
Shucai Xiao's avatar
Shucai Xiao committed
632
633
    }

634
635
636
637
638
639
640
641
642
643
644
645
    value to_value() const
    {
        assert((*this).private_detail_te_handle_mem_var);
        return (*this).private_detail_te_get_handle().to_value();
    }

    void from_value(const value& v)
    {
        assert((*this).private_detail_te_handle_mem_var);
        (*this).private_detail_te_get_handle().from_value(v);
    }

646
647
648
649
650
651
    value attributes() const
    {
        assert((*this).private_detail_te_handle_mem_var);
        return (*this).private_detail_te_get_handle().attributes();
    }

Paul's avatar
Paul committed
652
653
654
655
    friend std::ostream& operator<<(std::ostream& os, const operation& op)
    {
        assert(op.private_detail_te_handle_mem_var);
        return op.private_detail_te_get_handle().operator_shift_left(os);
Paul's avatar
Paul committed
656
657
    }

Paul's avatar
Paul committed
658
659
660
661
662
663
    friend bool operator==(const operation& x, const operation& y)
    {
        assert(x.private_detail_te_handle_mem_var);
        return x.private_detail_te_get_handle().operator==(y);
    }

Paul's avatar
Paul committed
664
665
666
667
668
669
    friend bool is_shared(const operation& private_detail_x, const operation& private_detail_y)
    {
        return private_detail_x.private_detail_te_handle_mem_var ==
               private_detail_y.private_detail_te_handle_mem_var;
    }

Paul's avatar
Paul committed
670
    private:
Paul's avatar
Paul committed
671
    struct private_detail_te_handle_base_type
Paul's avatar
Paul committed
672
    {
Paul's avatar
Paul committed
673
674
        virtual ~private_detail_te_handle_base_type() {}
        virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
Paul's avatar
Paul committed
675
        virtual const std::type_info& type() const                                = 0;
Paul's avatar
Paul committed
676

Paul's avatar
Paul committed
677
678
        virtual std::string name() const                                           = 0;
        virtual bool is_context_free() const                                       = 0;
Shucai Xiao's avatar
Shucai Xiao committed
679
        virtual bool need_normalization() const                                    = 0;
Paul's avatar
Paul committed
680
        virtual bool has_finalize() const                                          = 0;
681
        virtual bool is_borrowed() const                                           = 0;
Paul's avatar
Paul committed
682
        virtual std::ptrdiff_t output_alias(const std::vector<shape>& input) const = 0;
683
684
        virtual value
        compile(context& ctx, const shape& output, const std::vector<shape>& input) = 0;
Paul's avatar
Paul committed
685
686
687
        virtual void
        finalize(context& ctx, const shape& output, const std::vector<shape>& input) = 0;
        virtual shape compute_shape(const std::vector<shape>& input) const           = 0;
Shucai Xiao's avatar
Shucai Xiao committed
688
689
        virtual shape compute_shape(const std::vector<shape>& inputs,
                                    const std::vector<module_ref>& mod_args) const   = 0;
Paul's avatar
Paul committed
690
        virtual argument
Paul's avatar
Paul committed
691
692
        compute(context& ctx, const shape& output, const std::vector<argument>& input) const    = 0;
        virtual argument compute(const shape& output, const std::vector<argument>& input) const = 0;
Shucai Xiao's avatar
Shucai Xiao committed
693
        virtual argument
Shucai Xiao's avatar
Shucai Xiao committed
694
695
696
697
698
699
700
701
702
        compute(const shape& output,
                const std::vector<argument>& input,
                const std::vector<module_ref>& module_args,
                std::function<std::vector<argument>(
                    module_ref&, const std::unordered_map<std::string, argument>&)> run) const = 0;
        virtual argument
        compute(context& ctx,
                const shape& output,
                const std::vector<argument>& input,
Shucai Xiao's avatar
Shucai Xiao committed
703
704
                const std::vector<module_ref>& module_args,
                std::function<std::vector<argument>(
Shucai Xiao's avatar
Shucai Xiao committed
705
706
707
708
709
710
                    module_ref&, const std::unordered_map<std::string, argument>&)> run) const = 0;
        virtual value to_value() const                                                         = 0;
        virtual void from_value(const value& v)                                                = 0;
        virtual value attributes() const                                                       = 0;
        virtual std::ostream& operator_shift_left(std::ostream& os) const                      = 0;
        virtual bool operator==(const operation& y) const                                      = 0;
Paul's avatar
Paul committed
711
712
    };

713
714
715
716
717
718
719
720
721
722
723
724
725
    template <class T>
    static auto private_detail_te_default_is_context_free(char, T&& private_detail_te_self)
        -> decltype(private_detail_te_self.is_context_free())
    {
        return private_detail_te_self.is_context_free();
    }

    template <class T>
    static bool private_detail_te_default_is_context_free(float, T&& private_detail_te_self)
    {
        return detail::is_context_free_op(private_detail_te_self);
    }

Shucai Xiao's avatar
Shucai Xiao committed
726
727
728
729
730
731
732
733
734
735
736
737
738
    template <class T>
    static auto private_detail_te_default_need_normalization(char, T&& private_detail_te_self)
        -> decltype(private_detail_te_self.need_normalization())
    {
        return private_detail_te_self.need_normalization();
    }

    template <class T>
    static bool private_detail_te_default_need_normalization(float, T&& private_detail_te_self)
    {
        return detail::need_normalization_op(private_detail_te_self);
    }

739
740
741
742
743
744
745
746
747
748
749
750
751
    template <class T>
    static auto private_detail_te_default_has_finalize(char, T&& private_detail_te_self)
        -> decltype(private_detail_te_self.has_finalize())
    {
        return private_detail_te_self.has_finalize();
    }

    template <class T>
    static bool private_detail_te_default_has_finalize(float, T&& private_detail_te_self)
    {
        return detail::has_finalize_op(private_detail_te_self);
    }

752
753
754
755
756
757
758
759
760
761
762
763
764
    template <class T>
    static auto private_detail_te_default_is_borrowed(char, T&& private_detail_te_self)
        -> decltype(private_detail_te_self.is_borrowed())
    {
        return private_detail_te_self.is_borrowed();
    }

    template <class T>
    static bool private_detail_te_default_is_borrowed(float, T&& private_detail_te_self)
    {
        return detail::is_borrowed_op(private_detail_te_self);
    }

765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
    template <class T>
    static auto private_detail_te_default_output_alias(char,
                                                       T&& private_detail_te_self,
                                                       const std::vector<shape>& input)
        -> decltype(private_detail_te_self.output_alias(input))
    {
        return private_detail_te_self.output_alias(input);
    }

    template <class T>
    static std::ptrdiff_t private_detail_te_default_output_alias(float,
                                                                 T&& private_detail_te_self,
                                                                 const std::vector<shape>& input)
    {
        return detail::output_alias_op(private_detail_te_self, input);
    }

782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
    template <class T>
    static auto private_detail_te_default_compile(char,
                                                  T&& private_detail_te_self,
                                                  context& ctx,
                                                  const shape& output,
                                                  const std::vector<shape>& input)
        -> decltype(private_detail_te_self.compile(ctx, output, input))
    {
        return private_detail_te_self.compile(ctx, output, input);
    }

    template <class T>
    static value private_detail_te_default_compile(float,
                                                   T&& private_detail_te_self,
                                                   context& ctx,
                                                   const shape& output,
                                                   const std::vector<shape>& input)
    {
        return detail::compile_op(private_detail_te_self, ctx, output, input);
    }

803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
    template <class T>
    static auto private_detail_te_default_finalize(char,
                                                   T&& private_detail_te_self,
                                                   context& ctx,
                                                   const shape& output,
                                                   const std::vector<shape>& input)
        -> decltype(private_detail_te_self.finalize(ctx, output, input))
    {
        private_detail_te_self.finalize(ctx, output, input);
    }

    template <class T>
    static void private_detail_te_default_finalize(float,
                                                   T&& private_detail_te_self,
                                                   context& ctx,
                                                   const shape& output,
                                                   const std::vector<shape>& input)
    {
        detail::finalize_op(private_detail_te_self, ctx, output, input);
    }

Shucai Xiao's avatar
Shucai Xiao committed
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
    template <class T>
    static auto private_detail_te_default_compute_shape(char,
                                                        T&& private_detail_te_self,
                                                        const std::vector<shape>& input)
        -> decltype(private_detail_te_self.compute_shape(input))
    {
        return private_detail_te_self.compute_shape(input);
    }

    template <class T>
    static shape private_detail_te_default_compute_shape(float,
                                                         T&& private_detail_te_self,
                                                         const std::vector<shape>& input)
    {
        return detail::normalize_compute_shape_op(private_detail_te_self, input);
    }

Shucai Xiao's avatar
Shucai Xiao committed
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
    template <class T>
    static auto private_detail_te_default_compute_shape(char,
                                                        T&& private_detail_te_self,
                                                        const std::vector<shape>& inputs,
                                                        const std::vector<module_ref>& mod_args)
        -> decltype(private_detail_te_self.compute_shape(inputs, mod_args))
    {
        return private_detail_te_self.compute_shape(inputs, mod_args);
    }

    template <class T>
    static shape private_detail_te_default_compute_shape(float,
                                                         T&& private_detail_te_self,
                                                         const std::vector<shape>& inputs,
                                                         const std::vector<module_ref>& mod_args)
    {
        return detail::compute_shape_op(private_detail_te_self, inputs, mod_args);
    }

860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
    template <class T>
    static auto private_detail_te_default_compute(char,
                                                  T&& private_detail_te_self,
                                                  context& ctx,
                                                  const shape& output,
                                                  const std::vector<argument>& input)
        -> decltype(private_detail_te_self.compute(ctx, output, input))
    {
        return private_detail_te_self.compute(ctx, output, input);
    }

    template <class T>
    static argument private_detail_te_default_compute(float,
                                                      T&& private_detail_te_self,
                                                      context& ctx,
                                                      const shape& output,
                                                      const std::vector<argument>& input)
    {
        return detail::compute_op(private_detail_te_self, ctx, output, input);
    }

    template <class T>
    static auto private_detail_te_default_compute(char,
                                                  T&& private_detail_te_self,
                                                  const shape& output,
                                                  const std::vector<argument>& input)
        -> decltype(private_detail_te_self.compute(output, input))
    {
        return private_detail_te_self.compute(output, input);
    }

    template <class T>
    static argument private_detail_te_default_compute(float,
                                                      T&& private_detail_te_self,
                                                      const shape& output,
                                                      const std::vector<argument>& input)
    {
        return detail::compute_op(private_detail_te_self, output, input);
    }

Shucai Xiao's avatar
Shucai Xiao committed
900
901
902
903
    template <class T>
    static auto private_detail_te_default_compute(
        char,
        T&& private_detail_te_self,
Shucai Xiao's avatar
Shucai Xiao committed
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
        const shape& output,
        const std::vector<argument>& input,
        const std::vector<module_ref>& module_args,
        std::function<std::vector<argument>(module_ref&,
                                            const std::unordered_map<std::string, argument>&)> run)
        -> decltype(private_detail_te_self.compute(output, input, module_args, std::move(run)))
    {
        return private_detail_te_self.compute(output, input, module_args, std::move(run));
    }

    template <class T>
    static argument private_detail_te_default_compute(
        float,
        T&& private_detail_te_self,
        const shape& output,
        const std::vector<argument>& input,
        const std::vector<module_ref>& module_args,
        std::function<std::vector<argument>(module_ref&,
                                            const std::unordered_map<std::string, argument>&)> run)
    {
        return detail::compute_op(
            private_detail_te_self, output, input, module_args, std::move(run));
    }

    template <class T>
    static auto private_detail_te_default_compute(
        char,
        T&& private_detail_te_self,
        context& ctx,
        const shape& output,
Shucai Xiao's avatar
Shucai Xiao committed
934
935
        const std::vector<argument>& input,
        const std::vector<module_ref>& module_args,
Shucai Xiao's avatar
Shucai Xiao committed
936
937
938
        std::function<std::vector<argument>(module_ref&,
                                            const std::unordered_map<std::string, argument>&)> run)
        -> decltype(private_detail_te_self.compute(ctx, output, input, module_args, std::move(run)))
Shucai Xiao's avatar
Shucai Xiao committed
939
    {
Shucai Xiao's avatar
Shucai Xiao committed
940
        return private_detail_te_self.compute(ctx, output, input, module_args, std::move(run));
Shucai Xiao's avatar
Shucai Xiao committed
941
942
943
944
945
946
    }

    template <class T>
    static argument private_detail_te_default_compute(
        float,
        T&& private_detail_te_self,
Shucai Xiao's avatar
Shucai Xiao committed
947
948
        context& ctx,
        const shape& output,
Shucai Xiao's avatar
Shucai Xiao committed
949
950
        const std::vector<argument>& input,
        const std::vector<module_ref>& module_args,
Shucai Xiao's avatar
Shucai Xiao committed
951
952
        std::function<std::vector<argument>(module_ref&,
                                            const std::unordered_map<std::string, argument>&)> run)
Shucai Xiao's avatar
Shucai Xiao committed
953
    {
Shucai Xiao's avatar
Shucai Xiao committed
954
955
        return detail::compute_op(
            private_detail_te_self, ctx, output, input, module_args, std::move(run));
Shucai Xiao's avatar
Shucai Xiao committed
956
957
    }

958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
    template <class T>
    static auto private_detail_te_default_to_value(char, T&& private_detail_te_self)
        -> decltype(private_detail_te_self.to_value())
    {
        return private_detail_te_self.to_value();
    }

    template <class T>
    static value private_detail_te_default_to_value(float, T&& private_detail_te_self)
    {
        return detail::to_value_op(private_detail_te_self);
    }

    template <class T>
    static auto
    private_detail_te_default_from_value(char, T&& private_detail_te_self, const value& v)
        -> decltype(private_detail_te_self.from_value(v))
    {
        private_detail_te_self.from_value(v);
    }

    template <class T>
    static void
    private_detail_te_default_from_value(float, T&& private_detail_te_self, const value& v)
    {
        detail::from_value_op(private_detail_te_self, v);
    }

986
987
988
989
990
991
992
993
994
995
996
997
998
    template <class T>
    static auto private_detail_te_default_attributes(char, T&& private_detail_te_self)
        -> decltype(private_detail_te_self.attributes())
    {
        return private_detail_te_self.attributes();
    }

    template <class T>
    static value private_detail_te_default_attributes(float, T&& private_detail_te_self)
    {
        return detail::attributes_op(private_detail_te_self);
    }

Paul's avatar
Paul committed
999
1000
    template <typename PrivateDetailTypeErasedT>
    struct private_detail_te_handle_type : private_detail_te_handle_base_type
Paul's avatar
Paul committed
1001
    {
Paul's avatar
Paul committed
1002
1003
1004
1005
1006
1007
        template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
        private_detail_te_handle_type(
            PrivateDetailTypeErasedT value,
            typename std::enable_if<std::is_reference<PrivateDetailTypeErasedU>::value>::type* =
                nullptr)
            : private_detail_te_value(value)
Paul's avatar
Paul committed
1008
1009
1010
        {
        }

Paul's avatar
Paul committed
1011
1012
1013
1014
1015
1016
        template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
        private_detail_te_handle_type(
            PrivateDetailTypeErasedT value,
            typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
                                    int>::type* = nullptr) noexcept
            : private_detail_te_value(std::move(value))
Paul's avatar
Paul committed
1017
1018
1019
        {
        }

Paul's avatar
Paul committed
1020
        std::shared_ptr<private_detail_te_handle_base_type> clone() const override
Paul's avatar
Paul committed
1021
        {
Paul's avatar
Paul committed
1022
            return std::make_shared<private_detail_te_handle_type>(private_detail_te_value);
Paul's avatar
Paul committed
1023
1024
        }

Paul's avatar
Paul committed
1025
        const std::type_info& type() const override { return typeid(private_detail_te_value); }
Paul's avatar
Paul committed
1026

Paul's avatar
Paul committed
1027
        std::string name() const override { return private_detail_te_value.name(); }
Paul's avatar
Paul committed
1028

Paul's avatar
Paul committed
1029
1030
1031
        bool is_context_free() const override
        {

1032
            return private_detail_te_default_is_context_free(char(0), private_detail_te_value);
Paul's avatar
Paul committed
1033
1034
        }

Shucai Xiao's avatar
Shucai Xiao committed
1035
1036
1037
1038
1039
1040
        bool need_normalization() const override
        {

            return private_detail_te_default_need_normalization(char(0), private_detail_te_value);
        }

1041
1042
1043
1044
1045
        bool has_finalize() const override
        {

            return private_detail_te_default_has_finalize(char(0), private_detail_te_value);
        }
Paul's avatar
Paul committed
1046

1047
1048
1049
1050
1051
1052
        bool is_borrowed() const override
        {

            return private_detail_te_default_is_borrowed(char(0), private_detail_te_value);
        }

Paul's avatar
Paul committed
1053
        std::ptrdiff_t output_alias(const std::vector<shape>& input) const override
Paul's avatar
Paul committed
1054
1055
        {

1056
            return private_detail_te_default_output_alias(char(0), private_detail_te_value, input);
Paul's avatar
Paul committed
1057
1058
        }

1059
1060
1061
1062
1063
1064
1065
        value compile(context& ctx, const shape& output, const std::vector<shape>& input) override
        {

            return private_detail_te_default_compile(
                char(0), private_detail_te_value, ctx, output, input);
        }

Paul's avatar
Paul committed
1066
1067
1068
        void finalize(context& ctx, const shape& output, const std::vector<shape>& input) override
        {

1069
1070
            private_detail_te_default_finalize(
                char(0), private_detail_te_value, ctx, output, input);
Paul's avatar
Paul committed
1071
1072
        }

Paul's avatar
Paul committed
1073
        shape compute_shape(const std::vector<shape>& input) const override
Paul's avatar
Paul committed
1074
        {
Paul's avatar
Paul committed
1075

Shucai Xiao's avatar
Shucai Xiao committed
1076
            return private_detail_te_default_compute_shape(char(0), private_detail_te_value, input);
Paul's avatar
Paul committed
1077
1078
        }

Shucai Xiao's avatar
Shucai Xiao committed
1079
1080
1081
1082
1083
1084
1085
1086
        shape compute_shape(const std::vector<shape>& inputs,
                            const std::vector<module_ref>& mod_args) const override
        {

            return private_detail_te_default_compute_shape(
                char(0), private_detail_te_value, inputs, mod_args);
        }

Paul's avatar
Paul committed
1087
1088
1089
        argument compute(context& ctx,
                         const shape& output,
                         const std::vector<argument>& input) const override
Paul's avatar
Paul committed
1090
        {
Paul's avatar
Paul committed
1091

1092
1093
            return private_detail_te_default_compute(
                char(0), private_detail_te_value, ctx, output, input);
Paul's avatar
Paul committed
1094
1095
        }

Paul's avatar
Paul committed
1096
1097
1098
        argument compute(const shape& output, const std::vector<argument>& input) const override
        {

1099
1100
            return private_detail_te_default_compute(
                char(0), private_detail_te_value, output, input);
Paul's avatar
Paul committed
1101
1102
        }

Shucai Xiao's avatar
Shucai Xiao committed
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
        argument compute(
            const shape& output,
            const std::vector<argument>& input,
            const std::vector<module_ref>& module_args,
            std::function<std::vector<argument>(
                module_ref&, const std::unordered_map<std::string, argument>&)> run) const override
        {

            return private_detail_te_default_compute(
                char(0), private_detail_te_value, output, input, module_args, std::move(run));
        }

        argument compute(
            context& ctx,
            const shape& output,
            const std::vector<argument>& input,
            const std::vector<module_ref>& module_args,
            std::function<std::vector<argument>(
                module_ref&, const std::unordered_map<std::string, argument>&)> run) const override
Shucai Xiao's avatar
Shucai Xiao committed
1122
1123
1124
        {

            return private_detail_te_default_compute(
Shucai Xiao's avatar
Shucai Xiao committed
1125
                char(0), private_detail_te_value, ctx, output, input, module_args, std::move(run));
Shucai Xiao's avatar
Shucai Xiao committed
1126
1127
        }

1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
        value to_value() const override
        {

            return private_detail_te_default_to_value(char(0), private_detail_te_value);
        }

        void from_value(const value& v) override
        {

            private_detail_te_default_from_value(char(0), private_detail_te_value, v);
        }

1140
1141
1142
1143
1144
1145
        value attributes() const override
        {

            return private_detail_te_default_attributes(char(0), private_detail_te_value);
        }

Paul's avatar
Paul committed
1146
1147
        std::ostream& operator_shift_left(std::ostream& os) const override
        {
1148
            using migraphx::detail::operation_operators::operator<<;
Paul's avatar
Paul committed
1149
1150
1151
            return os << private_detail_te_value;
        }

Paul's avatar
Paul committed
1152
1153
        bool operator==(const operation& y) const override
        {
1154
            using migraphx::detail::operation_operators::operator==;
Paul's avatar
Paul committed
1155
1156
1157
            return private_detail_te_value == y;
        }

Paul's avatar
Paul committed
1158
        PrivateDetailTypeErasedT private_detail_te_value;
Paul's avatar
Paul committed
1159
1160
    };

Paul's avatar
Paul committed
1161
1162
1163
    template <typename PrivateDetailTypeErasedT>
    struct private_detail_te_handle_type<std::reference_wrapper<PrivateDetailTypeErasedT>>
        : private_detail_te_handle_type<PrivateDetailTypeErasedT&>
Paul's avatar
Paul committed
1164
    {
Paul's avatar
Paul committed
1165
1166
        private_detail_te_handle_type(std::reference_wrapper<PrivateDetailTypeErasedT> ref)
            : private_detail_te_handle_type<PrivateDetailTypeErasedT&>(ref.get())
Paul's avatar
Paul committed
1167
1168
1169
1170
        {
        }
    };

Paul's avatar
Paul committed
1171
1172
1173
1174
1175
    bool private_detail_te_handle_empty() const
    {
        return private_detail_te_handle_mem_var == nullptr;
    }

Paul's avatar
Paul committed
1176
1177
    const private_detail_te_handle_base_type& private_detail_te_get_handle() const
    {
Paul's avatar
Paul committed
1178
        assert(private_detail_te_handle_mem_var != nullptr);
Paul's avatar
Paul committed
1179
1180
        return *private_detail_te_handle_mem_var;
    }
Paul's avatar
Paul committed
1181

Paul's avatar
Paul committed
1182
    private_detail_te_handle_base_type& private_detail_te_get_handle()
Paul's avatar
Paul committed
1183
    {
Paul's avatar
Paul committed
1184
        assert(private_detail_te_handle_mem_var != nullptr);
Paul's avatar
Paul committed
1185
1186
1187
        if(!private_detail_te_handle_mem_var.unique())
            private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
        return *private_detail_te_handle_mem_var;
Paul's avatar
Paul committed
1188
1189
    }

Paul's avatar
Paul committed
1190
    std::shared_ptr<private_detail_te_handle_base_type> private_detail_te_handle_mem_var;
Paul's avatar
Paul committed
1191
1192
};

Paul's avatar
Paul committed
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
template <typename ValueType>
inline const ValueType* any_cast(const operation* x)
{
    return x->any_cast<ValueType>();
}

template <typename ValueType>
inline ValueType* any_cast(operation* x)
{
    return x->any_cast<ValueType>();
}

template <typename ValueType>
inline ValueType& any_cast(operation& x)
{
Paul's avatar
Paul committed
1208
    auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
Paul's avatar
Paul committed
1209
1210
1211
1212
1213
1214
1215
1216
    if(y == nullptr)
        throw std::bad_cast();
    return *y;
}

template <typename ValueType>
inline const ValueType& any_cast(const operation& x)
{
Paul's avatar
Paul committed
1217
    const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
Paul's avatar
Paul committed
1218
1219
1220
1221
1222
    if(y == nullptr)
        throw std::bad_cast();
    return *y;
}

Paul's avatar
Paul committed
1223
1224
inline bool operator!=(const operation& x, const operation& y) { return !(x == y); }

1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
inline value
compile(operation& op, context& ctx, const shape& output_shape, const std::vector<shape>& input)
{
    return op.compile(ctx, output_shape, input);
}
template <class Context>
inline value
compile(operation& op, Context& ctx, const shape& output_shape, const std::vector<shape>& input)
{
    dependent_type<context, Context> ctx2 = std::ref(ctx);
    return compile(op, ctx2, output_shape, input);
}
template <class T, class Context>
inline auto compile(T& op, Context& ctx, const shape& output_shape, const std::vector<shape>& input)
    -> decltype(op.compile(ctx, ctx, output_shape, input))
{
    return op.compile(ctx, ctx, output_shape, input);
}
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
inline shape compute_shape(const operation& op, const std::vector<shape>& inputs)
{
    return op.compute_shape(inputs);
}

template <class T>
inline auto compute_shape(const T& op, const std::vector<shape>& inputs)
    -> decltype(op.compute_shape(inputs))
{
    return op.compute_shape(inputs);
}

template <class T>
inline auto compute_shape(const T& op, const std::vector<shape>& inputs)
    -> decltype(op.normalize_compute_shape(inputs))
{
    return detail::normalize_compute_shape_op(op, inputs);
}

Shucai Xiao's avatar
Shucai Xiao committed
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
inline shape compute_shape(const operation& op,
                           const std::vector<shape>& inputs,
                           const std::vector<module_ref>& mod_args)
{
    return op.compute_shape(inputs, mod_args);
}

template <class T>
inline auto compute_shape(const T& op,
                          const std::vector<shape>& inputs,
                          const std::vector<module_ref>& mod_args)
    -> decltype(op.compute_shape(inputs, mod_args))
{
    return op.compute_shape(inputs, mod_args);
}

template <class T>
inline auto compute_shape(const T& op,
                          const std::vector<shape>& inputs,
                          const std::vector<module_ref>& mod_args)
    -> decltype(op.normalize_compute_shape(inputs, mod_args))
{
    return detail::normalize_compute_shape_op(op, inputs, mod_args);
}

Paul's avatar
Paul committed
1287
1288
1289
1290
1291
inline bool is_context_free(const operation& op) { return op.is_context_free(); }

template <class T>
bool is_context_free(const T& x)
{
1292
    return detail::is_context_free_op(x);
Paul's avatar
Paul committed
1293
1294
}

Shucai Xiao's avatar
Shucai Xiao committed
1295
1296
1297
1298
1299
1300
1301
1302
inline bool need_normalization(const operation& op) { return op.need_normalization(); }

template <class T>
bool need_normalization(const T& x)
{
    return detail::need_normalization_op(x);
}

Paul's avatar
Paul committed
1303
1304
1305
1306
1307
inline bool has_finalize(const operation& op) { return op.has_finalize(); }

template <class T>
bool has_finalize(const T& x)
{
1308
    return detail::has_finalize_op(x);
Paul's avatar
Paul committed
1309
1310
}

1311
1312
1313
void migraphx_to_value(value& v, const operation& op);
void migraphx_from_value(const value& v, operation& op);

Paul's avatar
Paul committed
1314
1315
#endif

Paul's avatar
Paul committed
1316
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
1317
} // namespace migraphx
Paul's avatar
Paul committed
1318
1319

#endif