operation.hpp 19.9 KB
Newer Older
Paul's avatar
Paul committed
1
2
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_OPERAND_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_OPERAND_HPP
Paul's avatar
Paul committed
3

Paul's avatar
Paul committed
4
#include <cassert>
Paul's avatar
Paul committed
5
6
7
8
9
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
Shucai Xiao's avatar
Shucai Xiao committed
10
#include <unordered_map>
Paul's avatar
Paul committed
11
12
#include <migraphx/reflect.hpp>
#include <migraphx/streamutils.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
13
#include <migraphx/normalize_attributes.hpp>
Paul's avatar
Paul committed
14
#include <migraphx/argument.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
15
#include <migraphx/module_ref.hpp>
16
#include <migraphx/serialize.hpp>
Paul's avatar
Paul committed
17
#include <migraphx/auto_any_cast.hpp>
18
#include <migraphx/lifetime.hpp>
Paul's avatar
Paul committed
19
#include <migraphx/config.hpp>
Paul's avatar
Paul committed
20

Paul's avatar
Paul committed
21
namespace migraphx {
Paul's avatar
Paul committed
22
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
23

Paul's avatar
Paul committed
24
25
struct context;

Paul's avatar
Paul committed
26
27
28
29
30
31
32
33
#ifdef DOXYGEN

/// The operation interface represents an action an instruction will perform. All
/// operation classes must be CopyConstructible.
struct operation
{
    /// A unique name identifying the operation
    std::string name() const;
Paul's avatar
Paul committed
34
35
    /// An optional method that can be used to finalize the operator before running
    void finalize(context& ctx);
Paul's avatar
Paul committed
36
37
38
    /// This is used to compute the resulting shape from an operation. If an
    /// operation cannot be run with input shapes, then it should throw an
    /// exception.
Paul's avatar
Paul committed
39
    shape compute_shape(const std::vector<shape>& input) const;
Paul's avatar
Paul committed
40
    /**
Paul's avatar
Paul committed
41
42
     * @brief This performs the operation's computation.
     *
Paul's avatar
Paul committed
43
44
     * This method can be optional when the operation is only used as a placeholder to be lowered
     * later on.
Paul's avatar
Paul committed
45
46
47
48
49
     *
     * @param ctx This is the context created by the `target` during compilation. Implementations
     * can use the target's `context` class rather than the `context` interface class.
     * @param output This is the output shape. It is equivalent to running `compute_shape` with each
     * `shape` of the `argument`.
Paul's avatar
Paul committed
50
     * @param input This is the `argument` result from the previous instruction's computation.
Paul's avatar
Paul committed
51
52
     * @return Return an `argument` of the result computation. The `shape` of `argument` should be
     * the same the `output` shape.
Paul's avatar
Paul committed
53
     */
Paul's avatar
Paul committed
54
    argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const;
Paul's avatar
Paul committed
55
56
    /// An optional method to return which argument the output will alias. If
    /// there is no aliased output then -1 can be returned.
Paul's avatar
Paul committed
57
    std::ptrdiff_t output_alias(const std::vector<shape>& input) const;
Paul's avatar
Paul committed
58
59
    /// An optional stream operator to print the operation. When this is not
    /// implemented, it will just print the operation's name.
Paul's avatar
Paul committed
60
    friend std::ostream& operator<<(std::ostream& os, const operation& op);
Paul's avatar
Paul committed
61
62
};

Paul's avatar
Paul committed
63
64
/// Returns true if operation does not require a context to run compute
bool is_context_free(const operation& x);
Shucai Xiao's avatar
Shucai Xiao committed
65
66
/// Returns true if operation needs normalization before running compute
bool need_normalization(const operation& x);
Paul's avatar
Paul committed
67
68
/// Returns true if the operation has a finalize method
bool has_finalize(const operation& x);
Paul's avatar
Paul committed
69

Paul's avatar
Paul committed
70
71
#else

72
73
74
namespace detail {

namespace operation_operators {
Paul's avatar
Paul committed
75

Paul's avatar
Paul committed
76
77
template <class T>
auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
Paul's avatar
Paul committed
78
{
Paul's avatar
Paul committed
79
80
    os << x.name();
    char delim = '[';
Paul's avatar
Paul committed
81
    reflect_each(x, [&](auto&& y, auto name) {
Paul's avatar
Paul committed
82
        os << delim;
Paul's avatar
Paul committed
83
84
        os << name << "=";
        stream_write_value(os, y);
Paul's avatar
Paul committed
85
86
        delim = ',';
    });
Paul's avatar
Paul committed
87
88
    if(delim == ',')
        os << "]";
Paul's avatar
Paul committed
89
    return os;
Paul's avatar
Paul committed
90
91
}

Paul's avatar
Paul committed
92
93
94
template <class T, class U>
auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
{
Paul's avatar
Paul committed
95
96
    static_assert(is_reflectable<T>{} or sizeof(T) <= 1,
                  "Missing equality operator or reflect method.");
Paul's avatar
Paul committed
97
98
99
100
101
102
    if(x.name() != y.name())
        return false;
    const auto& yy = any_cast<T>(y);
    return reflect_tie(x) == reflect_tie(yy);
}

103
} // namespace operation_operators
Paul's avatar
Paul committed
104

Shucai Xiao's avatar
Shucai Xiao committed
105
template <class T>
106
107
108
109
110
111
112
113
auto compute_shape_op(rank<3>, const T& x, const std::vector<shape>& inputs)
    -> decltype(x.compute_shape(inputs))
{
    return x.compute_shape(inputs);
}

template <class T>
auto compute_shape_op(rank<2>, const T& x, const std::vector<shape>& inputs)
Shucai Xiao's avatar
Shucai Xiao committed
114
    -> decltype(x.normalize_compute_shape(inputs))
Shucai Xiao's avatar
Shucai Xiao committed
115
116
117
118
119
120
{
    dependent_type<operation, T> y = x;
    normalize_attributes(y, inputs[0].lens());
    return any_cast<T>(y).normalize_compute_shape(inputs);
}

121
template <class T>
122
auto compute_shape_op(rank<1>, const T& x, const std::vector<shape>& inputs)
123
124
125
126
127
    -> decltype(x.compute_shape(inputs, {}))
{
    return x.compute_shape(inputs, {});
}

Shucai Xiao's avatar
Shucai Xiao committed
128
template <class T>
129
shape compute_shape_op(rank<0>, const T& x, const std::vector<shape>&)
Shucai Xiao's avatar
Shucai Xiao committed
130
131
132
133
134
135
{
    std::string name = x.name();
    MIGRAPHX_THROW("Shape not computable: " + name);
}

template <class T>
136
shape compute_shape_op(const T& x, const std::vector<shape>& inputs)
Shucai Xiao's avatar
Shucai Xiao committed
137
{
138
    return compute_shape_op(rank<3>{}, x, inputs);
Shucai Xiao's avatar
Shucai Xiao committed
139
140
141
}

template <class T>
142
143
144
145
auto mod_compute_shape_op(rank<1>,
                          const T& x,
                          const std::vector<shape>& inputs,
                          const std::vector<module_ref>& mod_args)
Shucai Xiao's avatar
Shucai Xiao committed
146
147
148
149
150
151
    -> decltype(x.compute_shape(inputs, mod_args))
{
    return x.compute_shape(inputs, mod_args);
}

template <class T>
152
153
154
155
shape mod_compute_shape_op(rank<0>,
                           const T& x,
                           const std::vector<shape>& inputs,
                           const std::vector<module_ref>& mod_args)
Shucai Xiao's avatar
Shucai Xiao committed
156
{
157
158
    if(mod_args.empty())
        return compute_shape_op(x, inputs);
Shucai Xiao's avatar
Shucai Xiao committed
159
160
161
162
163
    std::string name = x.name();
    MIGRAPHX_THROW("Shape not computable: " + name);
}

template <class T>
164
165
166
shape mod_compute_shape_op(const T& x,
                           const std::vector<shape>& inputs,
                           const std::vector<module_ref>& mod_args)
Shucai Xiao's avatar
Shucai Xiao committed
167
{
168
    return mod_compute_shape_op(rank<1>{}, x, inputs, mod_args);
Shucai Xiao's avatar
Shucai Xiao committed
169
170
}

Paul's avatar
Paul committed
171
template <class T>
Shucai Xiao's avatar
Shucai Xiao committed
172
auto compute_op(rank<1>,
Paul's avatar
Paul committed
173
174
175
176
177
                const T& x,
                context& ctx,
                const shape& output_shape,
                const std::vector<argument>& input)
    -> decltype(x.compute(auto_any_cast(ctx), output_shape, input))
Paul's avatar
Paul committed
178
179
180
181
182
{
    return x.compute(auto_any_cast(ctx), output_shape, input);
}

template <class T>
Paul's avatar
Paul committed
183
argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
Paul's avatar
Paul committed
184
{
Paul's avatar
Paul committed
185
    std::string name = x.name();
Paul's avatar
Paul committed
186
    MIGRAPHX_THROW("Not computable: " + name);
Paul's avatar
Paul committed
187
188
}

Paul's avatar
Paul committed
189
template <class T>
Paul's avatar
Paul committed
190
191
argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
Paul's avatar
Paul committed
192
{
Shucai Xiao's avatar
Shucai Xiao committed
193
    return compute_op(rank<1>{}, x, ctx, output_shape, input);
Paul's avatar
Paul committed
194
195
196
}

template <class T>
Shucai Xiao's avatar
Shucai Xiao committed
197
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
Paul's avatar
Paul committed
198
199
200
201
202
203
    -> decltype(x.compute(output_shape, input))
{
    return x.compute(output_shape, input);
}

template <class T>
Shucai Xiao's avatar
Shucai Xiao committed
204
argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&)
Paul's avatar
Paul committed
205
206
{
    std::string name = x.name();
Shucai Xiao's avatar
Shucai Xiao committed
207
    MIGRAPHX_THROW("Not computable: " + name);
Paul's avatar
Paul committed
208
209
210
}

template <class T>
Shucai Xiao's avatar
Shucai Xiao committed
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
argument compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input)
{
    return compute_op(rank<1>{}, x, output_shape, input);
}

