shape.hpp 3.43 KB
Newer Older
Paul's avatar
Paul committed
1
#ifndef RTG_GUARD_RTGLIB_SHAPE_HPP
Paul's avatar
Paul committed
2
#define RTG_GUARD_RTGLIB_SHAPE_HPP
Paul's avatar
Paul committed
3
4
5

#include <vector>
#include <cassert>
Paul's avatar
Paul committed
6
#include <ostream>
Paul's avatar
Paul committed
7

Paul's avatar
Paul committed
8
9
#include <rtg/errors.hpp>

Paul's avatar
Paul committed
10
11
12
13
namespace rtg {

struct shape
{
Paul's avatar
Paul committed
14
15

// Add new types here
Paul's avatar
Paul committed
16
// clang-format off
Paul's avatar
Paul committed
17
18
#define RTG_SHAPE_VISIT_TYPES(m) \
    m(float_type, float) \
Paul's avatar
Paul committed
19
20
21
22
23
24
25
26
    m(double_type, double) \
    m(uint8_type, uint8_t) \
    m(int8_type, int8_t) \
    m(uint16_type, uint16_t) \
    m(int16_type, int16_t) \
    m(int32_type, int32_t) \
    m(int64_type, int64_t) \
    m(uint32_type, uint32_t) \
Paul's avatar
Paul committed
27
    m(uint64_type, uint64_t)
Paul's avatar
Paul committed
28
// clang-format on
Paul's avatar
Paul committed
29

Paul's avatar
Paul committed
30
#define RTG_SHAPE_ENUM_TYPES(x, t) x,
Paul's avatar
Paul committed
31
32
    enum type_t
    {
Paul's avatar
Paul committed
33
        any_type,
Paul's avatar
Paul committed
34
        RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_ENUM_TYPES)
Paul's avatar
Paul committed
35
    };
Paul's avatar
Paul committed
36
37
#undef RTG_SHAPE_ENUM_TYPES

Paul's avatar
Paul committed
38
    template <class T, class = void>
Paul's avatar
Paul committed
39
    struct get_type : std::integral_constant<type_t, any_type>
Paul's avatar
Paul committed
40
41
    {
    };
Paul's avatar
Paul committed
42
43
#define RTG_SHAPE_GET_TYPE(x, t)                              \
    template <class T>                                        \
Paul's avatar
Paul committed
44
    struct get_type<t, T> : std::integral_constant<type_t, x> \
Paul's avatar
Paul committed
45
46
    {                                                         \
    };
Paul's avatar
Paul committed
47
48
    RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_GET_TYPE)
#undef RTG_SHAPE_GET_TYPE
Paul's avatar
Paul committed
49
50
51
52
53
54
55

    shape();
    shape(type_t t);
    shape(type_t t, std::vector<std::size_t> l);
    shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s);

    type_t type() const;
Paul's avatar
Paul committed
56
57
    const std::vector<std::size_t>& lens() const;
    const std::vector<std::size_t>& strides() const;
Paul's avatar
Paul committed
58
59
60
    std::size_t elements() const;
    std::size_t bytes() const;

Paul's avatar
Paul committed
61
62
63
    std::size_t index(std::initializer_list<std::size_t> l) const;
    std::size_t index(const std::vector<std::size_t>& l) const;

Paul's avatar
Paul committed
64
65
66
67
68
    // Map element index to space index
    std::size_t index(std::size_t i) const;

    bool packed() const;

Paul's avatar
Paul committed
69
70
    friend bool operator==(const shape& x, const shape& y);
    friend bool operator!=(const shape& x, const shape& y);
Paul's avatar
Paul committed
71
    friend std::ostream& operator<<(std::ostream& os, const shape& x);
Paul's avatar
Paul committed
72

Paul's avatar
Paul committed
73
    template <class T>
Paul's avatar
Paul committed
74
75
76
77
    struct as
    {
        using type = T;

Paul's avatar
Paul committed
78
        template <class U>
Paul's avatar
Paul committed
79
80
81
82
83
        T operator()(U u) const
        {
            return T(u);
        }

Paul's avatar
Paul committed
84
        template <class U>
Paul's avatar
Paul committed
85
86
87
88
89
        T* operator()(U* u) const
        {
            return static_cast<T*>(u);
        }

Paul's avatar
Paul committed
90
        template <class U>
Paul's avatar
Paul committed
91
92
93
94
95
        const T* operator()(const U* u) const
        {
            return static_cast<T*>(u);
        }

Paul's avatar
Paul committed
96
        T operator()() const { return {}; }
Paul's avatar
Paul committed
97

Paul's avatar
Paul committed
98
        std::size_t size(std::size_t n = 1) const { return sizeof(T) * n; }
Paul's avatar
Paul committed
99

Paul's avatar
Paul committed
100
101
        template <class U>
        T* from(U* buffer, std::size_t n = 0) const
Paul's avatar
Paul committed
102
        {
Paul's avatar
Paul committed
103
            return reinterpret_cast<T*>(buffer) + n;
Paul's avatar
Paul committed
104
        }
Paul's avatar
Paul committed
105

Paul's avatar
Paul committed
106
107
        template <class U>
        const T* from(const U* buffer, std::size_t n = 0) const
Paul's avatar
Paul committed
108
        {
Paul's avatar
Paul committed
109
            return reinterpret_cast<const T*>(buffer) + n;
Paul's avatar
Paul committed
110
        }
Paul's avatar
Paul committed
111
112
    };

Paul's avatar
Paul committed
113
    template <class Visitor>
Paul's avatar
Paul committed
114
115
    void visit_type(Visitor v) const
    {
Paul's avatar
Paul committed
116
        switch(this->m_type)
Paul's avatar
Paul committed
117
        {
Paul's avatar
Paul committed
118
        case any_type: RTG_THROW("Cannot visit the any_type");
Paul's avatar
Paul committed
119
#define RTG_SHAPE_VISITOR_CASE(x, t) \
Paul's avatar
Paul committed
120
    case x: v(as<t>()); return;
Paul's avatar
Paul committed
121
122
            RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_VISITOR_CASE)
#undef RTG_SHAPE_VISITOR_CASE
Paul's avatar
Paul committed
123
        }
Paul's avatar
Paul committed
124
        RTG_THROW("Unknown type");
Paul's avatar
Paul committed
125
    }
Paul's avatar
Paul committed
126
127

    private:
Paul's avatar
Paul committed
128
129
130
131
    type_t m_type;
    std::vector<std::size_t> m_lens;
    std::vector<std::size_t> m_strides;
    bool m_packed;
Paul's avatar
Paul committed
132
133
134

    void calculate_strides();
    std::size_t element_space() const;
Paul's avatar
Paul committed
135
    std::string type_string() const;
Paul's avatar
Paul committed
136
137
};

Paul's avatar
Paul committed
138
} // namespace rtg
Paul's avatar
Paul committed
139
140

#endif