operation.hpp 19.1 KB
Newer Older
Paul's avatar
Paul committed
1
2
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_OPERAND_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_OPERAND_HPP
Paul's avatar
Paul committed
3

Paul's avatar
Paul committed
4
#include <cassert>
Paul's avatar
Paul committed
5
#include <string>
Paul's avatar
Paul committed
6
#include <functional>
Paul's avatar
Paul committed
7
8
9
#include <memory>
#include <type_traits>
#include <utility>
Paul's avatar
Paul committed
10
11
12
13
14
15
16
#include <migraphx/reflect.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/auto_any_cast.hpp>
#include <migraphx/config.hpp>

namespace migraphx {
Paul's avatar
Paul committed
17
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
18

Paul's avatar
Paul committed
19
20
struct context;

Paul's avatar
Paul committed
21
22
#ifdef DOXYGEN

Paul's avatar
Paul committed
23
/// The operation interface represents an action an instruction will perform. All
Paul's avatar
Paul committed
24
25
26
27
28
/// operation classes must be CopyConstructible.
struct operation
{
    /// A unique name identifying the operation
    std::string name() const;
Paul's avatar
Paul committed
29
30
    /// An optional method that can be used to finalize the operator before running
    void finalize(context& ctx);
Paul's avatar
Paul committed
31
32
33
    /// This is used to compute the resulting shape from an operation. If an
    /// operation cannot be run with input shapes, then it should throw an
    /// exception.
Paul's avatar
Paul committed
34
    shape compute_shape(const std::vector<shape>& input) const;
Paul's avatar
Paul committed
35
    /**
Paul's avatar
Paul committed
36
37
38
39
     * @brief This performs the operation's computation.
     *
     * This method can be optional when the operation is only used as a placeholder to be lowered
     * later on.
Paul's avatar
Paul committed
40
41
42
43
44
     *
     * @param ctx This is the context created by the `target` during compilation. Implementations
     * can use the target's `context` class rather than the `context` interface class.
     * @param output This is the output shape. It is equivalent to running `compute_shape` with each
     * `shape` of the `argument`.
Paul's avatar
Paul committed
45
     * @param input This is the `argument` result from the previous instruction's computation.
Paul's avatar
Paul committed
46
47
48
     * @return Return an `argument` of the result computation. The `shape` of `argument` should be
     * the same the `output` shape.
     */
Paul's avatar
Paul committed
49
    argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const;
Paul's avatar
Paul committed
50
51
    /// An optional method to return which argument the output will alias. If
    /// there is no aliased output then -1 can be returned.
Paul's avatar
Paul committed
52
    std::ptrdiff_t output_alias(const std::vector<shape>& input) const;
Paul's avatar
Paul committed
53
54
55
56
57
    /// An optional stream operator to print the operation. When this is not
    /// implemented, it will just print the operation's name.
    friend std::ostream& operator<<(std::ostream& os, const operation& op);
};

Paul's avatar
Paul committed
58
59
/// Returns true if operation does not require a context to run compute
bool is_context_free(const operation& x);
Paul's avatar
Paul committed
60
61
/// Returns true if the operation has a finalize method
bool has_finalize(const operation& x);
Paul's avatar
Paul committed
62

Paul's avatar
Paul committed
63
64
#else

Paul's avatar
Paul committed
65
66
67
68
69
namespace operation_stream {

template <class T>
auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
{
Paul's avatar
Paul committed
70
71
    os << x.name();
    char delim = '[';
Paul's avatar
Paul committed
72
    reflect_each(x, [&](auto&& y, auto name) {
Paul's avatar
Paul committed
73
        os << delim;
Paul's avatar
Paul committed
74
75
        os << name << "=";
        stream_write_value(os, y);
Paul's avatar
Paul committed
76
77
78
79
80
        delim = ',';
    });
    if(delim == ',')
        os << "]";
    return os;
Paul's avatar
Paul committed
81
82
83
84
}

} // namespace operation_stream

Paul's avatar
Paul committed
85
86
87
88
89
namespace operation_equal {

template <class T, class U>
auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
{
90
91
    static_assert(is_reflectable<T>{} or sizeof(T) <= 1,
                  "Missing equality operator or reflect method.");
Paul's avatar
Paul committed
92
93
94
95
96
97
98
99
    if(x.name() != y.name())
        return false;
    const auto& yy = any_cast<T>(y);
    return reflect_tie(x) == reflect_tie(yy);
}

} // namespace operation_equal

Paul's avatar
Paul committed
100
template <class T>
Paul's avatar
Paul committed
101
auto compute_op(rank<2>,
Paul's avatar
Paul committed
102
103
104
105
106
107
108
109
110
                const T& x,
                context& ctx,
                const shape& output_shape,
                const std::vector<argument>& input)
    -> decltype(x.compute(auto_any_cast(ctx), output_shape, input))
{
    return x.compute(auto_any_cast(ctx), output_shape, input);
}

Paul's avatar
Paul committed
111
112
113
114
115
116
117
118
template <class T>
auto compute_op(
    rank<1>, const T& x, context&, const shape& output_shape, const std::vector<argument>& input)
    -> decltype(x.compute(output_shape, input))
{
    return x.compute(output_shape, input);
}

Paul's avatar
Paul committed
119
120
121
template <class T>
argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
{
Paul's avatar
Paul committed
122
    std::string name = x.name();
Paul's avatar
Paul committed
123
    MIGRAPHX_THROW("Not computable: " + name);
Paul's avatar
Paul committed
124
125
}

Paul's avatar
Paul committed
126
template <class T>
Paul's avatar
Paul committed
127
128
argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
Paul's avatar
Paul committed
129
{
Paul's avatar
Paul committed
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
    return compute_op(rank<2>{}, x, ctx, output_shape, input);
}

template <class T>
auto compute_op(rank<2>, const T& x, const shape& output_shape, const std::vector<argument>& input)
    -> decltype(x.compute(output_shape, input))
{
    return x.compute(output_shape, input);
}

template <class T>
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
    -> decltype(x.compute(auto_any_cast(std::declval<context&>()), output_shape, input))
{
    std::string name = x.name();
    MIGRAPHX_THROW("Not computable without a context: " + name);
}

template <class T>
argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&)
{
    std::string name = x.name();
    MIGRAPHX_THROW("Not computable: " + name);
}

