raw_data.hpp 5.86 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4

#ifndef RTG_GUARD_RAW_DATA_HPP
#define RTG_GUARD_RAW_DATA_HPP

Paul's avatar
Paul committed
5
#include <rtg/tensor_view.hpp>
Paul's avatar
Paul committed
6
#include <rtg/requires.hpp>
Paul's avatar
Paul committed
7

Paul's avatar
Paul committed
8
9
namespace rtg {

10
struct raw_data_base
Paul's avatar
Paul committed
11
12
{
};
13

Paul's avatar
Paul committed
14
15
/**
 * @brief Provides a base class for common operations with raw buffer
Paul's avatar
Paul committed
16
17
18
19
20
 *
 * For classes that handle a raw buffer of data, this will provide common operations such as equals,
 * printing, and visitors. To use this class the derived class needs to provide a `data()` method to
 * retrieve a raw pointer to the data, and `get_shape` method that provides the shape of the data.
 *
Paul's avatar
Paul committed
21
 */
Paul's avatar
Paul committed
22
template <class Derived>
23
struct raw_data : raw_data_base
Paul's avatar
Paul committed
24
{
Paul's avatar
Paul committed
25
    template <class Stream>
Paul's avatar
Paul committed
26
27
    friend Stream& operator<<(Stream& os, const Derived& d)
    {
Paul's avatar
Paul committed
28
        d.visit([&](auto x) { os << x; });
Paul's avatar
Paul committed
29
30
        return os;
    }
Paul's avatar
Paul committed
31

Paul's avatar
Paul committed
32
33
    /**
     * @brief Visits a single data element at a certain index.
Paul's avatar
Paul committed
34
     *
Paul's avatar
Paul committed
35
36
37
     * @param v A function which will be called with the type of data
     * @param n The index to read from
     */
Paul's avatar
Paul committed
38
39
    template <class Visitor>
    void visit_at(Visitor v, std::size_t n = 0) const
Paul's avatar
Paul committed
40
    {
Paul's avatar
Paul committed
41
42
43
        auto&& s      = static_cast<const Derived&>(*this).get_shape();
        auto&& buffer = static_cast<const Derived&>(*this).data();
        s.visit_type([&](auto as) { v(*(as.from(buffer) + s.index(n))); });
Paul's avatar
Paul committed
44
45
    }

Paul's avatar
Paul committed
46
47
    /**
     * @brief Visits the data
Paul's avatar
Paul committed
48
     *
Paul's avatar
Paul committed
49
     *  This will call the visitor function with a `tensor_view<T>` based on the shape of the data.
Paul's avatar
Paul committed
50
     *
Paul's avatar
Paul committed
51
52
     * @param v A function to be called with `tensor_view<T>`
     */
Paul's avatar
Paul committed
53
    template <class Visitor>
Paul's avatar
Paul committed
54
55
    void visit(Visitor v) const
    {
Paul's avatar
Paul committed
56
57
58
        auto&& s      = static_cast<const Derived&>(*this).get_shape();
        auto&& buffer = static_cast<const Derived&>(*this).data();
        s.visit_type([&](auto as) { v(make_view(s, as.from(buffer))); });
Paul's avatar
Paul committed
59
60
    }

61
    /// Returns true if the raw data is only one element
Paul's avatar
Paul committed
62
63
    bool single() const
    {
Paul's avatar
Paul committed
64
        auto&& s = static_cast<const Derived&>(*this).get_shape();
Paul's avatar
Paul committed
65
        return s.elements() == 1;
Paul's avatar
Paul committed
66
67
    }

Paul's avatar
Paul committed
68
69
    /**
     * @brief Retrieves a single element of data
Paul's avatar
Paul committed
70
     *
Paul's avatar
Paul committed
71
72
73
74
     * @param n The index to retrieve the data from
     * @tparam T The type of data to be retrieved
     * @return The element as `T`
     */
Paul's avatar
Paul committed
75
76
    template <class T>
    T at(std::size_t n = 0) const
Paul's avatar
Paul committed
77
78
    {
        T result;
Paul's avatar
Paul committed
79
        this->visit_at([&](auto x) { result = x; }, n);
Paul's avatar
Paul committed
80
81
        return result;
    }
Paul's avatar
Paul committed
82
83
84

    struct auto_cast
    {
Paul's avatar
Paul committed
85
86
        const Derived* self;
        template <class T>
Paul's avatar
Paul committed
87
88
        operator T()
        {
89
            assert(self->single());
Paul's avatar
Paul committed
90
91
            return self->template at<T>();
        }
Paul's avatar
Paul committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107

        template<class T>
        using is_data_ptr = bool_c<(std::is_void<T>{} or std::is_same<char, std::remove_cv_t<T>>{} or std::is_same<unsigned char, std::remove_cv_t<T>>{})>;

        template<class T>
        using get_data_type = std::conditional_t<is_data_ptr<T>{},
            float,
            T
        >;

        template<class T>
        bool matches() const
        {
            return is_data_ptr<T>{} || self->get_shape().type() == rtg::shape::get_type<get_data_type<T>>{};
        }

Paul's avatar
Paul committed
108
        template <class T>
Paul's avatar
Paul committed
109
110
        operator T*()
        {
111
            using type = std::remove_cv_t<T>;
Paul's avatar
Paul committed
112
            assert(matches<T>());
113
            return reinterpret_cast<type*>(self->data());
Paul's avatar
Paul committed
114
115
116
        }
    };

117
118
119
120
    /// Implicit conversion of raw data pointer
    auto_cast implicit() const { return {static_cast<const Derived*>(this)}; }

    /// Get a tensor_view to the data
Paul's avatar
Paul committed
121
    template <class T>
122
123
    tensor_view<T> get() const
    {
Paul's avatar
Paul committed
124
        auto&& s      = static_cast<const Derived&>(*this).get_shape();
125
126
127
128
129
        auto&& buffer = static_cast<const Derived&>(*this).data();
        if(s.type() != rtg::shape::get_type<T>{})
            RTG_THROW("Incorrect data type for raw data");
        return make_view(s, reinterpret_cast<T*>(buffer));
    }
Paul's avatar
Paul committed
130
131
};

Paul's avatar
Paul committed
132
133
template <class T,
          class U,
Paul's avatar
Paul committed
134
          RTG_REQUIRES(std::is_base_of<raw_data_base, T>{} && std::is_base_of<raw_data_base, U>{})>
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
bool operator==(const T& x, const U& y)
{
    auto&& xshape = x.get_shape();
    auto&& yshape = y.get_shape();
    bool result   = x.empty() && y.empty();
    if(not result && xshape == yshape)
    {
        auto&& xbuffer = x.data();
        auto&& ybuffer = y.data();
        // TODO: Dont use tensor view for single values
        xshape.visit_type([&](auto as) {
            auto xview = make_view(xshape, as.from(xbuffer));
            auto yview = make_view(yshape, as.from(ybuffer));
            result     = xview == yview;
        });
    }
    return result;
}

Paul's avatar
Paul committed
154
155
template <class T,
          class U,
Paul's avatar
Paul committed
156
          RTG_REQUIRES(std::is_base_of<raw_data_base, T>{} && std::is_base_of<raw_data_base, U>{})>
157
158
159
160
161
bool operator!=(const T& x, const U& y)
{
    return !(x == y);
}

Paul's avatar
Paul committed
162
namespace detail {
Paul's avatar
Paul committed
163
template <class V, class... Ts>
Paul's avatar
Paul committed
164
165
void visit_all_impl(const shape& s, V&& v, Ts&&... xs)
{
Paul's avatar
Paul committed
166
    s.visit_type([&](auto as) { v(make_view(xs.get_shape(), as.from(xs.data()))...); });
Paul's avatar
Paul committed
167
}
Paul's avatar
Paul committed
168
} // namespace detail
Paul's avatar
Paul committed
169
170
171

/**
 * @brief Visits every object together
Paul's avatar
Paul committed
172
173
174
175
176
 * @details This will visit every object, but assumes each object is the same type. This can reduce
 * the deeply nested visit calls. This will return a function that will take the visitor callback.
 * So it will be called with `visit_all(xs...)([](auto... ys) {})` where `xs...` and `ys...` are the
 * same number of parameters.
 *
Paul's avatar
Paul committed
177
178
179
180
 * @param x A raw data object
 * @param xs Many raw data objects
 * @return A function to be called with the visitor
 */
Paul's avatar
Paul committed
181
template <class T, class... Ts>
Paul's avatar
Paul committed
182
183
auto visit_all(T&& x, Ts&&... xs)
{
Paul's avatar
Paul committed
184
    auto&& s                                   = x.get_shape();
Paul's avatar
Paul committed
185
    std::initializer_list<shape::type_t> types = {xs.get_shape().type()...};
Paul's avatar
Paul committed
186
    if(!std::all_of(types.begin(), types.end(), [&](shape::type_t t) { return t == s.type(); }))
Paul's avatar
Paul committed
187
188
        RTG_THROW("Types must be the same");
    return [&](auto v) {
Paul's avatar
Paul committed
189
190
        // Workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70100
        detail::visit_all_impl(s, v, x, xs...);
Paul's avatar
Paul committed
191
192
193
    };
}

Paul's avatar
Paul committed
194
195
196
} // namespace rtg

#endif