template <class T, class F>
auto compute_op(rank<1>,
                const T& x,
                const shape& output,
                const std::vector<argument>& inputs,
                const std::vector<module_ref>& module_args,
                F f) -> decltype(x.compute(output, inputs, module_args, f))
{
    return x.compute(output, inputs, module_args, f);
}

template <class T, class F>
argument compute_op(rank<0>,
                    const T& x,
                    const shape&,
                    const std::vector<argument>&,
                    const std::vector<module_ref>&,
                    F)
Paul's avatar
Paul committed
234
235
236
237
238
{
    std::string name = x.name();
    MIGRAPHX_THROW("Not computable: " + name);
}

Shucai Xiao's avatar
Shucai Xiao committed
239
240
241
242
243
244
template <class T, class F>
argument compute_op(const T& x,
                    const shape& output,
                    const std::vector<argument>& inputs,
                    const std::vector<module_ref>& module_args,
                    F f)
Paul's avatar
Paul committed
245
{
Shucai Xiao's avatar
Shucai Xiao committed
246
    return compute_op(rank<1>{}, x, output, inputs, module_args, f);
Paul's avatar
Paul committed
247
248
}

Shucai Xiao's avatar
Shucai Xiao committed
249
250
251
252
253
254
255
256
257
258
259
260
template <class T, class F>
auto compute_op(rank<4>,
                const T& x,
                context& ctx,
                const shape& output,
                const std::vector<argument>& inputs,
                const std::vector<module_ref>& module_args,
                F f) -> decltype(x.compute(auto_any_cast(ctx), output, inputs, module_args, f))
{
    return x.compute(auto_any_cast(ctx), output, inputs, module_args, f);
}