template <class T>
argument compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input)
{
    return compute_op(rank<2>{}, x, output_shape, input);
}

template <class T>
auto is_context_free_op(rank<1>,
                        const T& x,
                        const shape& output_shape,
                        const std::vector<argument>& input)
    -> decltype(x.compute(output_shape, input), std::true_type{});

template <class T>
auto is_context_free_op(rank<0>, const T&, const shape&, const std::vector<argument>&)
    -> std::false_type;

template <class T>
auto is_context_free_op(const T& x) -> decltype(is_context_free_op(
    rank<1>{}, x, std::declval<const shape&>(), std::declval<std::vector<argument>>()))
{
    return {};
Paul's avatar
Paul committed
177
178
}

Paul's avatar
Paul committed
179
template <class T>
Paul's avatar
Paul committed
180
std::ptrdiff_t output_alias_op(rank<0>, const T&, const std::vector<shape>&)
Paul's avatar
Paul committed
181
182
183
184
185
186
187
188
189
190
191
192
{
    return -1;
}

template <class T>
auto output_alias_op(rank<1>, const T& x, const std::vector<shape>& shapes)
    -> decltype(x.output_alias(shapes))
{
    return x.output_alias(shapes);
}

template <class T>
Paul's avatar
Paul committed
193
std::ptrdiff_t output_alias_op(const T& x, const std::vector<shape>& shapes)
Paul's avatar
Paul committed
194
195
196
197
{
    return output_alias_op(rank<1>{}, x, shapes);
}

