"src/vscode:/vscode.git/clone" did not exist on "d7a283004e1db3c25cb9eed0d640d427ff771dc6"
shape.hpp 4.52 KB
Newer Older
Paul's avatar
Paul committed
1
2
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_HPP
Paul's avatar
Paul committed
3
4
5

#include <vector>
#include <cassert>
Paul's avatar
Paul committed
6
#include <ostream>
Paul's avatar
Paul committed
7
#include <numeric>
Paul's avatar
Paul committed
8
#include <memory>
Paul's avatar
Paul committed
9

Paul's avatar
Paul committed
10
#include <migraph/errors.hpp>
Paul's avatar
Paul committed
11
#include <migraph/half.hpp>
Paul's avatar
Paul committed
12

Paul's avatar
Paul committed
13
namespace migraph {
Paul's avatar
Paul committed
14

Paul's avatar
Paul committed
15
16
struct shape_impl;

Paul's avatar
Paul committed
17
18
struct shape
{
Paul's avatar
Paul committed
19
20

// Add new types here
Paul's avatar
Paul committed
21
// clang-format off
Paul's avatar
Paul committed
22
#define MIGRAPH_SHAPE_VISIT_TYPES(m) \
Paul's avatar
Paul committed
23
    m(half_type, half) \
Paul's avatar
Paul committed
24
    m(float_type, float) \
Paul's avatar
Paul committed
25
26
27
28
29
30
31
32
    m(double_type, double) \
    m(uint8_type, uint8_t) \
    m(int8_type, int8_t) \
    m(uint16_type, uint16_t) \
    m(int16_type, int16_t) \
    m(int32_type, int32_t) \
    m(int64_type, int64_t) \
    m(uint32_type, uint32_t) \
Paul's avatar
Paul committed
33
    m(uint64_type, uint64_t)
Paul's avatar
Paul committed
34
// clang-format on
Paul's avatar
Paul committed
35

Paul's avatar
Paul committed
36
#define MIGRAPH_SHAPE_ENUM_TYPES(x, t) x,
Paul's avatar
Paul committed
37
38
    enum type_t
    {
Paul's avatar
Paul committed
39
        MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_ENUM_TYPES)
Paul's avatar
Paul committed
40
    };
Paul's avatar
Paul committed
41
#undef MIGRAPH_SHAPE_ENUM_TYPES
Paul's avatar
Paul committed
42

Paul's avatar
Paul committed
43
    template <class T, class = void>
Paul's avatar
Paul committed
44
    struct get_type;
Paul's avatar
Paul committed
45
#define MIGRAPH_SHAPE_GET_TYPE(x, t)                          \
Paul's avatar
Paul committed
46
    template <class T>                                        \
Paul's avatar
Paul committed
47
    struct get_type<t, T> : std::integral_constant<type_t, x> \
Paul's avatar
Paul committed
48
49
    {                                                         \
    };
Paul's avatar
Paul committed
50
51
    MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_GET_TYPE)
#undef MIGRAPH_SHAPE_GET_TYPE
Paul's avatar
Paul committed
52

wsttiger's avatar
wsttiger committed
53
54
55
56
57
    template <class T>
    struct get_type<const T> : get_type<T>
    {
    };

Paul's avatar
Paul committed
58
59
60
61
62
63
    shape();
    shape(type_t t);
    shape(type_t t, std::vector<std::size_t> l);
    shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s);

    type_t type() const;
Paul's avatar
Paul committed
64
65
    const std::vector<std::size_t>& lens() const;
    const std::vector<std::size_t>& strides() const;
Paul's avatar
Paul committed
66
67
    std::size_t elements() const;
    std::size_t bytes() const;
Scott Thornton's avatar
Scott Thornton committed
68
    std::size_t type_size() const;
Paul's avatar
Paul committed
69

Paul's avatar
Paul committed
70
    /// Map multiple indices to space index
Paul's avatar
Paul committed
71
    std::size_t index(std::initializer_list<std::size_t> l) const;
Paul's avatar
Paul committed
72
    /// Map multiple indices to space index
Paul's avatar
Paul committed
73
    std::size_t index(const std::vector<std::size_t>& l) const;
Paul's avatar
Paul committed
74

Paul's avatar
Paul committed
75
    /// Map multiple indices from a range of iterator to a space index
Paul's avatar
Paul committed
76
    template <class Iterator>
Paul's avatar
Paul committed
77
78
79
80
81
82
    std::size_t index(Iterator start, Iterator last) const
    {
        assert(std::distance(start, last) <= this->lens().size());
        assert(this->lens().size() == this->strides().size());
        return std::inner_product(start, last, this->strides().begin(), std::size_t{0});
    }