Shucai Xiao's avatar
Shucai Xiao committed
261
template <class T, class F>
Shucai Xiao's avatar
Shucai Xiao committed
262
auto compute_op(rank<3>,
Shucai Xiao's avatar
Shucai Xiao committed
263
                const T& x,
Shucai Xiao's avatar
Shucai Xiao committed
264
265
                context&,
                const shape& output,
Shucai Xiao's avatar
Shucai Xiao committed
266
267
                const std::vector<argument>& inputs,
                const std::vector<module_ref>& module_args,
Shucai Xiao's avatar
Shucai Xiao committed
268
                F f) -> decltype(x.compute(output, inputs, module_args, f))
Shucai Xiao's avatar
Shucai Xiao committed
269
{
Shucai Xiao's avatar
Shucai Xiao committed
270
    return x.compute(output, inputs, module_args, f);
Shucai Xiao's avatar
Shucai Xiao committed
271
272
273
}

template <class T, class F>
Shucai Xiao's avatar
Shucai Xiao committed
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
auto compute_op(rank<2>,
                const T& x,
                context&,
                const shape& output,
                const std::vector<argument>& inputs,
                const std::vector<module_ref>&,
                F) -> decltype(x.compute(output, inputs))
{
    return x.compute(output, inputs);
}

template <class T, class F>
auto compute_op(rank<1>,
                const T& x,
                context& ctx,
                const shape& output,
                const std::vector<argument>& inputs,
                const std::vector<module_ref>&,
                F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs))
{
    return x.compute(auto_any_cast(ctx), output, inputs);
}