Paul's avatar
Paul committed
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
template <class T>
auto finalize_op(
    rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
    -> decltype(x.finalize(auto_any_cast(ctx), output_shape, input), void())
{
    x.finalize(auto_any_cast(ctx), output_shape, input);
}

template <class T>
void finalize_op(rank<0>, T&, context&, const shape&, const std::vector<shape>&)
{
}

template <class T>
void finalize_op(T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
{
    finalize_op(rank<1>{}, x, ctx, output_shape, input);
}

template <class T>
auto has_finalize_op(
    rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
    -> decltype(x.finalize(auto_any_cast(ctx), output_shape, input), std::true_type{});

template <class T>
auto has_finalize_op(rank<0>, T&, context&, const shape&, const std::vector<shape>&)
    -> std::false_type;

template <class T>
auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
                                                           std::declval<T&>(),
                                                           std::declval<context&>(),
                                                           std::declval<const shape&>(),
                                                           std::declval<std::vector<shape>>()))
{
    return {};
}

Paul's avatar
Paul committed
236
/*
Paul's avatar
Paul committed
237
238
 * Type-erased interface for:
 *
Paul's avatar
Paul committed
239
 * struct operation
Paul's avatar
Paul committed
240
 * {
Paul's avatar
Paul committed
241
 *      std::string name() const;
Paul's avatar
Paul committed
242
 *      bool is_context_free() const;
Paul's avatar
Paul committed
243
 *      bool has_finalize() const;
Paul's avatar
Paul committed
244
 *      std::ptrdiff_t output_alias(const std::vector<shape>& input) const;
Paul's avatar
Paul committed
245
 *      void finalize(context& ctx,const shape& output,const std::vector<shape>& input) ;
Paul's avatar
Paul committed
246
247
 *      shape compute_shape(const std::vector<shape>& input) const;
 *      argument compute(context& ctx,const shape& output,const std::vector<argument>& input) const;
Paul's avatar
Paul committed
248
 *      argument compute(const shape& output,const std::vector<argument>& input) const;
Paul's avatar
Paul committed
249
 *     friend std::ostream & operator<<(std::ostream & os,const operation & op) ;
Paul's avatar
Paul committed
250
 *     friend bool operator==(const operation & x,const operation & y) ;
Paul's avatar
Paul committed
251
252
253
 * };
 *
 */
Paul's avatar
Paul committed
254

Paul's avatar
Paul committed
255
struct operation
Paul's avatar
Paul committed
256
{
Paul's avatar
Paul committed
257
    // Constructors
Paul's avatar
Paul committed
258
    operation() = default;
Paul's avatar
Paul committed
259

Paul's avatar
Paul committed
260
    template <typename PrivateDetailTypeErasedT>
Paul's avatar
Paul committed
261
    operation(PrivateDetailTypeErasedT value)
Paul's avatar
Paul committed
262
263
264
265
        : private_detail_te_handle_mem_var(
              std::make_shared<private_detail_te_handle_type<
                  typename std::remove_reference<PrivateDetailTypeErasedT>::type>>(
                  std::forward<PrivateDetailTypeErasedT>(value)))
Paul's avatar
Paul committed
266
267
268
269
    {
    }

    // Assignment
Paul's avatar
Paul committed
270
    template <typename PrivateDetailTypeErasedT>
Paul's avatar
Paul committed
271
    operation& operator=(PrivateDetailTypeErasedT value)
Paul's avatar
Paul committed
272
    {
Paul's avatar
Paul committed
273
274
275
276
277
        if(private_detail_te_handle_mem_var.unique())
            *private_detail_te_handle_mem_var = std::forward<PrivateDetailTypeErasedT>(value);
        else if(!private_detail_te_handle_mem_var)
            private_detail_te_handle_mem_var = std::make_shared<PrivateDetailTypeErasedT>(
                std::forward<PrivateDetailTypeErasedT>(value));
Paul's avatar
Paul committed
278
279
280
        return *this;
    }

Paul's avatar
Paul committed
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
    // Cast
    template <typename PrivateDetailTypeErasedT>
    PrivateDetailTypeErasedT* any_cast()
    {
        return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
                   ? std::addressof(static_cast<private_detail_te_handle_type<
                                        typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
                                        private_detail_te_get_handle())
                                        .private_detail_te_value)
                   : nullptr;
    }

    template <typename PrivateDetailTypeErasedT>
    const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
    {
        return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
                   ? std::addressof(static_cast<const private_detail_te_handle_type<
                                        typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
                                        private_detail_te_get_handle())
                                        .private_detail_te_value)
                   : nullptr;
    }

