operation.hpp 7.56 KB
Newer Older
Paul's avatar
Paul committed
1
2
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_OPERAND_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_OPERAND_HPP
Paul's avatar
Paul committed
3

Paul's avatar
Paul committed
4
#include <cassert>
Paul's avatar
Paul committed
5
6
7
8
9
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
Paul's avatar
Paul committed
10
11
12
13
14
15
#include <migraphx/shape.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/context.hpp>
#include <migraphx/auto_any_cast.hpp>
Paul's avatar
Paul committed
16
#include <migraphx/config.hpp>
Paul's avatar
Paul committed
17

Paul's avatar
Paul committed
18
namespace migraphx {
Paul's avatar
Paul committed
19
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
20

Paul's avatar
Paul committed
21
22
23
24
25
26
27
28
29
30
31
#ifdef DOXYGEN

/// The operation interface represents an action an instruction will perform. All
/// operation classes must be CopyConstructible.
struct operation
{
    /// A unique name identifying the operation
    std::string name() const;
    /// This is used to compute the resulting shape from an operation. If an
    /// operation cannot be run with input shapes, then it should throw an
    /// exception.
Paul's avatar
Paul committed
32
    shape compute_shape(const std::vector<shape>& input) const;
Paul's avatar
Paul committed
33
    /**
Paul's avatar
Paul committed
34
35
     * @brief This performs the operation's computation.
     *
Paul's avatar
Paul committed
36
37
     * This method can be optional when the operation is only used as a placeholder to be lowered
     * later on.
Paul's avatar
Paul committed
38
39
40
41
42
     *
     * @param ctx This is the context created by the `target` during compilation. Implementations
     * can use the target's `context` class rather than the `context` interface class.
     * @param output This is the output shape. It is equivalent to running `compute_shape` with each
     * `shape` of the `argument`.
Paul's avatar
Paul committed
43
     * @param input This is the `argument` result from the previous instruction's computation.
Paul's avatar
Paul committed
44
45
     * @return Return an `argument` of the result computation. The `shape` of `argument` should be
     * the same the `output` shape.
Paul's avatar
Paul committed
46
     */
Paul's avatar
Paul committed
47
    argument compute(context& ctx, const shape& output, const std::vector<argument>& input) const;
Paul's avatar
Paul committed
48
49
50
    /// An optional method to return which argument the output will alias. If
    /// there is no aliased output then -1 can be returned.
    int output_alias(const std::vector<shape>& input) const;
Paul's avatar
Paul committed
51
52
    /// An optional stream operator to print the operation. When this is not
    /// implemented, it will just print the operation's name.
Paul's avatar
Paul committed
53
    friend std::ostream& operator<<(std::ostream& os, const operation& op);
Paul's avatar
Paul committed
54
55
};

Paul's avatar
Paul committed
56
57
58
/// Returns true if operation does not require a context to run compute
bool is_context_free(const operation& x);

Paul's avatar
Paul committed
59
60
#else

Paul's avatar
Paul committed
61
62
namespace operation_stream {

Paul's avatar
Paul committed
63
64
template <class T>
auto operator<<(std::ostream& os, const T& x) -> decltype(os << x.name())
Paul's avatar
Paul committed
65
{
Paul's avatar
Paul committed
66
67
    os << x.name();
    char delim = '[';
Paul's avatar
Paul committed
68
    reflect_each(x, [&](auto& y, auto name) {
Paul's avatar
Paul committed
69
        os << delim;
Paul's avatar
Paul committed
70
71
        os << name << "=";
        stream_write_value(os, y);
Paul's avatar
Paul committed
72
73
        delim = ',';
    });
Paul's avatar
Paul committed
74
75
    if(delim == ',')
        os << "]";
Paul's avatar
Paul committed
76
    return os;
Paul's avatar
Paul committed
77
78
}

Paul's avatar
Paul committed
79
} // namespace operation_stream
Paul's avatar
Paul committed
80

Paul's avatar
Paul committed
81
82
83
84
85
86
87
88
89
90
91
92
93
namespace operation_equal {

template <class T, class U>
auto operator==(const T& x, const U& y) -> decltype(x.name() == y.name())
{
    if(x.name() != y.name())
        return false;
    const auto& yy = any_cast<T>(y);
    return reflect_tie(x) == reflect_tie(yy);
}

} // namespace operation_equal

Paul's avatar
Paul committed
94
template <class T>
Paul's avatar
Paul committed
95
auto compute_op(rank<2>,
Paul's avatar
Paul committed
96
97
98
99
100
                const T& x,
                context& ctx,
                const shape& output_shape,
                const std::vector<argument>& input)
    -> decltype(x.compute(auto_any_cast(ctx), output_shape, input))
Paul's avatar
Paul committed
101
102
103
104
{
    return x.compute(auto_any_cast(ctx), output_shape, input);
}

Paul's avatar
Paul committed
105
106
107
108
109
110
111
112
113
114
115
template <class T>
auto compute_op(rank<1>,
                const T& x,
                context&,
                const shape& output_shape,
                const std::vector<argument>& input)
    -> decltype(x.compute(output_shape, input))
{
    return x.compute(output_shape, input);
}

Paul's avatar
Paul committed
116
template <class T>
Paul's avatar
Paul committed
117
argument compute_op(rank<0>, const T& x, context&, const shape&, const std::vector<argument>&)
Paul's avatar
Paul committed
118
{
Paul's avatar
Paul committed
119
    std::string name = x.name();
Paul's avatar
Paul committed
120
    MIGRAPHX_THROW("Not computable: " + name);
Paul's avatar
Paul committed
121
122
}

Paul's avatar
Paul committed
123
template <class T>
Paul's avatar
Paul committed
124
125
argument
compute_op(const T& x, context& ctx, const shape& output_shape, const std::vector<argument>& input)
Paul's avatar
Paul committed
126
{
Paul's avatar
Paul committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    return compute_op(rank<2>{}, x, ctx, output_shape, input);
}

template <class T>
auto compute_op(rank<2>,
                const T& x,
                const shape& output_shape,
                const std::vector<argument>& input)
    -> decltype(x.compute(output_shape, input))
{
    return x.compute(output_shape, input);
}

template <class T>
auto compute_op(rank<1>,
                const T& x,
                const shape& output_shape,
                const std::vector<argument>& input)
    -> decltype(x.compute(auto_any_cast(std::declval<context&>()), output_shape, input))
{
    std::string name = x.name();
    MIGRAPHX_THROW("Not computable without a context: " + name);
}

template <class T>
argument compute_op(rank<0>, const T& x, const shape&, const std::vector<argument>&)
{
    std::string name = x.name();
    MIGRAPHX_THROW("Not computable: " + name);
}

template <class T>
argument
compute_op(const T& x, const shape& output_shape, const std::vector<argument>& input)
{
    return compute_op(rank<2>{}, x, output_shape, input);
}

template <class T>
auto is_context_free_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input) -> decltype(x.compute(output_shape, input), std::true_type{});

