raw_data.hpp 5.58 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4

#ifndef RTG_GUARD_RAW_DATA_HPP
#define RTG_GUARD_RAW_DATA_HPP

Paul's avatar
Paul committed
5
6
#include <rtg/tensor_view.hpp>

Paul's avatar
Paul committed
7
8
namespace rtg {

Paul's avatar
Paul committed
9
#define RTG_REQUIRES(...) class = typename std::enable_if<(__VA_ARGS__)>::type
10
11

struct raw_data_base
Paul's avatar
Paul committed
12
13
{
};
14

Paul's avatar
Paul committed
15
16
/**
 * @brief Provides a base class for common operations with raw buffer
Paul's avatar
Paul committed
17
18
19
20
21
 *
 * For classes that handle a raw buffer of data, this will provide common operations such as equals,
 * printing, and visitors. To use this class the derived class needs to provide a `data()` method to
 * retrieve a raw pointer to the data, and `get_shape` method that provides the shape of the data.
 *
Paul's avatar
Paul committed
22
 */
Paul's avatar
Paul committed
23
template <class Derived>
24
struct raw_data : raw_data_base
Paul's avatar
Paul committed
25
{
Paul's avatar
Paul committed
26
    template <class Stream>
Paul's avatar
Paul committed
27
28
    friend Stream& operator<<(Stream& os, const Derived& d)
    {
Paul's avatar
Paul committed
29
        d.visit([&](auto x) { os << x; });
Paul's avatar
Paul committed
30
31
        return os;
    }
Paul's avatar
Paul committed
32

Paul's avatar
Paul committed
33
34
    /**
     * @brief Visits a single data element at a certain index.
Paul's avatar
Paul committed
35
     *
Paul's avatar
Paul committed
36
37
38
     * @param v A function which will be called with the type of data
     * @param n The index to read from
     */
Paul's avatar
Paul committed
39
40
    template <class Visitor>
    void visit_at(Visitor v, std::size_t n = 0) const
Paul's avatar
Paul committed
41
    {
Paul's avatar
Paul committed
42
43
44
        auto&& s      = static_cast<const Derived&>(*this).get_shape();
        auto&& buffer = static_cast<const Derived&>(*this).data();
        s.visit_type([&](auto as) { v(*(as.from(buffer) + s.index(n))); });
Paul's avatar
Paul committed
45
46
    }

Paul's avatar
Paul committed
47
48
    /**
     * @brief Visits the data
Paul's avatar
Paul committed
49
     *
Paul's avatar
Paul committed
50
     *  This will call the visitor function with a `tensor_view<T>` based on the shape of the data.
Paul's avatar
Paul committed
51
     *
Paul's avatar
Paul committed
52
53
     * @param v A function to be called with `tensor_view<T>`
     */
Paul's avatar
Paul committed
54
    template <class Visitor>
Paul's avatar
Paul committed
55
56
    void visit(Visitor v) const
    {
Paul's avatar
Paul committed
57
58
59
        auto&& s      = static_cast<const Derived&>(*this).get_shape();
        auto&& buffer = static_cast<const Derived&>(*this).data();
        s.visit_type([&](auto as) { v(make_view(s, as.from(buffer))); });
Paul's avatar
Paul committed
60
61
    }

62
    /// Returns true if the raw data is only one element
Paul's avatar
Paul committed
63
64
    bool single() const
    {
Paul's avatar
Paul committed
65
        auto&& s = static_cast<const Derived&>(*this).get_shape();
Paul's avatar
Paul committed
66
        return s.elements() == 1;
Paul's avatar
Paul committed
67
68
    }

Paul's avatar
Paul committed
69
70
    /**
     * @brief Retrieves a single element of data
Paul's avatar
Paul committed
71
     *
Paul's avatar
Paul committed
72
73
74
75
     * @param n The index to retrieve the data from
     * @tparam T The type of data to be retrieved
     * @return The element as `T`
     */
Paul's avatar
Paul committed
76
77
    template <class T>
    T at(std::size_t n = 0) const
Paul's avatar
Paul committed
78
79
    {
        T result;
Paul's avatar
Paul committed
80
        this->visit_at([&](auto x) { result = x; }, n);
Paul's avatar
Paul committed
81
82
        return result;
    }
Paul's avatar
Paul committed
83
84
85

    struct auto_cast
    {
Paul's avatar
Paul committed
86
87
        const Derived* self;
        template <class T>
Paul's avatar
Paul committed
88
89
        operator T()
        {
90
            assert(self->single());
Paul's avatar
Paul committed
91
92
            return self->template at<T>();
        }
Paul's avatar
Paul committed
93
        template <class T>
Paul's avatar
Paul committed
94
95
        operator T*()
        {
96
            using type = std::remove_cv_t<T>;
Paul's avatar
Paul committed
97
98
99
            assert((std::is_void<T>{} or std::is_same<char, type>{} or
                    std::is_same<unsigned char, type>{} or
                    self->get_shape().type() == rtg::shape::get_type<T>{}));
100
            return reinterpret_cast<type*>(self->data());
Paul's avatar
Paul committed
101
102
103
        }
    };

104
105
106
107
    /// Implicit conversion of raw data pointer
    auto_cast implicit() const { return {static_cast<const Derived*>(this)}; }

    /// Get a tensor_view to the data
Paul's avatar
Paul committed
108
    template <class T>
109
110
    tensor_view<T> get() const
    {
Paul's avatar
Paul committed
111
        auto&& s      = static_cast<const Derived&>(*this).get_shape();
112
113
114
115
116
        auto&& buffer = static_cast<const Derived&>(*this).data();
        if(s.type() != rtg::shape::get_type<T>{})
            RTG_THROW("Incorrect data type for raw data");
        return make_view(s, reinterpret_cast<T*>(buffer));
    }
Paul's avatar
Paul committed
117
118
};

Paul's avatar
Paul committed
119
120
template <class T,
          class U,
Paul's avatar
Paul committed
121
          RTG_REQUIRES(std::is_base_of<raw_data_base, T>{} && std::is_base_of<raw_data_base, U>{})>
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
bool operator==(const T& x, const U& y)
{
    auto&& xshape = x.get_shape();
    auto&& yshape = y.get_shape();
    bool result   = x.empty() && y.empty();
    if(not result && xshape == yshape)
    {
        auto&& xbuffer = x.data();
        auto&& ybuffer = y.data();
        // TODO: Dont use tensor view for single values
        xshape.visit_type([&](auto as) {
            auto xview = make_view(xshape, as.from(xbuffer));
            auto yview = make_view(yshape, as.from(ybuffer));
            result     = xview == yview;
        });
    }
    return result;
}

Paul's avatar
Paul committed
141
142
template <class T,
          class U,
Paul's avatar
Paul committed
143
          RTG_REQUIRES(std::is_base_of<raw_data_base, T>{} && std::is_base_of<raw_data_base, U>{})>
144
145
146
147
148
bool operator!=(const T& x, const U& y)
{
    return !(x == y);
}

Paul's avatar
Paul committed
149
namespace detail {
Paul's avatar
Paul committed
150
template <class V, class... Ts>
Paul's avatar
Paul committed
151
152
void visit_all_impl(const shape& s, V&& v, Ts&&... xs)
{
Paul's avatar
Paul committed
153
    s.visit_type([&](auto as) { v(make_view(xs.get_shape(), as.from(xs.data()))...); });
Paul's avatar
Paul committed
154
}
Paul's avatar
Paul committed
155
} // namespace detail
Paul's avatar
Paul committed
156
157
158

/**
 * @brief Visits every object together
Paul's avatar
Paul committed
159
160
161
162
163
 * @details This will visit every object, but assumes each object is the same type. This can reduce
 * the deeply nested visit calls. This will return a function that will take the visitor callback.
 * So it will be called with `visit_all(xs...)([](auto... ys) {})` where `xs...` and `ys...` are the
 * same number of parameters.
 *
Paul's avatar
Paul committed
164
165
166
167
 * @param x A raw data object
 * @param xs Many raw data objects
 * @return A function to be called with the visitor
 */
Paul's avatar
Paul committed
168
template <class T, class... Ts>
Paul's avatar
Paul committed
169
170
auto visit_all(T&& x, Ts&&... xs)
{
Paul's avatar
Paul committed
171
    auto&& s                                   = x.get_shape();
Paul's avatar
Paul committed
172
    std::initializer_list<shape::type_t> types = {xs.get_shape().type()...};
Paul's avatar
Paul committed
173
    if(!std::all_of(types.begin(), types.end(), [&](shape::type_t t) { return t == s.type(); }))
Paul's avatar
Paul committed
174
175
        RTG_THROW("Types must be the same");
    return [&](auto v) {
Paul's avatar
Paul committed
176
177
        // Workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70100
        detail::visit_all_impl(s, v, x, xs...);
Paul's avatar
Paul committed
178
179
180
    };
}

Paul's avatar
Paul committed
181
182
183
} // namespace rtg

#endif