Paul's avatar
Paul committed
304
305
306
307
308
309
310
311
    const std::type_info& type_id() const
    {
        if(private_detail_te_handle_empty())
            return typeid(std::nullptr_t);
        else
            return private_detail_te_get_handle().type();
    }

Paul's avatar
Paul committed
312
313
    std::string name() const
    {
Paul's avatar
Paul committed
314
315
        assert((*this).private_detail_te_handle_mem_var);
        return (*this).private_detail_te_get_handle().name();
Paul's avatar
Paul committed
316
317
    }

Paul's avatar
Paul committed
318
319
320
321
322
323
    bool is_context_free() const
    {
        assert((*this).private_detail_te_handle_mem_var);
        return (*this).private_detail_te_get_handle().is_context_free();
    }

Paul's avatar
Paul committed
324
325
326
327
328
329
    bool has_finalize() const
    {
        assert((*this).private_detail_te_handle_mem_var);
        return (*this).private_detail_te_get_handle().has_finalize();
    }

Paul's avatar
Paul committed
330
    std::ptrdiff_t output_alias(const std::vector<shape>& input) const
Paul's avatar
Paul committed
331
332
333
334
335
    {
        assert((*this).private_detail_te_handle_mem_var);
        return (*this).private_detail_te_get_handle().output_alias(input);
    }

Paul's avatar
Paul committed
336
337
338
339
340
341
    void finalize(context& ctx, const shape& output, const std::vector<shape>& input)
    {
        assert((*this).private_detail_te_handle_mem_var);
        (*this).private_detail_te_get_handle().finalize(ctx, output, input);
    }

Paul's avatar
Paul committed
342
    shape compute_shape(const std::vector<shape>& input) const
Paul's avatar
Paul committed
343
    {
Paul's avatar
Paul committed
344
        assert((*this).private_detail_te_handle_mem_var);
Paul's avatar
Paul committed
345
        return (*this).private_detail_te_get_handle().compute_shape(input);
Paul's avatar
Paul committed
346
347
    }

Paul's avatar
Paul committed
348
    argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const
Paul's avatar
Paul committed
349
    {
Paul's avatar
Paul committed
350
        assert((*this).private_detail_te_handle_mem_var);
Paul's avatar
Paul committed
351
        return (*this).private_detail_te_get_handle().compute(ctx, output, input);
Paul's avatar
Paul committed
352
353
    }

Paul's avatar
Paul committed
354
355
356
357
358
359
    argument compute(const shape& output, const std::vector<argument>& input) const
    {
        assert((*this).private_detail_te_handle_mem_var);
        return (*this).private_detail_te_get_handle().compute(output, input);
    }

Paul's avatar
Paul committed
360
361
362
363
    friend std::ostream& operator<<(std::ostream& os, const operation& op)
    {
        assert(op.private_detail_te_handle_mem_var);
        return op.private_detail_te_get_handle().operator_shift_left(os);
Paul's avatar
Paul committed
364
365
    }

Paul's avatar
Paul committed
366
367
368
369
370
371
    friend bool operator==(const operation& x, const operation& y)
    {
        assert(x.private_detail_te_handle_mem_var);
        return x.private_detail_te_get_handle().operator==(y);
    }

Paul's avatar
Paul committed
372
373
374
375
376
377
    friend bool is_shared(const operation& private_detail_x, const operation& private_detail_y)
    {
        return private_detail_x.private_detail_te_handle_mem_var ==
               private_detail_y.private_detail_te_handle_mem_var;
    }

Paul's avatar
Paul committed
378
    private:
Paul's avatar
Paul committed
379
    struct private_detail_te_handle_base_type
