shape.hpp 2.9 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
#ifndef GUARD_RTGLIB_SHAPE_HPP
#define GUARD_RTGLIB_SHAPE_HPP

#include <vector>
#include <cassert>
Paul's avatar
Paul committed
6
#include <ostream>
Paul's avatar
Paul committed
7

Paul's avatar
Paul committed
8
9
10
11
namespace rtg {

struct shape
{
Paul's avatar
Paul committed
12
13
14
15
16
17
18

// Add new types here
#define RTG_SHAPE_VISIT_TYPES(m) \
    m(float_type, float) \
    m(int_type, int) \

#define RTG_SHAPE_ENUM_TYPES(x, t) x,
Paul's avatar
Paul committed
19
20
    enum type_t
    {
Paul's avatar
Paul committed
21
        RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_ENUM_TYPES)
Paul's avatar
Paul committed
22
    };
Paul's avatar
Paul committed
23
24
25
26
27
28
29
30
31
32
#undef RTG_SHAPE_ENUM_TYPES

    template<class T, class=void>
    struct get_type;
#define RTG_SHAPE_GET_TYPE(x, t) \
    template<class T> \
    struct get_type<t, T> : std::integral_constant<type_t, x> \
    {};
    RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_GET_TYPE)
#undef RTG_SHAPE_GET_TYPE
Paul's avatar
Paul committed
33
34
35
36
37
38
39
40

    shape();
    shape(type_t t);
    shape(type_t t, std::vector<std::size_t> l);
    shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s);


    type_t type() const;
Paul's avatar
Paul committed
41
42
    const std::vector<std::size_t>& lens() const;
    const std::vector<std::size_t>& strides() const;
Paul's avatar
Paul committed
43
44
45
    std::size_t elements() const;
    std::size_t bytes() const;

Paul's avatar
Paul committed
46
47
48
    std::size_t index(std::initializer_list<std::size_t> l) const;
    std::size_t index(const std::vector<std::size_t>& l) const;

Paul's avatar
Paul committed
49
50
51
52
53
    // Map element index to space index
    std::size_t index(std::size_t i) const;

    bool packed() const;

Paul's avatar
Paul committed
54
55
    friend bool operator==(const shape& x, const shape& y);
    friend bool operator!=(const shape& x, const shape& y);
Paul's avatar
Paul committed
56
    friend std::ostream& operator<<(std::ostream& os, const shape& x);
Paul's avatar
Paul committed
57
58
59
60
61
62
63
64
65
66
67
68

    template<class T>
    struct as
    {
        using type = T;

        template<class U>
        T operator()(U u) const
        {
            return T(u);
        }

Paul's avatar
Paul committed
69
70
71
72
73
74
75
76
77
78
79
80
        template<class U>
        T* operator()(U* u) const
        {
            return static_cast<T*>(u);
        }

        template<class U>
        const T* operator()(const U* u) const
        {
            return static_cast<T*>(u);
        }

Paul's avatar
Paul committed
81
82
83
84
85
        T operator()() const
        {
            return {};
        }

Paul's avatar
Paul committed
86
        std::size_t size(std::size_t n=1) const
Paul's avatar
Paul committed
87
88
89
90
91
        {
            return sizeof(T)*n;
        }

        template<class U>
Paul's avatar
Paul committed
92
        T* from(U* buffer, std::size_t n=0) const
Paul's avatar
Paul committed
93
        {
Paul's avatar
Paul committed
94
            return reinterpret_cast<T*>(buffer)+n;
Paul's avatar
Paul committed
95
        }
Paul's avatar
Paul committed
96
97

        template<class U>
Paul's avatar
Paul committed
98
        const T* from(const U* buffer, std::size_t n=0) const
Paul's avatar
Paul committed
99
        {
Paul's avatar
Paul committed
100
            return reinterpret_cast<const T*>(buffer)+n;
Paul's avatar
Paul committed
101
        }
Paul's avatar
Paul committed
102
103
104
105
106
107
108
    };

    template<class Visitor>
    void visit_type(Visitor v) const
    {
        switch(this->type_) 
        {
Paul's avatar
Paul committed
109
110
111
#define RTG_SHAPE_VISITOR_CASE(x, t) \
            case x: \
                v(as<t>()); \
Paul's avatar
Paul committed
112
                return;
Paul's avatar
Paul committed
113
114
            RTG_SHAPE_VISIT_TYPES(RTG_SHAPE_VISITOR_CASE)
#undef RTG_SHAPE_VISITOR_CASE
Paul's avatar
Paul committed
115
116
117
118
119
120
121
        }
        assert(true);
    }
private:
    type_t type_;
    std::vector<std::size_t> lens_;
    std::vector<std::size_t> strides_;
Paul's avatar
Paul committed
122
    bool packed_;
Paul's avatar
Paul committed
123
124
125

    void calculate_strides();
    std::size_t element_space() const;
Paul's avatar
Paul committed
126
    std::string type_string() const;
Paul's avatar
Paul committed
127
128
129
130
131
};

}

#endif