raw_data.hpp 4.8 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4

#ifndef RTG_GUARD_RAW_DATA_HPP
#define RTG_GUARD_RAW_DATA_HPP

Paul's avatar
Paul committed
5
6
#include <rtg/tensor_view.hpp>

Paul's avatar
Paul committed
7
8
namespace rtg {

9
10
11
12
13
#define RTG_REQUIRES(...) class=typename std::enable_if<(__VA_ARGS__)>::type

struct raw_data_base
{};

Paul's avatar
Paul committed
14
15
/**
 * @brief Provides a base class for common operations with raw buffer
Paul's avatar
Paul committed
16
17
18
19
20
 *
 * For classes that handle a raw buffer of data, this will provide common operations such as equals,
 * printing, and visitors. To use this class the derived class needs to provide a `data()` method to
 * retrieve a raw pointer to the data, and `get_shape` method that provides the shape of the data.
 *
Paul's avatar
Paul committed
21
 */
Paul's avatar
Paul committed
22
template <class Derived>
23
struct raw_data : raw_data_base
Paul's avatar
Paul committed
24
{
Paul's avatar
Paul committed
25
    template <class Stream>
Paul's avatar
Paul committed
26
27
    friend Stream& operator<<(Stream& os, const Derived& d)
    {
Paul's avatar
Paul committed
28
        d.visit([&](auto x) { os << x; });
Paul's avatar
Paul committed
29
30
        return os;
    }
Paul's avatar
Paul committed
31

Paul's avatar
Paul committed
32
33
    /**
     * @brief Visits a single data element at a certain index.
Paul's avatar
Paul committed
34
     *
Paul's avatar
Paul committed
35
36
37
     * @param v A function which will be called with the type of data
     * @param n The index to read from
     */
Paul's avatar
Paul committed
38
39
    template <class Visitor>
    void visit_at(Visitor v, std::size_t n = 0) const
Paul's avatar
Paul committed
40
    {
Paul's avatar
Paul committed
41
42
43
        auto&& s      = static_cast<const Derived&>(*this).get_shape();
        auto&& buffer = static_cast<const Derived&>(*this).data();
        s.visit_type([&](auto as) { v(*(as.from(buffer) + s.index(n))); });
Paul's avatar
Paul committed
44
45
    }

Paul's avatar
Paul committed
46
47
    /**
     * @brief Visits the data
Paul's avatar
Paul committed
48
     *
Paul's avatar
Paul committed
49
     *  This will call the visitor function with a `tensor_view<T>` based on the shape of the data.
Paul's avatar
Paul committed
50
     *
Paul's avatar
Paul committed
51
52
     * @param v A function to be called with `tensor_view<T>`
     */
Paul's avatar
Paul committed
53
    template <class Visitor>
Paul's avatar
Paul committed
54
55
    void visit(Visitor v) const
    {
Paul's avatar
Paul committed
56
57
58
        auto&& s      = static_cast<const Derived&>(*this).get_shape();
        auto&& buffer = static_cast<const Derived&>(*this).data();
        s.visit_type([&](auto as) { v(make_view(s, as.from(buffer))); });
Paul's avatar
Paul committed
59
60
61
62
    }

    bool single() const
    {
Paul's avatar
Paul committed
63
        auto&& s = static_cast<const Derived&>(*this).get_shape();
Paul's avatar
Paul committed
64
        return s.elements() == 1;
Paul's avatar
Paul committed
65
66
    }

Paul's avatar
Paul committed
67
68
    /**
     * @brief Retrieves a single element of data
Paul's avatar
Paul committed
69
     *
Paul's avatar
Paul committed
70
71
72
73
     * @param n The index to retrieve the data from
     * @tparam T The type of data to be retrieved
     * @return The element as `T`
     */
Paul's avatar
Paul committed
74
75
    template <class T>
    T at(std::size_t n = 0) const
Paul's avatar
Paul committed
76
77
    {
        T result;
Paul's avatar
Paul committed
78
        this->visit_at([&](auto x) { result = x; }, n);
Paul's avatar
Paul committed
79
80
        return result;
    }
Paul's avatar
Paul committed
81
82
83

    struct auto_cast
    {
Paul's avatar
Paul committed
84
85
        const Derived* self;
        template <class T>
Paul's avatar
Paul committed
86
87
88
89
        operator T()
        {
            return self->template at<T>();
        }
Paul's avatar
Paul committed
90
        template <class T>
Paul's avatar
Paul committed
91
92
93
94
95
96
97
        operator T*()
        {
            // TODO: Check type
            return reinterpret_cast<T*>(self->data());
        }
    };

Paul's avatar
Paul committed
98
    auto_cast get() const { return {static_cast<const Derived*>(this)}; }
Paul's avatar
Paul committed
99
100
};

101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
template<class T, class U, RTG_REQUIRES(std::is_base_of<raw_data_base, T>{}), RTG_REQUIRES(std::is_base_of<raw_data_base, U>{})>
bool operator==(const T& x, const U& y)
{
    auto&& xshape = x.get_shape();
    auto&& yshape = y.get_shape();
    bool result   = x.empty() && y.empty();
    if(not result && xshape == yshape)
    {
        auto&& xbuffer = x.data();
        auto&& ybuffer = y.data();
        // TODO: Dont use tensor view for single values
        xshape.visit_type([&](auto as) {
            auto xview = make_view(xshape, as.from(xbuffer));
            auto yview = make_view(yshape, as.from(ybuffer));
            result     = xview == yview;
        });
    }
    return result;
}

template<class T, class U, RTG_REQUIRES(std::is_base_of<raw_data_base, T>{}), RTG_REQUIRES(std::is_base_of<raw_data_base, U>{})>
bool operator!=(const T& x, const U& y)
{
    return !(x == y);
}

Paul's avatar
Paul committed
127
namespace detail {
Paul's avatar
Paul committed
128
template <class V, class... Ts>
Paul's avatar
Paul committed
129
130
void visit_all_impl(const shape& s, V&& v, Ts&&... xs)
{
Paul's avatar
Paul committed
131
    s.visit_type([&](auto as) { v(make_view(xs.get_shape(), as.from(xs.data()))...); });
Paul's avatar
Paul committed
132
}
Paul's avatar
Paul committed
133
} // namespace detail
Paul's avatar
Paul committed
134
135
136

/**
 * @brief Visits every object together
Paul's avatar
Paul committed
137
138
139
140
141
 * @details This will visit every object, but assumes each object is the same type. This can reduce
 * the deeply nested visit calls. This will return a function that will take the visitor callback.
 * So it will be called with `visit_all(xs...)([](auto... ys) {})` where `xs...` and `ys...` are the
 * same number of parameters.
 *
Paul's avatar
Paul committed
142
143
144
145
 * @param x A raw data object
 * @param xs Many raw data objects
 * @return A function to be called with the visitor
 */
Paul's avatar
Paul committed
146
template <class T, class... Ts>
Paul's avatar
Paul committed
147
148
auto visit_all(T&& x, Ts&&... xs)
{
Paul's avatar
Paul committed
149
    auto&& s                                   = x.get_shape();
Paul's avatar
Paul committed
150
    std::initializer_list<shape::type_t> types = {xs.get_shape().type()...};
Paul's avatar
Paul committed
151
    if(!std::all_of(types.begin(), types.end(), [&](shape::type_t t) { return t == s.type(); }))
Paul's avatar
Paul committed
152
153
        RTG_THROW("Types must be the same");
    return [&](auto v) {
Paul's avatar
Paul committed
154
155
        // Workaround for https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70100
        detail::visit_all_impl(s, v, x, xs...);
Paul's avatar
Paul committed
156
157
158
    };
}

Paul's avatar
Paul committed
159
160
161
} // namespace rtg

#endif