Paul's avatar
Paul committed
380
    {
Paul's avatar
Paul committed
381
382
        virtual ~private_detail_te_handle_base_type() {}
        virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
Paul's avatar
Paul committed
383
        virtual const std::type_info& type() const                                = 0;
Paul's avatar
Paul committed
384

Paul's avatar
Paul committed
385
386
387
388
        virtual std::string name() const                                           = 0;
        virtual bool is_context_free() const                                       = 0;
        virtual bool has_finalize() const                                          = 0;
        virtual std::ptrdiff_t output_alias(const std::vector<shape>& input) const = 0;
Paul's avatar
Paul committed
389
390
391
        virtual void
        finalize(context& ctx, const shape& output, const std::vector<shape>& input) = 0;
        virtual shape compute_shape(const std::vector<shape>& input) const           = 0;
Paul's avatar
Paul committed
392
        virtual argument
Paul's avatar
Paul committed
393
394
395
396
        compute(context& ctx, const shape& output, const std::vector<argument>& input) const    = 0;
        virtual argument compute(const shape& output, const std::vector<argument>& input) const = 0;
        virtual std::ostream& operator_shift_left(std::ostream& os) const                       = 0;
        virtual bool operator==(const operation& y) const                                       = 0;
Paul's avatar
Paul committed
397
398
    };

Paul's avatar
Paul committed
399
400
    template <typename PrivateDetailTypeErasedT>
    struct private_detail_te_handle_type : private_detail_te_handle_base_type
Paul's avatar
Paul committed
401
    {
Paul's avatar
Paul committed
402
403
404
405
406
407
        template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
        private_detail_te_handle_type(
            PrivateDetailTypeErasedT value,
            typename std::enable_if<std::is_reference<PrivateDetailTypeErasedU>::value>::type* =
                nullptr)
            : private_detail_te_value(value)
Paul's avatar
Paul committed
408
409
410
        {
        }

Paul's avatar
Paul committed
411
412
413
414
415
416
        template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
        private_detail_te_handle_type(
            PrivateDetailTypeErasedT value,
            typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
                                    int>::type* = nullptr) noexcept
            : private_detail_te_value(std::move(value))
Paul's avatar
Paul committed
417
418
419
        {
        }

Paul's avatar
Paul committed
420
        std::shared_ptr<private_detail_te_handle_base_type> clone() const override
Paul's avatar
Paul committed
421
        {
Paul's avatar
Paul committed
422
            return std::make_shared<private_detail_te_handle_type>(private_detail_te_value);
Paul's avatar
Paul committed
423
424
        }

Paul's avatar
Paul committed
425
        const std::type_info& type() const override { return typeid(private_detail_te_value); }
Paul's avatar
Paul committed
426

Paul's avatar
Paul committed
427
        std::string name() const override { return private_detail_te_value.name(); }
Paul's avatar
Paul committed
428

Paul's avatar
Paul committed
429
430
431
432
433
434
        bool is_context_free() const override
        {

            return is_context_free_op(private_detail_te_value);
        }

Paul's avatar
Paul committed
435
436
        bool has_finalize() const override { return has_finalize_op(private_detail_te_value); }

Paul's avatar
Paul committed
437
        std::ptrdiff_t output_alias(const std::vector<shape>& input) const override
Paul's avatar
Paul committed
438
439
440
441
442
        {

            return output_alias_op(private_detail_te_value, input);
        }

Paul's avatar
Paul committed
443
444
445
446
447
448
        void finalize(context& ctx, const shape& output, const std::vector<shape>& input) override
        {

            finalize_op(private_detail_te_value, ctx, output, input);
        }

Paul's avatar
Paul committed
449
        shape compute_shape(const std::vector<shape>& input) const override
Paul's avatar
Paul committed
450
        {
Paul's avatar
Paul committed
451

Paul's avatar
Paul committed
452
            return private_detail_te_value.compute_shape(input);
Paul's avatar
Paul committed
453
454
        }

Paul's avatar
Paul committed
455
456
457
        argument compute(context& ctx,
                         const shape& output,
                         const std::vector<argument>& input) const override
Paul's avatar
Paul committed
458
        {
Paul's avatar
Paul committed
459

Paul's avatar
Paul committed
460
            return compute_op(private_detail_te_value, ctx, output, input);
Paul's avatar
Paul committed
461
462
        }

Paul's avatar
Paul committed
463
464
465
466
467
468
        argument compute(const shape& output, const std::vector<argument>& input) const override
        {

            return compute_op(private_detail_te_value, output, input);
        }

Paul's avatar
Paul committed
469
470
        std::ostream& operator_shift_left(std::ostream& os) const override
        {
Paul's avatar
Paul committed
471
            using migraphx::operation_stream::operator<<;
Paul's avatar
Paul committed
472
473
474
            return os << private_detail_te_value;
        }

Paul's avatar
Paul committed
475
476
        bool operator==(const operation& y) const override
        {
Paul's avatar
Paul committed
477
            using migraphx::operation_equal::operator==;
Paul's avatar
Paul committed
478
479
480
            return private_detail_te_value == y;
        }

Paul's avatar
Paul committed
481
        PrivateDetailTypeErasedT private_detail_te_value;
Paul's avatar
Paul committed
482
483
    };

