check_shapes.hpp 11.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
Paul's avatar
Paul committed
24
25
#ifndef MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP
#define MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP
Paul's avatar
Paul committed
26

Paul's avatar
Paul committed
27
#include <migraphx/shape.hpp>
28
29
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
Paul's avatar
Paul committed
30
#include <migraphx/config.hpp>
Paul's avatar
Paul committed
31
32
#include <algorithm>

Paul's avatar
Paul committed
33
namespace migraphx {
Paul's avatar
Paul committed
34
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
35
36
37

struct check_shapes
{
Paul's avatar
Paul committed
38
39
    const shape* begin;
    const shape* end;
Paul's avatar
Paul committed
40
    const std::string name;
Charlie Lin's avatar
Charlie Lin committed
41
    const bool dynamic_allowed;
Paul's avatar
Paul committed
42

Charlie Lin's avatar
Charlie Lin committed
43
44
    check_shapes(const shape* b, const shape* e, const std::string& n, const bool d = false)
        : begin(b), end(e), name(n), dynamic_allowed(d)
Paul's avatar
Paul committed
45
    {
Charlie Lin's avatar
Charlie Lin committed
46
        check_dynamic();
Paul's avatar
Paul committed
47
    }
Paul's avatar
Paul committed
48

Paul's avatar
Paul committed
49
    template <class Op>
Charlie Lin's avatar
Charlie Lin committed
50
51
    check_shapes(const shape* b, const shape* e, const Op& op, const bool d = false)
        : begin(b), end(e), name(op.name()), dynamic_allowed(d)
Paul's avatar
Paul committed
52
    {
Charlie Lin's avatar
Charlie Lin committed
53
        check_dynamic();
Paul's avatar
Paul committed
54
55
    }

Paul's avatar
Paul committed
56
    template <class Op>
Charlie Lin's avatar
Charlie Lin committed
57
58
    check_shapes(const std::vector<shape>& s, const Op& op, const bool d = false)
        : begin(s.data()), end(s.data() + s.size()), name(op.name()), dynamic_allowed(d)
Paul's avatar
Paul committed
59
    {
Charlie Lin's avatar
Charlie Lin committed
60
61
62
63
64
65
66
67
68
        check_dynamic();
    }

    void check_dynamic() const
    {
        if(not dynamic_allowed and this->any_of([&](const shape& s) { return s.dynamic(); }))
        {
            MIGRAPHX_THROW(prefix() + "Dynamic shapes not supported");
        }
Paul's avatar
Paul committed
69
70
71
72
73
74
75
76
77
78
    }

