operation.hpp 6.82 KB
Newer Older
Paul's avatar
Paul committed
1
#ifndef RTG_GUARD_RTGLIB_OPERAND_HPP
Paul's avatar
Paul committed
2
#define RTG_GUARD_RTGLIB_OPERAND_HPP
Paul's avatar
Paul committed
3

Paul's avatar
Paul committed
4
#include <string>
Paul's avatar
Paul committed
5
#include <functional>
Paul's avatar
Paul committed
6
7
8
#include <memory>
#include <type_traits>
#include <utility>
Paul's avatar
Paul committed
9
#include <rtg/shape.hpp>
Paul's avatar
Paul committed
10
#include <rtg/argument.hpp>
Paul's avatar
Paul committed
11
12
13

namespace rtg {

Paul's avatar
Paul committed
14
/*
Paul's avatar
Paul committed
15
16
 * Type-erased interface for:
 *
Paul's avatar
Paul committed
17
 * struct operation
Paul's avatar
Paul committed
18
19
20
21
22
23
24
 * {
 *     std::string name() const;
 *     shape compute_shape(std::vector<shape> input) const;
 *     argument compute(std::vector<argument> input) const;
 * };
 *
 */
Paul's avatar
Paul committed
25

Paul's avatar
Paul committed
26
struct operation
Paul's avatar
Paul committed
27
{
Paul's avatar
Paul committed
28
    // Constructors
Paul's avatar
Paul committed
29
    operation() = default;
Paul's avatar
Paul committed
30

Paul's avatar
Paul committed
31
    template <typename PrivateDetailTypeErasedT>
Paul's avatar
Paul committed
32
    operation(PrivateDetailTypeErasedT value)
Paul's avatar
Paul committed
33
34
35
36
        : private_detail_te_handle_mem_var(
              std::make_shared<private_detail_te_handle_type<
                  typename std::remove_reference<PrivateDetailTypeErasedT>::type>>(
                  std::forward<PrivateDetailTypeErasedT>(value)))
Paul's avatar
Paul committed
37
38
39
40
    {
    }

    // Assignment
Paul's avatar
Paul committed
41
    template <typename PrivateDetailTypeErasedT>
Paul's avatar
Paul committed
42
    operation& operator=(PrivateDetailTypeErasedT value)
Paul's avatar
Paul committed
43
    {
Paul's avatar
Paul committed
44
45
46
47
48
        if(private_detail_te_handle_mem_var.unique())
            *private_detail_te_handle_mem_var = std::forward<PrivateDetailTypeErasedT>(value);
        else if(!private_detail_te_handle_mem_var)
            private_detail_te_handle_mem_var = std::make_shared<PrivateDetailTypeErasedT>(
                std::forward<PrivateDetailTypeErasedT>(value));
Paul's avatar
Paul committed
49
50
51
        return *this;
    }

Paul's avatar
Paul committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    // Cast
    template <typename PrivateDetailTypeErasedT>
    PrivateDetailTypeErasedT* any_cast()
    {
        return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
                   ? std::addressof(static_cast<private_detail_te_handle_type<
                                        typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
                                        private_detail_te_get_handle())
                                        .private_detail_te_value)
                   : nullptr;
    }

    template <typename PrivateDetailTypeErasedT>
    const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
    {
        return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
                   ? std::addressof(static_cast<const private_detail_te_handle_type<
                                        typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
                                        private_detail_te_get_handle())
                                        .private_detail_te_value)
                   : nullptr;
    }

Paul's avatar
Paul committed
75
76
    std::string name() const
    {
Paul's avatar
Paul committed
77
78
        assert(private_detail_te_handle_mem_var);
        return private_detail_te_get_handle().name();
Paul's avatar
Paul committed
79
80
81
82
    }

    shape compute_shape(std::vector<shape> input) const
    {
Paul's avatar
Paul committed
83
84
        assert(private_detail_te_handle_mem_var);
        return private_detail_te_get_handle().compute_shape(std::move(input));
Paul's avatar
Paul committed
85
86
87
88
    }

    argument compute(std::vector<argument> input) const
    {
Paul's avatar
Paul committed
89
90
        assert(private_detail_te_handle_mem_var);
        return private_detail_te_get_handle().compute(std::move(input));
Paul's avatar
Paul committed
91
92
93
    }

    private:
Paul's avatar
Paul committed
94
    struct private_detail_te_handle_base_type
Paul's avatar
Paul committed
95
    {
Paul's avatar
Paul committed
96
97
        virtual ~private_detail_te_handle_base_type() {}
        virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
Paul's avatar
Paul committed
98
        virtual const std::type_info& type() const                                = 0;
Paul's avatar
Paul committed
99
100
101
102
103
104

        virtual std::string name() const                            = 0;
        virtual shape compute_shape(std::vector<shape> input) const = 0;
        virtual argument compute(std::vector<argument> input) const = 0;
    };

