check_shapes.hpp 6.96 KB
Newer Older
Paul's avatar
Paul committed
1
2
#ifndef MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP
#define MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP
Paul's avatar
Paul committed
3

Paul's avatar
Paul committed
4
#include <migraphx/shape.hpp>
5
6
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
Paul's avatar
Paul committed
7
#include <migraphx/config.hpp>
Paul's avatar
Paul committed
8
9
#include <algorithm>

Paul's avatar
Paul committed
10
namespace migraphx {
Paul's avatar
Paul committed
11
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
12
13
14

struct check_shapes
{
Paul's avatar
Paul committed
15
16
    const shape* begin;
    const shape* end;
Paul's avatar
Paul committed
17
18
    const std::string name;

Paul's avatar
Paul committed
19
20
21
    check_shapes(const shape* b, const shape* e, const std::string& n) : begin(b), end(e), name(n)
    {
    }
Paul's avatar
Paul committed
22

Paul's avatar
Paul committed
23
24
    template <class Op>
    check_shapes(const shape* b, const shape* e, const Op& op) : begin(b), end(e), name(op.name())
Paul's avatar
Paul committed
25
26
27
    {
    }

Paul's avatar
Paul committed
28
    template <class Op>
Paul's avatar
Paul committed
29
30
    check_shapes(const std::vector<shape>& s, const Op& op)
        : begin(s.data()), end(s.data() + s.size()), name(op.name())
Paul's avatar
Paul committed
31
32
33
34
35
36
37
38
39
40
41
    {
    }

    std::string prefix() const
    {
        if(name.empty())
            return "";
        else
            return name + ": ";
    }

Paul's avatar
Paul committed
42
    std::size_t size() const
Paul's avatar
Paul committed
43
    {
Paul's avatar
Paul committed
44
        if(begin == end)
Paul's avatar
Paul committed
45
            return 0;
Paul's avatar
Paul committed
46
47
        assert(begin != nullptr);
        assert(end != nullptr);
Paul's avatar
Paul committed
48
49
50
        return end - begin;
    }

51
52
    template <class... Ts>
    const check_shapes& has(Ts... ns) const
Paul's avatar
Paul committed
53
    {
54
55
56
        if(migraphx::none_of({ns...}, [&](auto i) { return this->size() == i; }))
            MIGRAPHX_THROW(prefix() + "Wrong number of arguments: expected " +
                           to_string_range({ns...}) + " but given " + std::to_string(size()));
Paul's avatar
Paul committed
57
58
59
        return *this;
    }

60
61
62
63
64
65
66
    const check_shapes& nelements(std::size_t n) const
    {
        if(!this->all_of([&](const shape& s) { return s.elements() == n; }))
            MIGRAPHX_THROW(prefix() + "Shapes must have only " + std::to_string(n) + " elements");
        return *this;
    }

Paul's avatar
Paul committed
67
68
    const check_shapes& only_dims(std::size_t n) const
    {
Paul's avatar
Paul committed
69
70
        assert(begin != nullptr);
        assert(end != nullptr);
Paul's avatar
Paul committed
71
        if(begin != end)
Paul's avatar
Paul committed
72
        {
Paul's avatar
Paul committed
73
            if(begin->lens().size() != n)
Paul's avatar
Paul committed
74
                MIGRAPHX_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
Paul's avatar
Paul committed
75
76
77
78
        }
        return *this;
    }

kahmed10's avatar
kahmed10 committed
79
80
81
82
83
84
85
86
87
88
89
90
91
    const check_shapes& max_ndims(std::size_t n) const
    {
        assert(begin != nullptr);
        assert(end != nullptr);
        if(begin != end)
        {
            if(begin->lens().size() > n)
                MIGRAPHX_THROW(prefix() + "Shape must have at most " + std::to_string(n) +
                               " dimensions");
        }
        return *this;
    }

92
93
94
95
96
97
98
99
100
101
102
103
104
    const check_shapes& min_ndims(std::size_t n) const
    {
        assert(begin != nullptr);
        assert(end != nullptr);
        if(begin != end)
        {
            if(begin->lens().size() < n)
                MIGRAPHX_THROW(prefix() + "Shape must have at least " + std::to_string(n) +
                               " dimensions");
        }
        return *this;
    }

Paul's avatar
Paul committed
105
106
107
    const check_shapes& same_shape() const
    {
        if(!this->same([](const shape& s) { return s; }))
Paul's avatar
Paul committed
108
            MIGRAPHX_THROW(prefix() + "Shapes do not match");
Paul's avatar
Paul committed
109
110
111
112
113
114
        return *this;
    }

    const check_shapes& same_type() const
    {
        if(!this->same([](const shape& s) { return s.type(); }))
Paul's avatar
Paul committed
115
            MIGRAPHX_THROW(prefix() + "Types do not match");
Paul's avatar
Paul committed
116
117
118
119
120
121
        return *this;
    }

    const check_shapes& same_dims() const
    {
        if(!this->same([](const shape& s) { return s.lens(); }))
Paul's avatar
Paul committed
122
            MIGRAPHX_THROW(prefix() + "Dimensions do not match");
Paul's avatar
Paul committed
123
124
125
126
127
128
        return *this;
    }

