context.hpp 9.73 KB
Newer Older
Paul's avatar
Paul committed
1
2
#ifndef MIGRAPHX_GUARD_CONTEXT_HPP
#define MIGRAPHX_GUARD_CONTEXT_HPP
Paul's avatar
Paul committed
3

Paul's avatar
Paul committed
4
#include <cassert>
Paul's avatar
Paul committed
5
6
7
8
9
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
Paul's avatar
Paul committed
10
#include <migraphx/config.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
11
#include <migraphx/value.hpp>
12
#include <migraphx/any_ptr.hpp>
Paul's avatar
Paul committed
13

Paul's avatar
Paul committed
14
namespace migraphx {
Paul's avatar
Paul committed
15
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
16

Paul's avatar
Paul committed
17
18
19
20
21
22
23
#ifdef DOXYGEN

/// A context is used to store internal data for a `target`. A context is
/// constructed by a target during compilation and passed to the operations
/// during `eval`.
struct context
{
Paul's avatar
Paul committed
24
25
    /// Wait for any tasks in the context to complete
    void finish() const;
Paul's avatar
Paul committed
26
27
28
29
};

#else

Shucai Xiao's avatar
Shucai Xiao committed
30
31
32
33
34
35
36
37
38
39
40
template <class T>
value to_value_context(const T&)
{
    return value{};
}

template <class T>
void from_value_context(T&, const value&)
{
}

41
42
43
44
45
46
template <class T>
any_ptr get_queue_context(T&)
{
    return {};
}

47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#ifdef TYPE_ERASED_DECLARATION

// Type-erased interface for:
struct context
{
    // (optional)
    value to_value() const;
    // (optional)
    void from_value(const value& v);
    // (optional)
    any_ptr get_queue();
    //
    void finish() const;
};

#else
Paul's avatar
Paul committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

struct context
{
    // Constructors
    context() = default;

    template <typename PrivateDetailTypeErasedT>
    context(PrivateDetailTypeErasedT value)
        : private_detail_te_handle_mem_var(
              std::make_shared<private_detail_te_handle_type<
                  typename std::remove_reference<PrivateDetailTypeErasedT>::type>>(
                  std::forward<PrivateDetailTypeErasedT>(value)))
    {
    }

    // Assignment
    template <typename PrivateDetailTypeErasedT>
    context& operator=(PrivateDetailTypeErasedT value)
    {
Paul Fultz II's avatar
Paul Fultz II committed
82
83
84
85
86
87
88
89
90
91
92
        using std::swap;
        auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
        if(derived and private_detail_te_handle_mem_var.unique())
        {
            *derived = std::forward<PrivateDetailTypeErasedT>(value);
        }
        else
        {
            context rhs(value);
            swap(private_detail_te_handle_mem_var, rhs.private_detail_te_handle_mem_var);
        }
Paul's avatar
Paul committed
93
94
95
96
97
98
99
        return *this;
    }

    // Cast
    template <typename PrivateDetailTypeErasedT>
    PrivateDetailTypeErasedT* any_cast()
    {
Paul Fultz II's avatar
Paul Fultz II committed
100
        return this->type_id() == typeid(PrivateDetailTypeErasedT)
Paul's avatar
Paul committed
101
102
103
104
105
106
107
108
109
110
                   ? std::addressof(static_cast<private_detail_te_handle_type<
                                        typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
                                        private_detail_te_get_handle())
                                        .private_detail_te_value)
                   : nullptr;
    }

    template <typename PrivateDetailTypeErasedT>
    const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
    {
Paul Fultz II's avatar
Paul Fultz II committed
111
        return this->type_id() == typeid(PrivateDetailTypeErasedT)
Paul's avatar
Paul committed
112
113
114
115
116
117
118
                   ? std::addressof(static_cast<const private_detail_te_handle_type<
                                        typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
                                        private_detail_te_get_handle())
                                        .private_detail_te_value)
                   : nullptr;
    }

Paul's avatar
Paul committed
119
120
121
122
123
124
125
126
    const std::type_info& type_id() const
    {
        if(private_detail_te_handle_empty())
            return typeid(std::nullptr_t);
        else
            return private_detail_te_get_handle().type();
    }

Shucai Xiao's avatar
Shucai Xiao committed
127
128
129
130
131
132
133
134
135
136
137
138
    value to_value() const
    {
        assert((*this).private_detail_te_handle_mem_var);
        return (*this).private_detail_te_get_handle().to_value();
    }

    void from_value(const value& v)
    {
        assert((*this).private_detail_te_handle_mem_var);
        (*this).private_detail_te_get_handle().from_value(v);
    }

139
140
141
142
143
144
    any_ptr get_queue()
    {
        assert((*this).private_detail_te_handle_mem_var);
        return (*this).private_detail_te_get_handle().get_queue();
    }

Paul's avatar
Paul committed
145
146
147
    void finish() const
    {
        assert((*this).private_detail_te_handle_mem_var);
Paul's avatar
Paul committed
148
        (*this).private_detail_te_get_handle().finish();
Paul's avatar
Paul committed
149
150
    }

Paul's avatar
Paul committed
151
152
153
154
155
156
    friend bool is_shared(const context& private_detail_x, const context& private_detail_y)
    {
        return private_detail_x.private_detail_te_handle_mem_var ==
               private_detail_y.private_detail_te_handle_mem_var;
    }

Paul's avatar
Paul committed
157
158
159
160
161
162
    private:
    struct private_detail_te_handle_base_type
    {
        virtual ~private_detail_te_handle_base_type() {}
        virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
        virtual const std::type_info& type() const                                = 0;
Paul's avatar
Paul committed
163

Shucai Xiao's avatar
Shucai Xiao committed
164
165
        virtual value to_value() const          = 0;
        virtual void from_value(const value& v) = 0;
166
        virtual any_ptr get_queue()             = 0;
Shucai Xiao's avatar
Shucai Xiao committed
167
        virtual void finish() const             = 0;
Paul's avatar
Paul committed
168
169
    };