template <class T, class F>
argument compute_op(rank<0>,
                    const T& x,
                    context&,
                    const shape&,
                    const std::vector<argument>&,
                    const std::vector<module_ref>&,
                    F)
Shucai Xiao's avatar
Shucai Xiao committed
305
306
307
308
309
310
311
{
    std::string name = x.name();
    MIGRAPHX_THROW("Not computable: " + name);
}

template <class T, class F>
argument compute_op(const T& x,
Shucai Xiao's avatar
Shucai Xiao committed
312
313
                    context& ctx,
                    const shape& output,
Shucai Xiao's avatar
Shucai Xiao committed
314
315
316
317
                    const std::vector<argument>& inputs,
                    const std::vector<module_ref>& module_args,
                    F f)
{
Shucai Xiao's avatar
Shucai Xiao committed
318
    return compute_op(rank<4>{}, x, ctx, output, inputs, module_args, f);
Shucai Xiao's avatar
Shucai Xiao committed
319
320
}

Paul's avatar
Paul committed
321
template <class T>
Paul's avatar
Paul committed
322
323
324
325
326
auto is_context_free_op(rank<1>,
                        const T& x,
                        const shape& output_shape,
                        const std::vector<argument>& input)
    -> decltype(x.compute(output_shape, input), std::true_type{});
Paul's avatar
Paul committed
327
328

template <class T>
Paul's avatar
Paul committed
329
330
auto is_context_free_op(rank<0>, const T&, const shape&, const std::vector<argument>&)
    -> std::false_type;
Paul's avatar
Paul committed
331
332

template <class T>
Paul's avatar
Paul committed
333
334
auto is_context_free_op(const T& x) -> decltype(is_context_free_op(
    rank<1>{}, x, std::declval<const shape&>(), std::declval<std::vector<argument>>()))
Paul's avatar
Paul committed
335
336
{
    return {};
Paul's avatar
Paul committed
337
338
}

