operation.hpp 9.26 KB
Newer Older
Paul's avatar
Paul committed
1
2
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_OPERAND_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_OPERAND_HPP
Paul's avatar
Paul committed
3

Paul's avatar
Paul committed
4
#include <cassert>
Paul's avatar
Paul committed
5
6
7
8
9
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
Paul's avatar
Paul committed
10
11
12
13
14
15
#include <migraphx/shape.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/context.hpp>
#include <migraphx/auto_any_cast.hpp>
Paul's avatar
Paul committed
16
#include <migraphx/config.hpp>
Paul's avatar
Paul committed
17

Paul's avatar
Paul committed
18
namespace migraphx {
Paul's avatar
Paul committed
19
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
20

Paul's avatar
Paul committed
21
22
23
24
25
26
27
28
#ifdef DOXYGEN

/// The operation interface represents an action an instruction will perform. All
/// operation classes must be CopyConstructible.
struct operation
{
    /// A unique name identifying the operation
    std::string name() const;
Paul's avatar
Paul committed
29
30
    /// An optional method that can be used to finalize the operator before running
    void finalize(context& ctx);
Paul's avatar
Paul committed
31
32
33
    /// This is used to compute the resulting shape from an operation. If an
    /// operation cannot be run with input shapes, then it should throw an
    /// exception.
Paul's avatar
Paul committed
34
    shape compute_shape(const std::vector<shape>& input) const;
Paul's avatar
Paul committed
35
    /**
Paul's avatar
Paul committed
36
37
     * @brief This performs the operation's computation.
     *
Paul's avatar
Paul committed
38
39
     * This method can be optional when the operation is only used as a placeholder to be lowered
     * later on.
Paul's avatar
Paul committed
40
41
42
43
44
     *
     * @param ctx This is the context created by the `target` during compilation. Implementations
     * can use the target's `context` class rather than the `context` interface class.
     * @param output This is the output shape. It is equivalent to running `compute_shape` with each
     * `shape` of the `argument`.
Paul's avatar
Paul committed
45
     * @param input This is the `argument` result from the previous instruction's computation.
Paul's avatar
Paul committed
46
47
     * @return Return an `argument` of the result computation. The `shape` of `argument` should be
     * the same the `output` shape.
Paul's avatar
Paul committed
48
     */
Paul's avatar
Paul committed
49
    argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const;
Paul's avatar
Paul committed
50
51
52
    /// An optional method to return which argument the output will alias. If
    /// there is no aliased output then -1 can be returned.
    int output_alias(const std::vector<shape>& input) const;
Paul's avatar
Paul committed
53
54
    /// An optional stream operator to print the operation. When this is not
    /// implemented, it will just print the operation's name.
Paul's avatar
Paul committed
55
    friend std::ostream& operator<<(std::ostream& os, const operation& op);
Paul's avatar
Paul committed
56
57
};

Paul's avatar
Paul committed
58
59
/// Returns true if operation does not require a context to run compute
bool is_context_free(const operation& x);
Paul's avatar
Paul committed
60
61
/// Returns true if the operation has a finalize method
bool has_finalize(const operation& x);
Paul's avatar
Paul committed
62

Paul's avatar
Paul committed
63
64
#else

Paul's avatar
Paul committed
65
66
namespace operation_stream {

Paul's avatar
Paul committed
67
68
template <class T>
auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
Paul's avatar
Paul committed
69
{
Paul's avatar
Paul committed
70
71
    os << x.name();
    char delim = '[';
Paul's avatar
Paul committed
72
    reflect_each(x, [&](auto& y, auto name) {
Paul's avatar
Paul committed
73
        os << delim;
Paul's avatar
Paul committed
74
75
        os << name << "=";
        stream_write_value(os, y);
Paul's avatar
Paul committed
76
77
        delim = ',';
    });
Paul's avatar
Paul committed
78
79
    if(delim == ',')
        os << "]";
Paul's avatar
Paul committed
80
    return os;
Paul's avatar
Paul committed
81
82
}

Paul's avatar
Paul committed
83
} // namespace operation_stream
Paul's avatar
Paul committed
84

Paul's avatar
Paul committed
85
86
87
88
89
90
91
92
93
94
95
96
97
namespace operation_equal {

template <class T, class U>
auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
{
    if(x.name() != y.name())
        return false;
    const auto& yy = any_cast<T>(y);
    return reflect_tie(x) == reflect_tie(yy);
}

} // namespace operation_equal

Paul's avatar
Paul committed
98
template <class T>
Paul's avatar
Paul committed
99
auto compute_op(rank<2>,
Paul's avatar
Paul committed
100
101
102
103
104
                const T& x,
                context& ctx,
                const shape& output_shape,
                const std::vector<argument>& input)
    -> decltype(x.compute(auto_any_cast(ctx), output_shape, input))
Paul's avatar
Paul committed
105
106
107
108
{
    return x.compute(auto_any_cast(ctx), output_shape, input);
}

Paul's avatar
Paul committed
109
template <class T>
Paul's avatar
Paul committed
110
111
auto compute_op(
    rank<1>, const T& x, context&, const shape& output_shape, const std::vector<argument>& input)
Paul's avatar
Paul committed
112
113
114
115
116
    -> decltype(x.compute(output_shape, input))
{
    return x.compute(output_shape, input);
}

Paul's avatar
Paul committed
117
template <class T>
Paul's avatar
Paul committed
118
argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
Paul's avatar
Paul committed
119
{
Paul's avatar
Paul committed
120
    std::string name = x.name();
Paul's avatar
Paul committed
121
    MIGRAPHX_THROW("Not computable: " + name);
Paul's avatar
Paul committed
122
123
}

Paul's avatar
Paul committed
124
template <class T>
Paul's avatar
Paul committed
125
126
argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
Paul's avatar
Paul committed
127
{
Paul's avatar
Paul committed
128
129
130
131
    return compute_op(rank<2>{}, x, ctx, output_shape, input);
}

template <class T>
Paul's avatar
Paul committed
132
auto compute_op(rank<2>, const T& x, const shape& output_shape, const std::vector<argument>& input)
Paul's avatar
Paul committed
133
134
135
136
137
138
    -> decltype(x.compute(output_shape, input))
{
    return x.compute(output_shape, input);
}

template <class T>
Paul's avatar
Paul committed
139
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
Paul's avatar
Paul committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
    -> decltype(x.compute(auto_any_cast(std::declval<context&>()), output_shape, input))
{
    std::string name = x.name();
    MIGRAPHX_THROW("Not computable without a context: " + name);
}

template <class T>
argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&)
{
    std::string name = x.name();
    MIGRAPHX_THROW("Not computable: " + name);
}