Paul's avatar
Paul committed
484
485
486
    template <typename PrivateDetailTypeErasedT>
    struct private_detail_te_handle_type<std::reference_wrapper<PrivateDetailTypeErasedT>>
        : private_detail_te_handle_type<PrivateDetailTypeErasedT&>
Paul's avatar
Paul committed
487
    {
Paul's avatar
Paul committed
488
489
        private_detail_te_handle_type(std::reference_wrapper<PrivateDetailTypeErasedT> ref)
            : private_detail_te_handle_type<PrivateDetailTypeErasedT&>(ref.get())
Paul's avatar
Paul committed
490
491
492
493
        {
        }
    };

Paul's avatar
Paul committed
494
495
496
497
498
    bool private_detail_te_handle_empty() const
    {
        return private_detail_te_handle_mem_var == nullptr;
    }

Paul's avatar
Paul committed
499
500
    const private_detail_te_handle_base_type& private_detail_te_get_handle() const
    {
Paul's avatar
Paul committed
501
        assert(private_detail_te_handle_mem_var != nullptr);
Paul's avatar
Paul committed
502
503
        return *private_detail_te_handle_mem_var;
    }
Paul's avatar
Paul committed
504

Paul's avatar
Paul committed
505
    private_detail_te_handle_base_type& private_detail_te_get_handle()
Paul's avatar
Paul committed
506
    {
Paul's avatar
Paul committed
507
        assert(private_detail_te_handle_mem_var != nullptr);
Paul's avatar
Paul committed
508
509
510
        if(!private_detail_te_handle_mem_var.unique())
            private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
        return *private_detail_te_handle_mem_var;
Paul's avatar
Paul committed
511
512
    }

Paul's avatar
Paul committed
513
    std::shared_ptr<private_detail_te_handle_base_type> private_detail_te_handle_mem_var;
Paul's avatar
Paul committed
514
515
};

Paul's avatar
Paul committed
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
template <typename ValueType>
inline const ValueType* any_cast(const operation* x)
{
    return x->any_cast<ValueType>();
}

template <typename ValueType>
inline ValueType* any_cast(operation* x)
{
    return x->any_cast<ValueType>();
}

template <typename ValueType>
inline ValueType& any_cast(operation& x)
{
Paul's avatar
Paul committed
531
    auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
Paul's avatar
Paul committed
532
533
534
535
536
537
538
539
    if(y == nullptr)
        throw std::bad_cast();
    return *y;
}

template <typename ValueType>
inline const ValueType& any_cast(const operation& x)
{
Paul's avatar
Paul committed
540
    const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
Paul's avatar
Paul committed
541
542
543
544
545
    if(y == nullptr)
        throw std::bad_cast();
    return *y;
}

Paul's avatar
Paul committed
546
547
inline bool operator!=(const operation& x, const operation& y) { return !(x == y); }

Paul's avatar
Paul committed
548
549
550
551
552
553
554
555
inline bool is_context_free(const operation& op) { return op.is_context_free(); }

template <class T>
bool is_context_free(const T& x)
{
    return is_context_free_op(x);
}

Paul's avatar
Paul committed
556
557
558
559
560
561
562
563
inline bool has_finalize(const operation& op) { return op.has_finalize(); }

template <class T>
bool has_finalize(const T& x)
{
    return has_finalize_op(x);
}

Paul's avatar
Paul committed
564
565
#endif

Paul's avatar
Paul committed
566
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
567
} // namespace migraphx
Paul's avatar
Paul committed
568
569

#endif