    std::string prefix() const
    {
        if(name.empty())
            return "";
        else
            return name + ": ";
    }

Paul's avatar
Paul committed
79
    std::size_t size() const
Paul's avatar
Paul committed
80
    {
Paul's avatar
Paul committed
81
        if(begin == end)
Paul's avatar
Paul committed
82
            return 0;
Paul's avatar
Paul committed
83
84
        assert(begin != nullptr);
        assert(end != nullptr);
Paul's avatar
Paul committed
85
86
87
        return end - begin;
    }

Charlie Lin's avatar
Charlie Lin committed
88
89
90
91
92
    /*!
     * Check if the number of shape objects is equal to atleast one of the
     * given sizes.
     * \param ns template parameter pack of sizes to check against
     */
93
94
    template <class... Ts>
    const check_shapes& has(Ts... ns) const
Paul's avatar
Paul committed
95
    {
96
97
98
        if(migraphx::none_of({ns...}, [&](auto i) { return this->size() == i; }))
            MIGRAPHX_THROW(prefix() + "Wrong number of arguments: expected " +
                           to_string_range({ns...}) + " but given " + std::to_string(size()));
Paul's avatar
Paul committed
99
100
101
        return *this;
    }

102
103
    const check_shapes& nelements(std::size_t n) const
    {
104
        if(not this->all_of([&](const shape& s) { return s.elements() == n; }))
105
106
107
108
            MIGRAPHX_THROW(prefix() + "Shapes must have only " + std::to_string(n) + " elements");
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
109
110
111
112
113
    /*!
     * Check that the first shape has exactly n dimensions.
     * Do nothing if the container is empty.
     * \param n number of dimensions
     */
Paul's avatar
Paul committed
114
115
    const check_shapes& only_dims(std::size_t n) const
    {
Paul's avatar
Paul committed
116
117
        assert(begin != nullptr);
        assert(end != nullptr);
Paul's avatar
Paul committed
118
        if(begin != end)
Paul's avatar
Paul committed
119
        {
Charlie Lin's avatar
Charlie Lin committed
120
            if(begin->max_lens().size() != n)
Paul's avatar
Paul committed
121
                MIGRAPHX_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
Paul's avatar
Paul committed
122
123
124
125
        }
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
126
127
128
129
130
    /*!
     * Check that the first shape has a maximum of n dimensions.
     * Do nothing if the container is empty.
     * \param n number of dimensions
     */
kahmed10's avatar
kahmed10 committed
131
132
133
134
135
136
    const check_shapes& max_ndims(std::size_t n) const
    {
        assert(begin != nullptr);
        assert(end != nullptr);
        if(begin != end)
        {
Charlie Lin's avatar
Charlie Lin committed
137
            if(begin->max_lens().size() > n)
kahmed10's avatar
kahmed10 committed
138
139
140
141
142
143
                MIGRAPHX_THROW(prefix() + "Shape must have at most " + std::to_string(n) +
                               " dimensions");
        }
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
144
145
146
147
148
    /*!
     * Check that the first shape has a minimum of n dimensions.
     * Do nothing if the container is empty.
     * \param n number of dimensions
     */
149
150
151
152
153
154
    const check_shapes& min_ndims(std::size_t n) const
    {
        assert(begin != nullptr);
        assert(end != nullptr);
        if(begin != end)
        {
Charlie Lin's avatar
Charlie Lin committed
155
            if(begin->max_lens().size() < n)
156
157
158
159
160
161
                MIGRAPHX_THROW(prefix() + "Shape must have at least " + std::to_string(n) +
                               " dimensions");
        }
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
162
163
164
    /*!
     * Check all shapes have the same shape.
     */
Paul's avatar
Paul committed
165
166
    const check_shapes& same_shape() const
    {
167
        if(not this->same([](const shape& s) { return s; }))
Paul's avatar
Paul committed
168
            MIGRAPHX_THROW(prefix() + "Shapes do not match");
Paul's avatar
Paul committed
169
170
171
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
172
173
174
    /*!
     * Check all shapes have the same type.
     */
Paul's avatar
Paul committed
175
176
    const check_shapes& same_type() const
    {
177
        if(not this->same([](const shape& s) { return s.type(); }))
Paul's avatar
Paul committed
178
            MIGRAPHX_THROW(prefix() + "Types do not match");
Paul's avatar
Paul committed
179
180
181
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
182
183
184
    /*!
     * Check all shapes have the same lens.
     */
Paul's avatar
Paul committed
185
186
    const check_shapes& same_dims() const
    {
187
        if(not this->same([](const shape& s) { return s.max_lens(); }))
Paul's avatar
Paul committed
188
            MIGRAPHX_THROW(prefix() + "Dimensions do not match");
Charlie Lin's avatar
Charlie Lin committed
189
        if(this->any_of([&](const shape& s) { return s.dynamic(); }))
190
            if(not this->same([](const shape& s) { return s.min_lens(); }))
Charlie Lin's avatar
Charlie Lin committed
191
                MIGRAPHX_THROW(prefix() + "Min dynamic dimensions do not match");
Paul's avatar
Paul committed
192
193
194
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
195
196
197
    /*!
     * Check all shapes have the same number of dimensions.
     */
Paul's avatar
Paul committed
198
199
    const check_shapes& same_ndims() const
    {
200
        if(not this->same([](const shape& s) { return s.max_lens().size(); }))
Paul's avatar
Paul committed
201
            MIGRAPHX_THROW(prefix() + "Number of dimensions do not match");
Paul's avatar
Paul committed
202
203
204
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
205
206
207
    /*!
     * Check all shapes are standard.
     */
Paul's avatar
Paul committed
208
209
    const check_shapes& standard() const
    {
210
        if(not this->all_of([](const shape& s) { return s.standard(); }))
Paul's avatar
Paul committed
211
            MIGRAPHX_THROW(prefix() + "Shapes are not in standard layout");
Paul's avatar
Paul committed
212
213
214
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
215
216
217
    /*!
     * Check all shapes are standard or scalar.
     */
Paul's avatar
Paul committed
218
219
    const check_shapes& standard_or_scalar() const
    {
220
        if(not this->all_of([](const shape& s) { return s.standard() or s.scalar(); }))
Paul's avatar
Paul committed
221
222
223
224
            MIGRAPHX_THROW(prefix() + "Shapes are not a scalar or in standard layout");
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
225
226
227
    /*!
     * Check all shapes are packed.
     */
Paul's avatar
Paul committed
228
229
    const check_shapes& packed() const
    {
230
        if(not this->all_of([](const shape& s) { return s.packed(); }))
Paul's avatar
Paul committed
231
            MIGRAPHX_THROW(prefix() + "Shapes are not packed");
Paul's avatar
Paul committed
232
233
234
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
235
236
237
    /*!
     * Check all shapes are packed or broadcasted.
     */
238
239
    const check_shapes& packed_or_broadcasted() const
    {
240
        if(not this->all_of([](const shape& s) { return s.packed() or s.broadcasted(); }))
241
242
243
244
            MIGRAPHX_THROW(prefix() + "Shapes are not packed nor broadcasted");
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
245
246
247
    /*!
     * Check all shapes are tuples.
     */
Shucai Xiao's avatar
Shucai Xiao committed
248
249
    const check_shapes& tuple_type() const
    {
250
        if(not this->all_of([](const shape& s) { return s.type() == shape::tuple_type; }))
Shucai Xiao's avatar
Shucai Xiao committed
251
252
253
254
            MIGRAPHX_THROW(prefix() + "Shapes are not tuple!");
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
255
256
257
    /*!
     * Check all shapes are not transposed.
     */
Paul's avatar
Paul committed
258
259
    const check_shapes& not_transposed() const
    {
260
        if(not this->all_of([](const shape& s) { return not s.transposed(); }))
Paul's avatar
Paul committed
261
            MIGRAPHX_THROW(prefix() + "Shapes are transposed");
Paul's avatar
Paul committed
262
263
264
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
265
266
267
    /*!
     * Check all shapes are not broadcasted.
     */
Paul's avatar
Paul committed
268
269
    const check_shapes& not_broadcasted() const
    {
270
        if(not this->all_of([](const shape& s) { return not s.broadcasted(); }))
Paul's avatar
Paul committed
271
            MIGRAPHX_THROW(prefix() + "Shapes are broadcasted");
Paul's avatar
Paul committed
272
273
274
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
275
276
277
278
    /*!
     * Check all shapes have the same n elements.
     * \param n number of elements
     */
Paul's avatar
Paul committed
279
280
    const check_shapes& elements(std::size_t n) const
    {
281
        if(not this->all_of([&](const shape& s) { return s.elements() == n; }))
Paul's avatar
Paul committed
282
283
284
285
            MIGRAPHX_THROW(prefix() + "Wrong number of elements");
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
286
287
288
    /*!
     * Check the batches of all the shapes do not have transposed strides.
     */
289
290
    const check_shapes& batch_not_transposed() const
    {
291
292
        if(not this->all_of(
               [&](const shape& s) { return batch_not_transposed_strides(s.strides()); }))
293
294
295
296
            MIGRAPHX_THROW(prefix() + "Batch size is transposed");
        return *this;
    }

Paul's avatar
Paul committed
297
298
299
    template <class F>
    bool same(F f) const
    {
Paul's avatar
Paul committed
300
        if(begin == end)
Paul's avatar
Paul committed
301
            return true;
Paul's avatar
Paul committed
302
303
        assert(begin != nullptr);
        assert(end != nullptr);
Paul's avatar
Paul committed
304
        auto&& key = f(*begin);
Paul's avatar
Paul committed
305
306
307
308
309
310
        return this->all_of([&](const shape& s) { return f(s) == key; });
    }

    template <class Predicate>
    bool all_of(Predicate p) const
    {
Paul's avatar
Paul committed
311
312
        if(begin == end)
            return true;
Paul's avatar
Paul committed
313
314
315
316
317
        assert(begin != nullptr);
        assert(end != nullptr);
        return std::all_of(begin, end, p);
    }

Charlie Lin's avatar
Charlie Lin committed
318
319
320
321
322
323
324
325
326
327
    template <class Predicate>
    bool any_of(Predicate p) const
    {
        if(begin == end)
            return false;
        assert(begin != nullptr);
        assert(end != nullptr);
        return std::any_of(begin, end, p);
    }

328
    const shape* get(long i) const
Paul's avatar
Paul committed
329
    {
Paul's avatar
Paul committed
330
        if(i >= size())
Paul's avatar
Paul committed
331
            MIGRAPHX_THROW(prefix() + "Accessing shape out of bounds");
Paul's avatar
Paul committed
332
333
        assert(begin != nullptr);
        assert(end != nullptr);
Paul's avatar
Paul committed
334
        if(i < 0)
Paul's avatar
Paul committed
335
336
            return end - i;
        return begin + i;
Paul's avatar
Paul committed
337
338
    }

339
    check_shapes slice(long start) const { return {get(start), end, name}; }
Paul's avatar
Paul committed
340

341
    check_shapes slice(long start, long last) const { return {get(start), get(last), name}; }
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363

    private:
    static bool batch_not_transposed_strides(const std::vector<std::size_t>& strides)
    {
        if(strides.size() <= 2)
            return true;
        auto dim_0       = strides.size() - 2;
        auto matrix_size = std::max(strides[dim_0], strides[dim_0 + 1]);
        std::vector<std::size_t> batch(strides.begin(), strides.begin() + dim_0);
        if(std::all_of(batch.begin(), batch.end(), [&](auto i) { return (i < matrix_size); }))
        {
            return false;
        }

        if(std::adjacent_find(batch.begin(), batch.end(), [&](auto i, auto j) {
               return (i < j or i < matrix_size or j < matrix_size);
           }) != batch.end())
        {
            return false;
        }
        return true;
    }
Paul's avatar
Paul committed
364
365
};

Paul's avatar
Paul committed
366
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
367
} // namespace migraphx
Paul's avatar
Paul committed
368
369

#endif