Shucai Xiao's avatar
Shucai Xiao committed
339
340
341
342
343
344
345
346
347
348
349
350
351
352
template <class T>
auto need_normalization_op(rank<1>, const T& x, const std::vector<shape>& inputs)
    -> decltype(x.normalize_compute_shape(inputs), std::true_type{});

template <class T>
auto need_normalization_op(rank<0>, const T&, const std::vector<shape>&) -> std::false_type;

template <class T>
auto need_normalization_op(const T& x)
    -> decltype(need_normalization_op(rank<1>{}, x, std::declval<std::vector<shape>>()))
{
    return {};
}

Paul's avatar
Paul committed
353
template <class T>
354
std::ptrdiff_t output_alias_op(const T&, const std::vector<shape>&)
Paul's avatar
Paul committed
355
356
357
358
{
    return -1;
}

Paul's avatar
Paul committed
359
template <class T>
Paul's avatar
Paul committed
360
361
auto finalize_op(
    rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
Paul's avatar
Paul committed
362
363
364
365
366
367
368
    -> decltype(x.finalize(auto_any_cast(ctx), output_shape, input), void())
{
    x.finalize(auto_any_cast(ctx), output_shape, input);
}

template <class T>
void finalize_op(rank<0>, T&, context&, const shape&, const std::vector<shape>&)
Paul's avatar
Paul committed
369
370
{
}
Paul's avatar
Paul committed
371
372
373
374
375
376
377
378

template <class T>
void finalize_op(T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
{
    finalize_op(rank<1>{}, x, ctx, output_shape, input);
}

template <class T>
Paul's avatar
Paul committed
379
380
auto has_finalize_op(
    rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
Paul's avatar
Paul committed
381
382
383
384
385
386
387
    -> decltype(x.finalize(auto_any_cast(ctx), output_shape, input), std::true_type{});

template <class T>
auto has_finalize_op(rank<0>, T&, context&, const shape&, const std::vector<shape>&)
    -> std::false_type;

template <class T>
Paul's avatar
Paul committed
388
389
390
391
392
auto has_finalize_op(const T&) -> decltype(has_finalize_op(rank<1>{},
                                                           std::declval<T&>(),
                                                           std::declval<context&>(),
                                                           std::declval<const shape&>(),
                                                           std::declval<std::vector<shape>>()))
Paul's avatar
Paul committed
393
394
395
396
{
    return {};
}

397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
template <class T>
auto compile_op(
    rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
    -> decltype(x.compile(auto_any_cast(ctx), output_shape, input))
{
    return x.compile(auto_any_cast(ctx), output_shape, input);
}

template <class T>
value compile_op(rank<0>, T&, context&, const shape&, const std::vector<shape>&)
{
    return value::object{};
}

template <class T>
value compile_op(const T& x,
                 context& ctx,
                 const shape& output_shape,
                 const std::vector<shape>& input)
{
    return compile_op(rank<1>{}, x, ctx, output_shape, input);
}

420
421
422
423
424
425
template <class T>
value attributes_op(const T&)
{
    return value::object{};
}

426
427
428
429
430
431
432
433
434
template <class T>
value to_value_op(const T& x)
{
    return migraphx::to_value(x);
}

template <class T>
void from_value_op(T& x, const value& v)
{
435
436
    if(not(v.is_object() or (v.empty() and v.is_array())))
        MIGRAPHX_THROW("Value is not an object");
437
438
439
    return migraphx::from_value(v, x);
}

440
template <class T>
441
lifetime get_lifetime_op(const T&)
442
{
443
    return lifetime::local;
444
445
}

446
447
} // namespace detail

448
<%
Paul's avatar
Paul committed
449
450
451
 interface(
     'operation',
     virtual('name', returns = 'std::string', const = True),
Paul's avatar
Paul committed
452
453
     virtual(
         'is_context_free', returns = 'bool', const = True, default = 'detail::is_context_free_op'),
Shucai Xiao's avatar
Shucai Xiao committed
454
455
456
457
     virtual('need_normalization',
             returns = 'bool',
             const   = True,
             default = 'detail::need_normalization_op'),
458
     virtual('has_finalize', returns = 'bool', const = True, default = 'detail::has_finalize_op'),
459
460
     virtual(
         'get_lifetime', returns = 'lifetime', const = True, default = 'detail::get_lifetime_op'),
Paul's avatar
Paul committed
461
     virtual('output_alias',
Paul's avatar
Paul committed
462
             returns = 'std::ptrdiff_t',
Paul's avatar
Paul committed
463
464
             input   = 'const std::vector<shape>&',
             const   = True,
465
             default = 'detail::output_alias_op'),
466
467
468
469
470
471
     virtual('compile',
             returns = 'value',
             ctx     = 'context&',
             output  = 'const shape&',
             input   = 'const std::vector<shape>&',
             default = 'detail::compile_op'),
Paul's avatar
Paul committed
472
473
474
475
     virtual('finalize',
             ctx     = 'context&',
             output  = 'const shape&',
             input   = 'const std::vector<shape>&',
476
             default = 'detail::finalize_op'),
Shucai Xiao's avatar
Shucai Xiao committed
477
478
479
480
     virtual('compute_shape',
             returns = 'shape',
             input   = 'const std::vector<shape>&',
             const   = True,
481
             default = 'detail::compute_shape_op'),
Shucai Xiao's avatar
Shucai Xiao committed
482
483
484
485
486
     virtual('compute_shape',
             returns  = 'shape',
             inputs   = 'const std::vector<shape>&',
             mod_args = 'const std::vector<module_ref>&',
             const    = True,
487
             default  = 'detail::mod_compute_shape_op'),
Paul's avatar
Paul committed
488
489
490
491
492
493
     virtual('compute',
             returns = 'argument',
             ctx     = 'context&',
             output  = 'const shape&',
             input   = 'const std::vector<argument>&',
             const   = True,
494
             default = 'detail::compute_op'),
Paul's avatar
Paul committed
495
496
497
498
499
     virtual('compute',
             returns = 'argument',
             output  = 'const shape&',
             input   = 'const std::vector<argument>&',
             const   = True,
500
             default = 'detail::compute_op'),
Shucai Xiao's avatar
Shucai Xiao committed
501
502
503
     virtual(
         'compute',
         returns     = 'argument',
Shucai Xiao's avatar
Shucai Xiao committed
504
505
506
507
508
509
510
511
512
513
514
515
         output      = 'const shape&',
         input       = 'const std::vector<argument>&',
         module_args = 'const std::vector<module_ref>&',
         run =
             'std::function<std::vector<argument>(module_ref&, const std::unordered_map<std::string, argument>&)>',
         const   = True,
         default = 'detail::compute_op'),
     virtual(
         'compute',
         returns     = 'argument',
         ctx         = 'context&',
         output      = 'const shape&',
Shucai Xiao's avatar
Shucai Xiao committed
516
517
518
         input       = 'const std::vector<argument>&',
         module_args = 'const std::vector<module_ref>&',
         run =
Shucai Xiao's avatar
Shucai Xiao committed
519
             'std::function<std::vector<argument>(module_ref&, const std::unordered_map<std::string, argument>&)>',
Shucai Xiao's avatar
Shucai Xiao committed
520
521
         const   = True,
         default = 'detail::compute_op'),
522
523
     virtual('to_value', returns = 'value', const = True, default = 'detail::to_value_op'),
     virtual('from_value', v = 'const value&', default = 'detail::from_value_op'),
524
     virtual('attributes', returns = 'value', const = True, default = 'detail::attributes_op'),
Paul's avatar
Paul committed
525
526
527
528
     friend('operator<<',
            returns = 'std::ostream &',
            os      = 'std::ostream &',
            op      = 'const operation &',
529
            using   = 'migraphx::detail::operation_operators::operator<<'),
Paul's avatar
Paul committed
530
531
532
533
     friend('operator==',
            returns = 'bool',
            x       = 'const operation &',
            y       = 'const operation &',
534
            using   = 'migraphx::detail::operation_operators::operator==')) %>
Paul's avatar
Paul committed
535
536

    inline bool operator!=(const operation& x, const operation& y)
Paul's avatar
Paul committed
537
538
539
540
{
    return !(x == y);
}

541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
inline value
compile(operation& op, context& ctx, const shape& output_shape, const std::vector<shape>& input)
{
    return op.compile(ctx, output_shape, input);
}
template <class Context>
inline value
compile(operation& op, Context& ctx, const shape& output_shape, const std::vector<shape>& input)
{
    dependent_type<context, Context> ctx2 = std::ref(ctx);
    return compile(op, ctx2, output_shape, input);
}
template <class T, class Context>
inline auto compile(T& op, Context& ctx, const shape& output_shape, const std::vector<shape>& input)
    -> decltype(op.compile(ctx, ctx, output_shape, input))
{
    return op.compile(ctx, ctx, output_shape, input);
}
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
inline shape compute_shape(const operation& op, const std::vector<shape>& inputs)
{
    return op.compute_shape(inputs);
}

template <class T>
inline auto compute_shape(const T& op, const std::vector<shape>& inputs)
    -> decltype(op.compute_shape(inputs))
{
    return op.compute_shape(inputs);
}

template <class T>
inline auto compute_shape(const T& op, const std::vector<shape>& inputs)
    -> decltype(op.normalize_compute_shape(inputs))
{
575
    return detail::compute_shape_op(op, inputs);
576
577
}

Shucai Xiao's avatar
Shucai Xiao committed
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
inline shape compute_shape(const operation& op,
                           const std::vector<shape>& inputs,
                           const std::vector<module_ref>& mod_args)
{
    return op.compute_shape(inputs, mod_args);
}

template <class T>
inline auto compute_shape(const T& op,
                          const std::vector<shape>& inputs,
                          const std::vector<module_ref>& mod_args)
    -> decltype(op.compute_shape(inputs, mod_args))
{
    return op.compute_shape(inputs, mod_args);
}

template <class T>
inline auto compute_shape(const T& op,
                          const std::vector<shape>& inputs,
                          const std::vector<module_ref>& mod_args)
    -> decltype(op.normalize_compute_shape(inputs, mod_args))
{
600
    return detail::compute_shape_op(op, inputs, mod_args);
Shucai Xiao's avatar
Shucai Xiao committed
601
602
}

Paul's avatar
Paul committed
603
inline bool is_context_free(const operation& op) { return op.is_context_free(); }
Paul's avatar
Paul committed
604

Paul's avatar
Paul committed
605
template <class T>
Paul's avatar
Paul committed
606
607
bool is_context_free(const T& x)
{
608
    return detail::is_context_free_op(x);
Paul's avatar
Paul committed
609
610
}

Shucai Xiao's avatar
Shucai Xiao committed
611
612
613
614
615
616
617
618
inline bool need_normalization(const operation& op) { return op.need_normalization(); }

template <class T>
bool need_normalization(const T& x)
{
    return detail::need_normalization_op(x);
}

Paul's avatar
Paul committed
619
620
621
622
623
inline bool has_finalize(const operation& op) { return op.has_finalize(); }

template <class T>
bool has_finalize(const T& x)
{
624
    return detail::has_finalize_op(x);
Paul's avatar
Paul committed
625
626
}

627
628
629
void migraphx_to_value(value& v, const operation& op);
void migraphx_from_value(const value& v, operation& op);

Paul's avatar
Paul committed
630
631
#endif

Paul's avatar
Paul committed
632
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
633
} // namespace migraphx
Paul's avatar
Paul committed
634
635

#endif