check_shapes.hpp 13 KB
Newer Older
1
2
3
/*
 * The MIT License (MIT)
 *
Charlie Lin's avatar
Charlie Lin committed
4
 * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
Paul's avatar
Paul committed
24
25
#ifndef MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP
#define MIGRAPHX_GUARD_RTGLIB_CHECK_SHAPES_HPP
Paul's avatar
Paul committed
26

27
#include <migraphx/permutation.hpp>
Paul's avatar
Paul committed
28
#include <migraphx/shape.hpp>
29
30
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
Paul's avatar
Paul committed
31
#include <migraphx/config.hpp>
Paul's avatar
Paul committed
32
33
#include <algorithm>

Paul's avatar
Paul committed
34
namespace migraphx {
Paul's avatar
Paul committed
35
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
36

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
// Check that deduced type is incrementable, dereferencable, and comparable
template <class, class = void>
struct is_iterator
{
};

template <class T>
struct is_iterator<T,
                   std::void_t<decltype(++std::declval<T&>()),
                               decltype(*std::declval<T&>()),
                               decltype(std::declval<T&>() == std::declval<T&>())>> : std::true_type
{
};

template <class Iterator>
Paul's avatar
Paul committed
52
53
struct check_shapes
{
54
55
56
    static_assert(is_iterator<Iterator>{}, "CHECK_SHAPES: Deduced type must be an iterator");
    Iterator begin;
    Iterator end;
57
58
    std::string name;
    bool dynamic_allowed;
Paul's avatar
Paul committed
59

60
    check_shapes(Iterator b, Iterator e, const std::string& n, const bool d = false)
Charlie Lin's avatar
Charlie Lin committed
61
        : begin(b), end(e), name(n), dynamic_allowed(d)
Paul's avatar
Paul committed
62
    {
Charlie Lin's avatar
Charlie Lin committed
63
        check_dynamic();
Paul's avatar
Paul committed
64
    }
Paul's avatar
Paul committed
65

Paul's avatar
Paul committed
66
    template <class Op>
67
    check_shapes(Iterator b, Iterator e, const Op& op, const bool d = false)
Charlie Lin's avatar
Charlie Lin committed
68
        : begin(b), end(e), name(op.name()), dynamic_allowed(d)
Paul's avatar
Paul committed
69
    {
Charlie Lin's avatar
Charlie Lin committed
70
        check_dynamic();
Paul's avatar
Paul committed
71
72
    }

73
    template <class Op, MIGRAPHX_REQUIRES(not std::is_convertible<Op, std::string>{})>
Charlie Lin's avatar
Charlie Lin committed
74
    check_shapes(const std::vector<shape>& s, const Op& op, const bool d = false)
75
        : begin(s.begin()), end(s.end()), name(op.name()), dynamic_allowed(d)
Paul's avatar
Paul committed
76
    {
Charlie Lin's avatar
Charlie Lin committed
77
78
79
        check_dynamic();
    }

80
81
82
83
84
85
    check_shapes(const std::vector<shape>& s, const std::string& n, const bool d = false)
        : begin(s.begin()), end(s.end()), name(n), dynamic_allowed(d)
    {
        check_dynamic();
    }

Charlie Lin's avatar
Charlie Lin committed
86
87
88
89
90
91
    void check_dynamic() const
    {
        if(not dynamic_allowed and this->any_of([&](const shape& s) { return s.dynamic(); }))
        {
            MIGRAPHX_THROW(prefix() + "Dynamic shapes not supported");
        }
Paul's avatar
Paul committed
92
93
94
95
96
97
98
99
100
101
    }

    std::string prefix() const
    {
        if(name.empty())
            return "";
        else
            return name + ": ";
    }

Paul's avatar
Paul committed
102
    std::size_t size() const
Paul's avatar
Paul committed
103
    {
Paul's avatar
Paul committed
104
        if(begin == end)
Paul's avatar
Paul committed
105
106
107
108
            return 0;
        return end - begin;
    }

Charlie Lin's avatar
Charlie Lin committed
109
    /*!
110
     * Require the number of shape objects to equal to one of the
Charlie Lin's avatar
Charlie Lin committed
111
112
113
     * given sizes.
     * \param ns template parameter pack of sizes to check against
     */