Paul's avatar
Paul committed
105
106
    template <typename PrivateDetailTypeErasedT>
    struct private_detail_te_handle_type : private_detail_te_handle_base_type
Paul's avatar
Paul committed
107
    {
Paul's avatar
Paul committed
108
109
110
111
112
113
        template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
        private_detail_te_handle_type(
            PrivateDetailTypeErasedT value,
            typename std::enable_if<std::is_reference<PrivateDetailTypeErasedU>::value>::type* =
                nullptr)
            : private_detail_te_value(value)
Paul's avatar
Paul committed
114
115
116
        {
        }

Paul's avatar
Paul committed
117
118
119
120
121
122
        template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
        private_detail_te_handle_type(
            PrivateDetailTypeErasedT value,
            typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
                                    int>::type* = nullptr) noexcept
            : private_detail_te_value(std::move(value))
Paul's avatar
Paul committed
123
124
125
        {
        }

Paul's avatar
Paul committed
126
        std::shared_ptr<private_detail_te_handle_base_type> clone() const override
Paul's avatar
Paul committed
127
        {
Paul's avatar
Paul committed
128
            return std::make_shared<private_detail_te_handle_type>(private_detail_te_value);
Paul's avatar
Paul committed
129
130
        }

Paul's avatar
Paul committed
131
        const std::type_info& type() const override { return typeid(private_detail_te_value); }
Paul's avatar
Paul committed
132

Paul's avatar
Paul committed
133
        std::string name() const override { return private_detail_te_value.name(); }
Paul's avatar
Paul committed
134

Paul's avatar
Paul committed
135
        shape compute_shape(std::vector<shape> input) const override
Paul's avatar
Paul committed
136
        {
Paul's avatar
Paul committed
137
            return private_detail_te_value.compute_shape(std::move(input));
Paul's avatar
Paul committed
138
139
        }

Paul's avatar
Paul committed
140
        argument compute(std::vector<argument> input) const override
Paul's avatar
Paul committed
141
        {
Paul's avatar
Paul committed
142
            return private_detail_te_value.compute(std::move(input));
Paul's avatar
Paul committed
143
144
        }

Paul's avatar
Paul committed
145
        PrivateDetailTypeErasedT private_detail_te_value;
Paul's avatar
Paul committed
146
147
    };

Paul's avatar
Paul committed
148
149
150
    template <typename PrivateDetailTypeErasedT>
    struct private_detail_te_handle_type<std::reference_wrapper<PrivateDetailTypeErasedT>>
        : private_detail_te_handle_type<PrivateDetailTypeErasedT&>
Paul's avatar
Paul committed
151
    {
Paul's avatar
Paul committed
152
153
        private_detail_te_handle_type(std::reference_wrapper<PrivateDetailTypeErasedT> ref)
            : private_detail_te_handle_type<PrivateDetailTypeErasedT&>(ref.get())
Paul's avatar
Paul committed
154
155
156
157
        {
        }
    };

Paul's avatar
Paul committed
158
159
160
161
    const private_detail_te_handle_base_type& private_detail_te_get_handle() const
    {
        return *private_detail_te_handle_mem_var;
    }
Paul's avatar
Paul committed
162

Paul's avatar
Paul committed
163
    private_detail_te_handle_base_type& private_detail_te_get_handle()
Paul's avatar
Paul committed
164
    {
Paul's avatar
Paul committed
165
166
167
        if(!private_detail_te_handle_mem_var.unique())
            private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
        return *private_detail_te_handle_mem_var;
Paul's avatar
Paul committed
168
169
    }

Paul's avatar
Paul committed
170
    std::shared_ptr<private_detail_te_handle_base_type> private_detail_te_handle_mem_var;
Paul's avatar
Paul committed
171
172
};

Paul's avatar
Paul committed
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
template <typename ValueType>
inline const ValueType* any_cast(const operation* x)
{
    return x->any_cast<ValueType>();
}

template <typename ValueType>
inline ValueType* any_cast(operation* x)
{
    return x->any_cast<ValueType>();
}

template <typename ValueType>
inline ValueType& any_cast(operation& x)
{
    using type = typename std::remove_reference<ValueType>::type;
    auto* y    = x.any_cast<type>();
    if(y == nullptr)
        throw std::bad_cast();
    return *y;
}

template <typename ValueType>
inline const ValueType& any_cast(const operation& x)
{
    using type    = typename std::remove_reference<ValueType>::type;
    const auto* y = x.any_cast<type>();
    if(y == nullptr)
        throw std::bad_cast();
    return *y;
}

Paul's avatar
Paul committed
205
} // namespace rtg
Paul's avatar
Paul committed
206
207

#endif