"src/vscode:/vscode.git/clone" did not exist on "ffb96ae424252c9e42497b6fb6d47cc70f755b7e"
shape.hpp 6.49 KB
Newer Older
Paul's avatar
Paul committed
1
2
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_SHAPE_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_SHAPE_HPP
Paul's avatar
Paul committed
3
4
5

#include <vector>
#include <cassert>
Paul's avatar
Paul committed
6
#include <ostream>
Paul's avatar
Paul committed
7
#include <numeric>
Paul's avatar
Paul committed
8
#include <memory>
Paul's avatar
Paul committed
9

Paul's avatar
Paul committed
10
11
12
#include <migraphx/errors.hpp>
#include <migraphx/half.hpp>
#include <migraphx/config.hpp>
Paul's avatar
Paul committed
13

Paul's avatar
Paul committed
14
namespace migraphx {
Paul's avatar
Paul committed
15
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
16

17
struct value;
Paul's avatar
Paul committed
18
19
struct shape_impl;

Paul's avatar
Paul committed
20
21
struct shape
{
Paul's avatar
Paul committed
22
23

// Add new types here
Paul's avatar
Paul committed
24
// clang-format off
Paul's avatar
Paul committed
25
#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
26
    m(bool_type, bool) \
Paul's avatar
Paul committed
27
    m(half_type, half) \
Paul's avatar
Paul committed
28
    m(float_type, float) \
Paul's avatar
Paul committed
29
30
31
32
33
34
35
36
    m(double_type, double) \
    m(uint8_type, uint8_t) \
    m(int8_type, int8_t) \
    m(uint16_type, uint16_t) \
    m(int16_type, int16_t) \
    m(int32_type, int32_t) \
    m(int64_type, int64_t) \
    m(uint32_type, uint32_t) \
Paul's avatar
Paul committed
37
    m(uint64_type, uint64_t)
Paul's avatar
Paul committed
38
// clang-format on
Paul's avatar
Paul committed
39

Paul's avatar
Paul committed
40
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
Paul's avatar
Paul committed
41
42
    enum type_t
    {
Paul's avatar
Paul committed
43
        MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES)
Paul's avatar
Paul committed
44
    };
Paul's avatar
Paul committed
45
#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
Paul's avatar
Paul committed
46

Paul's avatar
Paul committed
47
    template <class T, class = void>
Paul's avatar
Paul committed
48
    struct get_type;
Paul's avatar
Paul committed
49
#define MIGRAPHX_SHAPE_GENERATE_GET_TYPE(x, t)                \
Paul's avatar
Paul committed
50
    template <class T>                                        \
Paul's avatar
Paul committed
51
    struct get_type<t, T> : std::integral_constant<type_t, x> \
Paul's avatar
Paul committed
52
53
    {                                                         \
    };
Paul's avatar
Paul committed
54
55
    MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_GET_TYPE)
#undef MIGRAPHX_SHAPE_GENERATE_GET_TYPE
Paul's avatar
Paul committed
56

wsttiger's avatar
wsttiger committed
57
58
59
60
61
    template <class T>
    struct get_type<const T> : get_type<T>
    {
    };

62
63
    static const std::vector<type_t>& types();

64
65
66
    static std::string name(type_t t);
    static std::string cpp_type(type_t t);

Paul's avatar
Paul committed
67
68
69
70
    shape();
    shape(type_t t);
    shape(type_t t, std::vector<std::size_t> l);
    shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s);
Paul's avatar
Paul committed
71

Paul's avatar
Paul committed
72
73
74
75
76
77
    template <class Range>
    shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end()))
    {
    }

    template <class Range1, class Range2>
Paul's avatar
Paul committed
78
    shape(type_t t, const Range1& l, const Range2& s)
Paul's avatar
Paul committed
79
80
81
82
83
        : shape(t,
                std::vector<std::size_t>(l.begin(), l.end()),
                std::vector<std::size_t>(s.begin(), s.end()))
    {
    }
Paul's avatar
Paul committed
84

85
86
    static shape
    from_permutation(type_t t, const std::vector<std::size_t>& l, const std::vector<int64_t>& perm);
Paul's avatar
Paul committed
87
    type_t type() const;
Paul's avatar
Paul committed
88
89
    const std::vector<std::size_t>& lens() const;
    const std::vector<std::size_t>& strides() const;
Paul's avatar
Paul committed
90
91
    std::size_t elements() const;
    std::size_t bytes() const;
Scott Thornton's avatar
Scott Thornton committed
92
    std::size_t type_size() const;
Paul's avatar
Paul committed
93

Paul's avatar
Paul committed
94
    /// Map multiple indices to space index
Paul's avatar
Paul committed
95
    std::size_t index(std::initializer_list<std::size_t> l) const;
Paul's avatar
Paul committed
96
    /// Map multiple indices to space index
Paul's avatar
Paul committed
97
    std::size_t index(const std::vector<std::size_t>& l) const;
Paul's avatar
Paul committed
98

Paul's avatar
Paul committed
99
    /// Map multiple indices from a range of iterator to a space index
Paul's avatar
Paul committed
100
    template <class Iterator>
Paul's avatar
Paul committed
101
102
103
104
    std::size_t index(Iterator start, Iterator last) const
    {
        assert(std::distance(start, last) <= this->lens().size());
        assert(this->lens().size() == this->strides().size());
Paul Fultz II's avatar
Paul Fultz II committed
105
        return std::inner_product(start, last, this->strides().begin(), std::size_t{0}); // NOLINT
Paul's avatar
Paul committed
106
    }
Paul's avatar
Paul committed
107

Paul's avatar
Paul committed
108
    /// Map element index to space index