114
115
    template <class... Ts>
    const check_shapes& has(Ts... ns) const
Paul's avatar
Paul committed
116
    {
117
118
119
        if(migraphx::none_of({ns...}, [&](auto i) { return this->size() == i; }))
            MIGRAPHX_THROW(prefix() + "Wrong number of arguments: expected " +
                           to_string_range({ns...}) + " but given " + std::to_string(size()));
Paul's avatar
Paul committed
120
121
122
        return *this;
    }

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
    /*!
     * Require the number of shape objects to equal at least a given amount.  Use this
     * method for ops that can take any number (variadic) of inputs.
     * \param n min. number of shapes
     */
    const check_shapes& has_at_least(std::size_t n) const
    {
        if(this->size() < n)
            MIGRAPHX_THROW(prefix() + "Wrong number of arguments: expected at least " +
                           to_string(n) + " but given " + std::to_string(size()));
        return *this;
    }

    /*!
     * Require all shapes to have the same number of elements.
     * \param n  number of
     */
140
141
    const check_shapes& nelements(std::size_t n) const
    {
142
        if(not this->all_of([&](const shape& s) { return s.elements() == n; }))
143
144
145
146
            MIGRAPHX_THROW(prefix() + "Shapes must have only " + std::to_string(n) + " elements");
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
147
148
149
150
151
    /*!
     * Check that the first shape has exactly n dimensions.
     * Do nothing if the container is empty.
     * \param n number of dimensions
     */
Paul's avatar
Paul committed
152
153
    const check_shapes& only_dims(std::size_t n) const
    {
Paul's avatar
Paul committed
154
        if(begin != end)
Paul's avatar
Paul committed
155
        {
Charlie Lin's avatar
Charlie Lin committed
156
            if(begin->ndim() != n)
Paul's avatar
Paul committed
157
                MIGRAPHX_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
Paul's avatar
Paul committed
158
159
160
161
        }
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
162
163
164
165
166
    /*!
     * Check that the first shape has a maximum of n dimensions.
     * Do nothing if the container is empty.
     * \param n number of dimensions
     */
kahmed10's avatar
kahmed10 committed
167
168
169
170
    const check_shapes& max_ndims(std::size_t n) const
    {
        if(begin != end)
        {
Charlie Lin's avatar
Charlie Lin committed
171
            if(begin->ndim() > n)
kahmed10's avatar
kahmed10 committed
172
173
174
175
176
177
                MIGRAPHX_THROW(prefix() + "Shape must have at most " + std::to_string(n) +
                               " dimensions");
        }
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
178
179
180
181
182
    /*!
     * Check that the first shape has a minimum of n dimensions.
     * Do nothing if the container is empty.
     * \param n number of dimensions
     */
183
184
185
186
    const check_shapes& min_ndims(std::size_t n) const
    {
        if(begin != end)
        {
Charlie Lin's avatar
Charlie Lin committed
187
            if(begin->ndim() < n)
188
189
190
191
192
193
                MIGRAPHX_THROW(prefix() + "Shape must have at least " + std::to_string(n) +
                               " dimensions");
        }
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
194
195
196
    /*!
     * Check all shapes have the same shape.
     */
Paul's avatar
Paul committed
197
198
    const check_shapes& same_shape() const
    {
199
        if(not this->same([](const shape& s) { return s; }))
Paul's avatar
Paul committed
200
            MIGRAPHX_THROW(prefix() + "Shapes do not match");
Paul's avatar
Paul committed
201
202
203
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
204
205
206
    /*!
     * Check all shapes have the same type.
     */
Paul's avatar
Paul committed
207
208
    const check_shapes& same_type() const
    {
209
        if(not this->same([](const shape& s) { return s.type(); }))
Paul's avatar
Paul committed
210
            MIGRAPHX_THROW(prefix() + "Types do not match");
Paul's avatar
Paul committed
211
212
213
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
214
215
216
    /*!
     * Check all shapes have the same lens.
     */
Paul's avatar
Paul committed
217
218
    const check_shapes& same_dims() const
    {
219
        if(not this->same([](const shape& s) { return s.max_lens(); }))
Paul's avatar
Paul committed
220
            MIGRAPHX_THROW(prefix() + "Dimensions do not match");
Charlie Lin's avatar
Charlie Lin committed
221
        if(this->any_of([&](const shape& s) { return s.dynamic(); }))
222
            if(not this->same([](const shape& s) { return s.min_lens(); }))
Charlie Lin's avatar
Charlie Lin committed
223
                MIGRAPHX_THROW(prefix() + "Min dynamic dimensions do not match");
Paul's avatar
Paul committed
224
225
226
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
227
228
229
    /*!
     * Check all shapes have the same number of dimensions.
     */
Paul's avatar
Paul committed
230
231
    const check_shapes& same_ndims() const
    {
Charlie Lin's avatar
Charlie Lin committed
232
        if(not this->same([](const shape& s) { return s.ndim(); }))
Paul's avatar
Paul committed
233
            MIGRAPHX_THROW(prefix() + "Number of dimensions do not match");
Paul's avatar
Paul committed
234
235
236
        return *this;
    }

237
238
239
240
241
242
243
244
245
246
    /*!
     * Check all shapes have the same layout.
     */
    const check_shapes& same_layout() const
    {
        if(not this->same([](const shape& s) { return find_permutation(s); }))
            MIGRAPHX_THROW(prefix() + "Layouts do not match");
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
247
248
249
    /*!
     * Check all shapes are standard.
     */
Paul's avatar
Paul committed
250
251
    const check_shapes& standard() const
    {
252
        if(not this->all_of([](const shape& s) { return s.standard(); }))
Paul's avatar
Paul committed
253
            MIGRAPHX_THROW(prefix() + "Shapes are not in standard layout");
Paul's avatar
Paul committed
254
255
256
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
257
258
259
260
261
262
263
264
265
266
    /*!
     * Check all shapes are scalar.
     */
    const check_shapes& scalar() const
    {
        if(not this->all_of([](const shape& s) { return s.scalar(); }))
            MIGRAPHX_THROW(prefix() + "Shapes are not a scalar");
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
267
268
269
    /*!
     * Check all shapes are standard or scalar.
     */
Paul's avatar
Paul committed
270
271
    const check_shapes& standard_or_scalar() const
    {
272
        if(not this->all_of([](const shape& s) { return s.standard() or s.scalar(); }))
Paul's avatar
Paul committed
273
274
275
276
            MIGRAPHX_THROW(prefix() + "Shapes are not a scalar or in standard layout");
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
277
278
279
    /*!
     * Check all shapes are packed.
     */
Paul's avatar
Paul committed
280
281
    const check_shapes& packed() const
    {
282
        if(not this->all_of([](const shape& s) { return s.packed(); }))
Paul's avatar
Paul committed
283
            MIGRAPHX_THROW(prefix() + "Shapes are not packed");
Paul's avatar
Paul committed
284
285
286
        return *this;
    }

287
288
289
290
291
292
293
294
295
296
297
298
299
    /*!
     * Check all shapes are packed with certain layouts
     */
    const check_shapes&
    packed_layouts(const std::initializer_list<std::vector<int64_t>>& layouts) const
    {
        if(not this->all_of([&](const shape& s) {
               return s.packed() and contains(layouts, find_permutation(s));
           }))
            MIGRAPHX_THROW(prefix() + "Shapes are not packed with correct layout");
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
300
301
302
    /*!
     * Check all shapes are packed or broadcasted.
     */
303
304
    const check_shapes& packed_or_broadcasted() const
    {
305
        if(not this->all_of([](const shape& s) { return s.packed() or s.broadcasted(); }))
306
307
308
309
            MIGRAPHX_THROW(prefix() + "Shapes are not packed nor broadcasted");
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
310
311
312
    /*!
     * Check all shapes are tuples.
     */
Shucai Xiao's avatar
Shucai Xiao committed
313
314
    const check_shapes& tuple_type() const
    {
315
        if(not this->all_of([](const shape& s) { return s.type() == shape::tuple_type; }))
Shucai Xiao's avatar
Shucai Xiao committed
316
317
318
319
            MIGRAPHX_THROW(prefix() + "Shapes are not tuple!");
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
320
321
322
    /*!
     * Check all shapes are not transposed.
     */
Paul's avatar
Paul committed
323
324
    const check_shapes& not_transposed() const
    {
325
        if(not this->all_of([](const shape& s) { return not s.transposed(); }))
Paul's avatar
Paul committed
326
            MIGRAPHX_THROW(prefix() + "Shapes are transposed");
Paul's avatar
Paul committed
327
328
329
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
330
331
332
    /*!
     * Check all shapes are not broadcasted.
     */
Paul's avatar
Paul committed
333
334
    const check_shapes& not_broadcasted() const
    {
335
        if(not this->all_of([](const shape& s) { return not s.broadcasted(); }))
Paul's avatar
Paul committed
336
            MIGRAPHX_THROW(prefix() + "Shapes are broadcasted");
Paul's avatar
Paul committed
337
338
339
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
340
341
342
343
    /*!
     * Check all shapes have the same n elements.
     * \param n number of elements
     */
Paul's avatar
Paul committed
344
345
    const check_shapes& elements(std::size_t n) const
    {
346
        if(not this->all_of([&](const shape& s) { return s.elements() == n; }))
Paul's avatar
Paul committed
347
348
349
350
            MIGRAPHX_THROW(prefix() + "Wrong number of elements");
        return *this;
    }

Charlie Lin's avatar
Charlie Lin committed
351
352
353
    /*!
     * Check the batches of all the shapes do not have transposed strides.
     */
354
355
    const check_shapes& batch_not_transposed() const
    {
356
357
        if(not this->all_of(
               [&](const shape& s) { return batch_not_transposed_strides(s.strides()); }))
358
359
360
361
            MIGRAPHX_THROW(prefix() + "Batch size is transposed");
        return *this;
    }

Paul's avatar
Paul committed
362
363
364
    template <class F>
    bool same(F f) const
    {
Paul's avatar
Paul committed
365
        if(begin == end)
Paul's avatar
Paul committed
366
            return true;
Paul's avatar
Paul committed
367
        auto&& key = f(*begin);
Paul's avatar
Paul committed
368
369
370
371
372
373
        return this->all_of([&](const shape& s) { return f(s) == key; });
    }

    template <class Predicate>
    bool all_of(Predicate p) const
    {
Paul's avatar
Paul committed
374
375
        if(begin == end)
            return true;
Paul's avatar
Paul committed
376
377
378
        return std::all_of(begin, end, p);
    }

Charlie Lin's avatar
Charlie Lin committed
379
380
381
382
383
384
385
386
    template <class Predicate>
    bool any_of(Predicate p) const
    {
        if(begin == end)
            return false;
        return std::any_of(begin, end, p);
    }

387
    Iterator get(long i) const
Paul's avatar
Paul committed
388
    {
Paul's avatar
Paul committed
389
        if(i >= size())
Paul's avatar
Paul committed
390
            MIGRAPHX_THROW(prefix() + "Accessing shape out of bounds");
Paul's avatar
Paul committed
391
        if(i < 0)
Paul's avatar
Paul committed
392
393
            return end - i;
        return begin + i;
Paul's avatar
Paul committed
394
395
    }

396
    check_shapes slice(long start) const { return {get(start), end, name}; }
Paul's avatar
Paul committed
397

398
    check_shapes slice(long start, long last) const { return {get(start), get(last), name}; }
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420

    private:
    static bool batch_not_transposed_strides(const std::vector<std::size_t>& strides)
    {
        if(strides.size() <= 2)
            return true;
        auto dim_0       = strides.size() - 2;
        auto matrix_size = std::max(strides[dim_0], strides[dim_0 + 1]);
        std::vector<std::size_t> batch(strides.begin(), strides.begin() + dim_0);
        if(std::all_of(batch.begin(), batch.end(), [&](auto i) { return (i < matrix_size); }))
        {
            return false;
        }

        if(std::adjacent_find(batch.begin(), batch.end(), [&](auto i, auto j) {
               return (i < j or i < matrix_size or j < matrix_size);
           }) != batch.end())
        {
            return false;
        }
        return true;
    }
Paul's avatar
Paul committed
421
422
};

423
424
425
426
427
// Deduction guide for std::vector constructor
template <class Op>
check_shapes(const std::vector<shape>&, const Op&, bool d = false)
    -> check_shapes<std::vector<shape>::const_iterator>;

Paul's avatar
Paul committed
428
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
429
} // namespace migraphx
Paul's avatar
Paul committed
430
431

#endif