template <class T>
auto is_context_free_op(rank<0>, const T&, const shape&, const std::vector<argument>&) -> std::false_type;

template <class T>
auto is_context_free_op(const T& x) -> decltype(is_context_free_op(rank<1>{}, x, std::declval<const shape&>(), std::declval<std::vector<argument>>()))
{
    return {};
Paul's avatar
Paul committed
175
176
}

Paul's avatar
Paul committed
177
template <class T>
Paul's avatar
Paul committed
178
179
180
181
182
int output_alias_op(rank<0>, const T&, const std::vector<shape>&)
{
    return -1;
}

Paul's avatar
Paul committed
183
184
185
template <class T>
auto output_alias_op(rank<1>, const T& x, const std::vector<shape>& shapes)
    -> decltype(x.output_alias(shapes))
Paul's avatar
Paul committed
186
187
188
189
{
    return x.output_alias(shapes);
}

Paul's avatar
Paul committed
190
template <class T>
Paul's avatar
Paul committed
191
192
193
194
195
int output_alias_op(const T& x, const std::vector<shape>& shapes)
{
    return output_alias_op(rank<1>{}, x, shapes);
}

196
<%
Paul's avatar
Paul committed
197
198
199
 interface(
     'operation',
     virtual('name', returns = 'std::string', const = True),
Paul's avatar
Paul committed
200
     virtual('is_context_free', returns = 'bool', const = True, default = 'is_context_free_op'),
Paul's avatar
Paul committed
201
202
203
204
205
     virtual('output_alias',
             returns = 'int',
             input   = 'const std::vector<shape>&',
             const   = True,
             default = 'output_alias_op'),
Paul's avatar
Paul committed
206
207
208
209
210
211
212
213
     virtual('compute_shape', returns = 'shape', input = 'const std::vector<shape>&', const = True),
     virtual('compute',
             returns = 'argument',
             ctx     = 'context&',
             output  = 'const shape&',
             input   = 'const std::vector<argument>&',
             const   = True,
             default = 'compute_op'),
Paul's avatar
Paul committed
214
215
216
217
218
219
     virtual('compute',
             returns = 'argument',
             output  = 'const shape&',
             input   = 'const std::vector<argument>&',
             const   = True,
             default = 'compute_op'),
Paul's avatar
Paul committed
220
221
222
223
     friend('operator<<',
            returns = 'std::ostream &',
            os      = 'std::ostream &',
            op      = 'const operation &',
Paul's avatar
Paul committed
224
            using   = 'migraphx::operation_stream::operator<<'),
Paul's avatar
Paul committed
225
226
227
228
     friend('operator==',
            returns = 'bool',
            x       = 'const operation &',
            y       = 'const operation &',
Paul's avatar
Paul committed
229
            using   = 'migraphx::operation_equal::operator==')) %>
Paul's avatar
Paul committed
230
231

    inline bool operator!=(const operation& x, const operation& y)
Paul's avatar
Paul committed
232
233
234
235
{
    return !(x == y);
}

Paul's avatar
Paul committed
236
237
238
239
240
241
242
243
244
245
246
inline bool is_context_free(const operation& op)
{
    return op.is_context_free();
}

template<class T>
bool is_context_free(const T& x)
{
    return is_context_free_op(x);
}

Paul's avatar
Paul committed
247
248
#endif

Paul's avatar
Paul committed
249
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
250
} // namespace migraphx
Paul's avatar
Paul committed
251
252

#endif