Paul's avatar
Paul committed
109
110
    std::size_t index(std::size_t i) const;

111
    std::vector<std::size_t> multi(std::size_t i) const;
112
    void multi_copy(std::size_t i, std::size_t* start, const std::size_t* end) const;
113

Paul's avatar
Paul committed
114
    /// Returns true if the shape is packed with no padding
Paul's avatar
Paul committed
115
    bool packed() const;
Paul's avatar
Paul committed
116
117
    /// Returns true is the shape has been transposed. That is the strides are not in descending
    /// order
Paul's avatar
Paul committed
118
    bool transposed() const;
Paul's avatar
Paul committed
119
    /// Returns true if the shape is broadcasting a dimension. That is, one of the strides are zero
Paul's avatar
Paul committed
120
    bool broadcasted() const;
Paul's avatar
Paul committed
121
122
    /// Returns true if the shape is in its standard format. That is, the shape is both packed and
    /// not transposed.
Paul's avatar
Paul committed
123
    bool standard() const;
Khalique's avatar
Khalique committed
124
125
    /// Returns true if all strides are equal to 0 (scalar tensor)
    bool scalar() const;
Paul's avatar
Paul committed
126

127
128
    shape normalize_standard() const;

129
130
131
    shape with_lens(type_t t, const std::vector<std::size_t>& l) const;
    shape with_lens(const std::vector<std::size_t>& l) const;

Paul's avatar
Paul committed
132
133
    friend bool operator==(const shape& x, const shape& y);
    friend bool operator!=(const shape& x, const shape& y);
Paul's avatar
Paul committed
134
    friend std::ostream& operator<<(std::ostream& os, const shape& x);
Paul's avatar
Paul committed
135

Paul's avatar
Paul committed
136
    template <class T>
Paul's avatar
Paul committed
137
138
    struct as
    {
139
        using type = std::conditional_t<std::is_same<T, bool>{}, int8_t, T>;
Paul's avatar
Paul committed
140

141
142
143
144
        type max() const { return std::numeric_limits<type>::max(); }

        type min() const { return std::numeric_limits<type>::lowest(); }

Paul's avatar
Paul committed
145
        template <class U>
146
        type operator()(U u) const
Paul's avatar
Paul committed
147
        {
148
            return type(u);
Paul's avatar
Paul committed
149
150
        }

Paul's avatar
Paul committed
151
        template <class U>
152
        type* operator()(U* u) const
Paul's avatar
Paul committed
153
        {
154
            return static_cast<type*>(u);
Paul's avatar
Paul committed
155
156
        }

Paul's avatar
Paul committed
157
        template <class U>
158
        const type* operator()(const U* u) const
Paul's avatar
Paul committed
159
        {
160
            return static_cast<type*>(u);
Paul's avatar
Paul committed
161
162
        }

163
        type operator()() const { return {}; }
Paul's avatar
Paul committed
164

165
        std::size_t size(std::size_t n = 1) const { return sizeof(type) * n; }
Paul's avatar
Paul committed
166

Paul's avatar
Paul committed
167
        template <class U>
168
        type* from(U* buffer, std::size_t n = 0) const
Paul's avatar
Paul committed
169
        {
170
            return reinterpret_cast<type*>(buffer) + n;
Paul's avatar
Paul committed
171
        }
Paul's avatar
Paul committed
172

Paul's avatar
Paul committed
173
        template <class U>
174
        const type* from(const U* buffer, std::size_t n = 0) const
Paul's avatar
Paul committed
175
        {
176
            return reinterpret_cast<const type*>(buffer) + n;
Paul's avatar
Paul committed
177
        }
Paul's avatar
Paul committed
178

179
        type_t type_enum() const { return get_type<type>{}; }
Paul's avatar
Paul committed
180
181
    };

Paul's avatar
Paul committed
182
    template <class Visitor>
183
    static void visit(type_t t, Visitor v)
Paul's avatar
Paul committed
184
    {
185
        switch(t)
Paul's avatar
Paul committed
186
        {
Paul's avatar
Paul committed
187
#define MIGRAPHX_SHAPE_GENERATE_VISITOR_CASE(x, t) \
Paul's avatar
Paul committed
188
    case x: v(as<t>()); return;
Paul's avatar
Paul committed
189
190
            MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_VISITOR_CASE)
#undef MIGRAPHX_SHAPE_GENERATE_VISITOR_CASE
Paul's avatar
Paul committed
191
        }
Paul's avatar
Paul committed
192
        MIGRAPHX_THROW("Unknown type");
Paul's avatar
Paul committed
193
    }
Paul's avatar
Paul committed
194

195
196
197
198
199
200
    template <class Visitor>
    void visit_type(Visitor v) const
    {
        visit(this->type(), v);
    }

Paul's avatar
Paul committed
201
202
203
    template <class Visitor>
    static void visit_types(Visitor v)
    {
Paul's avatar
Paul committed
204
#define MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL(x, t) v(as<t>());
Paul's avatar
Paul committed
205
206
207
208
        MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL)
#undef MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL
    }

209
    std::string type_string() const;
210
    static type_t parse_type(const std::string& s);
211

Paul's avatar
Paul committed
212
    private:
Paul's avatar
Paul committed
213
    std::shared_ptr<const shape_impl> impl;
Paul's avatar
Paul committed
214

Paul's avatar
Paul committed
215
216
217
    std::size_t element_space() const;
};

218
219
220
void migraphx_to_value(value& v, const shape& s);
void migraphx_from_value(const value& v, shape& s);

Paul's avatar
Paul committed
221
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
222
} // namespace migraphx
Paul's avatar
Paul committed
223
224

#endif