template <class T>
Paul's avatar
Paul committed
154
argument compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input)
Paul's avatar
Paul committed
155
156
157
158
159
{
    return compute_op(rank<2>{}, x, output_shape, input);
}

template <class T>
Paul's avatar
Paul committed
160
161
162
163
164
auto is_context_free_op(rank<1>,
                        const T& x,
                        const shape& output_shape,
                        const std::vector<argument>& input)
    -> decltype(x.compute(output_shape, input), std::true_type{});
Paul's avatar
Paul committed
165
166

template <class T>
Paul's avatar
Paul committed
167
168
auto is_context_free_op(rank<0>, const T&, const shape&, const std::vector<argument>&)
    -> std::false_type;
Paul's avatar
Paul committed
169
170

template <class T>
Paul's avatar
Paul committed
171
172
auto is_context_free_op(const T& x) -> decltype(is_context_free_op(
    rank<1>{}, x, std::declval<const shape&>(), std::declval<std::vector<argument>>()))
Paul's avatar
Paul committed
173
174
{
    return {};
Paul's avatar
Paul committed
175
176
}

Paul's avatar
Paul committed
177
template <class T>
Paul's avatar
Paul committed
178
179
180
181
182
int output_alias_op(rank<0>, const T&, const std::vector<shape>&)
{
    return -1;
}

Paul's avatar
Paul committed
183
184
185
template <class T>
auto output_alias_op(rank<1>, const T& x, const std::vector<shape>& shapes)
    -> decltype(x.output_alias(shapes))
Paul's avatar
Paul committed
186
187
188
189
{
    return x.output_alias(shapes);
}

Paul's avatar
Paul committed
190
template <class T>
Paul's avatar
Paul committed
191
192
193
194
195
int output_alias_op(const T& x, const std::vector<shape>& shapes)
{
    return output_alias_op(rank<1>{}, x, shapes);
}