    const check_shapes& same_ndims() const
    {
        if(!this->same([](const shape& s) { return s.lens().size(); }))
Paul's avatar
Paul committed
129
            MIGRAPHX_THROW(prefix() + "Number of dimensions do not match");
Paul's avatar
Paul committed
130
131
132
133
134
135
        return *this;
    }

    const check_shapes& standard() const
    {
        if(!this->all_of([](const shape& s) { return s.standard(); }))
Paul's avatar
Paul committed
136
            MIGRAPHX_THROW(prefix() + "Shapes are not in standard layout");
Paul's avatar
Paul committed
137
138
139
        return *this;
    }

Paul's avatar
Paul committed
140
141
142
143
144
145
146
    const check_shapes& standard_or_scalar() const
    {
        if(!this->all_of([](const shape& s) { return s.standard() or s.scalar(); }))
            MIGRAPHX_THROW(prefix() + "Shapes are not a scalar or in standard layout");
        return *this;
    }

Paul's avatar
Paul committed
147
148
149
    const check_shapes& packed() const
    {
        if(!this->all_of([](const shape& s) { return s.packed(); }))
Paul's avatar
Paul committed
150
            MIGRAPHX_THROW(prefix() + "Shapes are not packed");
Paul's avatar
Paul committed
151
152
153
154
155
156
        return *this;
    }

    const check_shapes& not_transposed() const
    {
        if(!this->all_of([](const shape& s) { return not s.transposed(); }))
Paul's avatar
Paul committed
157
            MIGRAPHX_THROW(prefix() + "Shapes are transposed");
Paul's avatar
Paul committed
158
159
160
161
162
        return *this;
    }

    const check_shapes& not_broadcasted() const
    {
Paul's avatar
Paul committed
163
        if(!this->all_of([](const shape& s) { return not s.broadcasted(); }))
Paul's avatar
Paul committed
164
            MIGRAPHX_THROW(prefix() + "Shapes are broadcasted");
Paul's avatar
Paul committed
165
166
167
        return *this;
    }

Paul's avatar
Paul committed
168
169
170
171
172
173
174
    const check_shapes& elements(std::size_t n) const
    {
        if(!this->all_of([&](const shape& s) { return s.elements() == n; }))
            MIGRAPHX_THROW(prefix() + "Wrong number of elements");
        return *this;
    }

175
176
177
178
179
180
181
    const check_shapes& batch_not_transposed() const
    {
        if(!this->all_of([&](const shape& s) { return batch_not_transposed_strides(s.strides()); }))
            MIGRAPHX_THROW(prefix() + "Batch size is transposed");
        return *this;
    }

Paul's avatar
Paul committed
182
183
184
    template <class F>
    bool same(F f) const
    {
Paul's avatar
Paul committed
185
        if(begin == end)
Paul's avatar
Paul committed
186
            return true;
Paul's avatar
Paul committed
187
188
        assert(begin != nullptr);
        assert(end != nullptr);
Paul's avatar
Paul committed
189
        auto&& key = f(*begin);
Paul's avatar
Paul committed
190
191
192
193
194
195
        return this->all_of([&](const shape& s) { return f(s) == key; });
    }

    template <class Predicate>
    bool all_of(Predicate p) const
    {
Paul's avatar
Paul committed
196
197
        if(begin == end)
            return true;
Paul's avatar
Paul committed
198
199
200
201
202
        assert(begin != nullptr);
        assert(end != nullptr);
        return std::all_of(begin, end, p);
    }

203
    const shape* get(long i) const
Paul's avatar
Paul committed
204
    {
Paul's avatar
Paul committed
205
        if(i >= size())
Paul's avatar
Paul committed
206
            MIGRAPHX_THROW(prefix() + "Accessing shape out of bounds");
Paul's avatar
Paul committed
207
208
        assert(begin != nullptr);
        assert(end != nullptr);
Paul's avatar
Paul committed
209
        if(i < 0)
Paul's avatar
Paul committed
210
211
            return end - i;
        return begin + i;
Paul's avatar
Paul committed
212
213
    }

214
    check_shapes slice(long start) const { return {get(start), end, name}; }
Paul's avatar
Paul committed
215

216
    check_shapes slice(long start, long last) const { return {get(start), get(last), name}; }
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238

    private:
    static bool batch_not_transposed_strides(const std::vector<std::size_t>& strides)
    {
        if(strides.size() <= 2)
            return true;
        auto dim_0       = strides.size() - 2;
        auto matrix_size = std::max(strides[dim_0], strides[dim_0 + 1]);
        std::vector<std::size_t> batch(strides.begin(), strides.begin() + dim_0);
        if(std::all_of(batch.begin(), batch.end(), [&](auto i) { return (i < matrix_size); }))
        {
            return false;
        }

        if(std::adjacent_find(batch.begin(), batch.end(), [&](auto i, auto j) {
               return (i < j or i < matrix_size or j < matrix_size);
           }) != batch.end())
        {
            return false;
        }
        return true;
    }
Paul's avatar
Paul committed
239
240
};

Paul's avatar
Paul committed
241
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
242
} // namespace migraphx
Paul's avatar
Paul committed
243
244

#endif