Paul's avatar
Paul committed
83

Paul's avatar
Paul committed
84
    /// Map element index to space index
Paul's avatar
Paul committed
85
86
    std::size_t index(std::size_t i) const;

Paul's avatar
Paul committed
87
    /// Returns true if the shape is packed with no padding
Paul's avatar
Paul committed
88
    bool packed() const;
Paul's avatar
Paul committed
89
90
    /// Returns true is the shape has been transposed. That is the strides are not in descending
    /// order
Paul's avatar
Paul committed
91
    bool transposed() const;
Paul's avatar
Paul committed
92
    /// Returns true if the shape is broadcasting a dimension. That is, one of the strides are zero
Paul's avatar
Paul committed
93
    bool broadcasted() const;
Paul's avatar
Paul committed
94
95
    /// Returns true if the shape is in its standard format. That is, the shape is both packed and
    /// not transposed.
Paul's avatar
Paul committed
96
    bool standard() const;
Khalique's avatar
Khalique committed
97
98
    /// Returns true if all strides are equal to 0 (scalar tensor)
    bool scalar() const;
Paul's avatar
Paul committed
99

Paul's avatar
Paul committed
100
101
    friend bool operator==(const shape& x, const shape& y);
    friend bool operator!=(const shape& x, const shape& y);
Paul's avatar
Paul committed
102
    friend std::ostream& operator<<(std::ostream& os, const shape& x);
Paul's avatar
Paul committed
103

Paul's avatar
Paul committed
104
    template <class T>
Paul's avatar
Paul committed
105
106
107
108
    struct as
    {
        using type = T;

Paul's avatar
Paul committed
109
        template <class U>
Paul's avatar
Paul committed
110
111
112
113
114
        T operator()(U u) const
        {
            return T(u);
        }

Paul's avatar
Paul committed
115
        template <class U>
Paul's avatar
Paul committed
116
117
118
119
120
        T* operator()(U* u) const
        {
            return static_cast<T*>(u);
        }

Paul's avatar
Paul committed
121
        template <class U>
Paul's avatar
Paul committed
122
123
124
125
126
        const T* operator()(const U* u) const
        {
            return static_cast<T*>(u);
        }

Paul's avatar
Paul committed
127
        T operator()() const { return {}; }
Paul's avatar
Paul committed
128

Paul's avatar
Paul committed
129
        std::size_t size(std::size_t n = 1) const { return sizeof(T) * n; }
Paul's avatar
Paul committed
130

Paul's avatar
Paul committed
131
132
        template <class U>
        T* from(U* buffer, std::size_t n = 0) const
Paul's avatar
Paul committed
133
        {
Paul's avatar
Paul committed
134
            return reinterpret_cast<T*>(buffer) + n;
Paul's avatar
Paul committed
135
        }
Paul's avatar
Paul committed
136

Paul's avatar
Paul committed
137
138
        template <class U>
        const T* from(const U* buffer, std::size_t n = 0) const
Paul's avatar
Paul committed
139
        {
Paul's avatar
Paul committed
140
            return reinterpret_cast<const T*>(buffer) + n;
Paul's avatar
Paul committed
141
        }
Paul's avatar
Paul committed
142
143
    };

Paul's avatar
Paul committed
144
    template <class Visitor>
Paul's avatar
Paul committed
145
146
    void visit_type(Visitor v) const
    {
Paul's avatar
Paul committed
147
        switch(this->type())
Paul's avatar
Paul committed
148
        {
Paul's avatar
Paul committed
149
#define MIGRAPH_SHAPE_VISITOR_CASE(x, t) \
Paul's avatar
Paul committed
150
    case x: v(as<t>()); return;
Paul's avatar
Paul committed
151
152
            MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_VISITOR_CASE)
#undef MIGRAPH_SHAPE_VISITOR_CASE
Paul's avatar
Paul committed
153
        }
Paul's avatar
Paul committed
154
        MIGRAPH_THROW("Unknown type");
Paul's avatar
Paul committed
155
    }
Paul's avatar
Paul committed
156
157

    private:
Paul's avatar
Paul committed
158
    std::shared_ptr<const shape_impl> impl;
Paul's avatar
Paul committed
159

Paul's avatar
Paul committed
160
    std::size_t element_space() const;
Paul's avatar
Paul committed
161
    std::string type_string() const;
Paul's avatar
Paul committed
162
163
};

Paul's avatar
Paul committed
164
} // namespace migraph
Paul's avatar
Paul committed
165
166

#endif