Paul's avatar
Paul committed
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
template <class T>
auto finalize_op(rank<1>, T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
    -> decltype(x.finalize(auto_any_cast(ctx), output_shape, input), void())
{
    x.finalize(auto_any_cast(ctx), output_shape, input);
}

template <class T>
void finalize_op(rank<0>, T&, context&, const shape&, const std::vector<shape>&)
{}

template <class T>
void finalize_op(T& x, context& ctx, const shape& output_shape, const std::vector<shape>& input)
{
    finalize_op(rank<1>{}, x, ctx, output_shape, input);
}

template <class T>
auto has_finalize_op(rank<1>,
                        T& x,
                        context& ctx,
                        const shape& output_shape,
                        const std::vector<shape>& input)
    -> decltype(x.finalize(auto_any_cast(ctx), output_shape, input), std::true_type{});

template <class T>
auto has_finalize_op(rank<0>, T&, context&, const shape&, const std::vector<shape>&)
    -> std::false_type;

template <class T>
auto has_finalize_op(const T&) -> decltype(has_finalize_op(
    rank<1>{}, std::declval<T&>(), std::declval<context&>(), std::declval<const shape&>(), std::declval<std::vector<shape>>()))
{
    return {};
}

232
<%
Paul's avatar
Paul committed
233
234
235
 interface(
     'operation',
     virtual('name', returns = 'std::string', const = True),
Paul's avatar
Paul committed
236
     virtual('is_context_free', returns = 'bool', const = True, default = 'is_context_free_op'),
Paul's avatar
Paul committed
237
     virtual('has_finalize', returns = 'bool', const = True, default = 'has_finalize_op'),
Paul's avatar
Paul committed
238
239
240
241
242
     virtual('output_alias',
             returns = 'int',
             input   = 'const std::vector<shape>&',
             const   = True,
             default = 'output_alias_op'),
Paul's avatar
Paul committed
243
     virtual('finalize', ctx = 'context&', output  = 'const shape&', input = 'const std::vector<shape>&', default = 'finalize_op'),
Paul's avatar
Paul committed
244
245
246
247
248
249
250
251
     virtual('compute_shape', returns = 'shape', input = 'const std::vector<shape>&', const = True),
     virtual('compute',
             returns = 'argument',
             ctx     = 'context&',
             output  = 'const shape&',
             input   = 'const std::vector<argument>&',
             const   = True,
             default = 'compute_op'),
Paul's avatar
Paul committed
252
253
254
255
256
257
     virtual('compute',
             returns = 'argument',
             output  = 'const shape&',
             input   = 'const std::vector<argument>&',
             const   = True,
             default = 'compute_op'),
Paul's avatar
Paul committed
258
259
260
261
     friend('operator<<',
            returns = 'std::ostream &',
            os      = 'std::ostream &',
            op      = 'const operation &',
Paul's avatar
Paul committed
262
            using   = 'migraphx::operation_stream::operator<<'),
Paul's avatar
Paul committed
263
264
265
266
     friend('operator==',
            returns = 'bool',
            x       = 'const operation &',
            y       = 'const operation &',
Paul's avatar
Paul committed
267
            using   = 'migraphx::operation_equal::operator==')) %>
Paul's avatar
Paul committed
268
269

    inline bool operator!=(const operation& x, const operation& y)
Paul's avatar
Paul committed
270
271
272
273
{
    return !(x == y);
}

Paul's avatar
Paul committed
274
inline bool is_context_free(const operation& op) { return op.is_context_free(); }
Paul's avatar
Paul committed
275

Paul's avatar
Paul committed
276
template <class T>
Paul's avatar
Paul committed
277
278
279
280
281
bool is_context_free(const T& x)
{
    return is_context_free_op(x);
}

Paul's avatar
Paul committed
282
283
284
285
286
287
288
289
inline bool has_finalize(const operation& op) { return op.has_finalize(); }

template <class T>
bool has_finalize(const T& x)
{
    return has_finalize_op(x);
}

Paul's avatar
Paul committed
290
291
#endif

Paul's avatar
Paul committed
292
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
293
} // namespace migraphx
Paul's avatar
Paul committed
294
295

#endif