Shucai Xiao's avatar
Shucai Xiao committed
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
    template <class T>
    static auto private_detail_te_default_to_value(char, T&& private_detail_te_self)
        -> decltype(private_detail_te_self.to_value())
    {
        return private_detail_te_self.to_value();
    }

    template <class T>
    static value private_detail_te_default_to_value(float, T&& private_detail_te_self)
    {
        return to_value_context(private_detail_te_self);
    }

    template <class T>
    static auto
    private_detail_te_default_from_value(char, T&& private_detail_te_self, const value& v)
        -> decltype(private_detail_te_self.from_value(v))
    {
        private_detail_te_self.from_value(v);
    }

    template <class T>
    static void
    private_detail_te_default_from_value(float, T&& private_detail_te_self, const value& v)
    {
        from_value_context(private_detail_te_self, v);
    }

198
199
200
201
202
203
204
205
206
207
208
209
210
    template <class T>
    static auto private_detail_te_default_get_queue(char, T&& private_detail_te_self)
        -> decltype(private_detail_te_self.get_queue())
    {
        return private_detail_te_self.get_queue();
    }

    template <class T>
    static any_ptr private_detail_te_default_get_queue(float, T&& private_detail_te_self)
    {
        return get_queue_context(private_detail_te_self);
    }

Paul's avatar
Paul committed
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
    template <typename PrivateDetailTypeErasedT>
    struct private_detail_te_handle_type : private_detail_te_handle_base_type
    {
        template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
        private_detail_te_handle_type(
            PrivateDetailTypeErasedT value,
            typename std::enable_if<std::is_reference<PrivateDetailTypeErasedU>::value>::type* =
                nullptr)
            : private_detail_te_value(value)
        {
        }

        template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
        private_detail_te_handle_type(
            PrivateDetailTypeErasedT value,
            typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
                                    int>::type* = nullptr) noexcept
            : private_detail_te_value(std::move(value))
        {
        }

        std::shared_ptr<private_detail_te_handle_base_type> clone() const override
        {
            return std::make_shared<private_detail_te_handle_type>(private_detail_te_value);
        }

        const std::type_info& type() const override { return typeid(private_detail_te_value); }

Shucai Xiao's avatar
Shucai Xiao committed
239
240
241
242
243
244
245
246
247
248
249
250
        value to_value() const override
        {

            return private_detail_te_default_to_value(char(0), private_detail_te_value);
        }

        void from_value(const value& v) override
        {

            private_detail_te_default_from_value(char(0), private_detail_te_value, v);
        }

251
252
253
254
255
256
        any_ptr get_queue() override
        {

            return private_detail_te_default_get_queue(char(0), private_detail_te_value);
        }

Paul's avatar
Paul committed
257
        void finish() const override { private_detail_te_value.finish(); }
Paul's avatar
Paul committed
258

Paul's avatar
Paul committed
259
260
261
262
263
264
265
266
267
268
269
270
271
        PrivateDetailTypeErasedT private_detail_te_value;
    };

    template <typename PrivateDetailTypeErasedT>
    struct private_detail_te_handle_type<std::reference_wrapper<PrivateDetailTypeErasedT>>
        : private_detail_te_handle_type<PrivateDetailTypeErasedT&>
    {
        private_detail_te_handle_type(std::reference_wrapper<PrivateDetailTypeErasedT> ref)
            : private_detail_te_handle_type<PrivateDetailTypeErasedT&>(ref.get())
        {
        }
    };

Paul's avatar
Paul committed
272
273
274
275
276
    bool private_detail_te_handle_empty() const
    {
        return private_detail_te_handle_mem_var == nullptr;
    }

Paul's avatar
Paul committed
277
278
    const private_detail_te_handle_base_type& private_detail_te_get_handle() const
    {
Paul's avatar
Paul committed
279
        assert(private_detail_te_handle_mem_var != nullptr);
Paul's avatar
Paul committed
280
281
282
283
284
        return *private_detail_te_handle_mem_var;
    }

    private_detail_te_handle_base_type& private_detail_te_get_handle()
    {
Paul's avatar
Paul committed
285
        assert(private_detail_te_handle_mem_var != nullptr);
Paul's avatar
Paul committed
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
        if(!private_detail_te_handle_mem_var.unique())
            private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
        return *private_detail_te_handle_mem_var;
    }

    std::shared_ptr<private_detail_te_handle_base_type> private_detail_te_handle_mem_var;
};

template <typename ValueType>
inline const ValueType* any_cast(const context* x)
{
    return x->any_cast<ValueType>();
}

template <typename ValueType>
inline ValueType* any_cast(context* x)
{
    return x->any_cast<ValueType>();
}

template <typename ValueType>
inline ValueType& any_cast(context& x)
{
    auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
    if(y == nullptr)
        throw std::bad_cast();
    return *y;
}

template <typename ValueType>
inline const ValueType& any_cast(const context& x)
{
    const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
    if(y == nullptr)
        throw std::bad_cast();
    return *y;
}
323
#endif
Paul's avatar
Paul committed
324

Shucai Xiao's avatar
Shucai Xiao committed
325
326
327
inline void migraphx_to_value(value& v, const context& ctx) { v = ctx.to_value(); }
inline void migraphx_from_value(const value& v, context& ctx) { ctx.from_value(v); }

Paul's avatar
Paul committed
328
329
#endif

Paul's avatar
Paul committed
330
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
331
} // namespace migraphx
Paul's avatar
Paul committed
332
333

#endif