layout_utils.hpp 11.4 KB
Newer Older
1
// SPDX-License-Identifier: MIT
2
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18

#pragma once

#include "ck/ck.hpp"

#include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/tuple_helper.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/sequence_helper.hpp"
#include "ck/utility/is_detected.hpp"

#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"

19
20
// Disable from doxygen docs generation
/// @cond INTERNAL
21
22
namespace ck {
namespace wrapper {
23
/// @endcond
24
25

// Disable from doxygen docs generation
26
/// @cond INTERNAL
27
// forward declaration
28
template <typename Shape, typename UnrolledDescriptorType>
29
30
31
32
struct Layout;

template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
33
34

namespace {
35
36
37
38
39
40
/**
 * \brief Generate packed (column-major) strides if not passed
 *
 * \param shape Tensor shape.
 * \return Generated column-major strides.
 */
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
template <typename... Ts>
__host__ __device__ constexpr static auto
GenerateColumnMajorPackedStrides(const Tuple<Ts...>& shape)
{
    const auto unrolled_shape = UnrollNestedTuple(shape);
    return generate_tuple(
        [&](auto i) {
            if constexpr(i.value == 0)
            {
                return Number<1>{};
            }
            else
            {
                return TupleReduce<Number<0>{}.value, i.value>([](auto x, auto y) { return x * y; },
                                                               unrolled_shape);
            }
        },
        Number<decltype(unrolled_shape)::Size()>{});
}

61
62
63
64
65
66
67
/**
 * \brief Create naive tensor descriptor from nested shape.
 *
 * \param shape Tensor shape.
 * \param strides Tensor strides.
 * \return Unrolled descriptor
 */
68
template <typename LayoutShape, typename LayoutStrides>
69
70
__host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& shape,
                                                          const LayoutStrides& strides)
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
{
    const auto unrolled_shape = UnrollNestedTuple(shape);
    if constexpr(is_same_v<LayoutStrides, Tuple<>>)
    {
        // if not passed, then generate
        const auto unrolled_strides = GenerateColumnMajorPackedStrides(unrolled_shape);
        static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
                      "Size of strides and shape are not consistent.");
        return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
    }
    else
    {
        const auto unrolled_strides = UnrollNestedTuple(strides);
        static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
                      "Size of strides and shape are not consistent.");
        return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
    }
}
} // namespace

91
92
93
94
95
96
97
98
99
100
101
/// @endcond

// make_*
/**
 * \brief Make layout function.
 *
 * \tparam Shape Shape for layout.
 * \tparam Strides Strides for layout.
 * \return Constructed layout.
 */
template <typename Shape, typename Strides>
102
__host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides& strides)
103
{
104
105
    using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Strides{}));
    return Layout<Shape, UnrolledDescriptorType>(shape, MakeUnrolledDescriptor(shape, strides));
106
107
108
109
110
111
112
113
114
115
}

/**
 * \brief Make layout function with packed strides
 *        (column-major).
 *
 * \tparam Shape Shape for layout.
 * \return Constructed layout.
 */
template <typename Shape>
116
__host__ __device__ constexpr auto make_layout(const Shape& shape)
117
{
118
119
    using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Tuple<>{}));
    return Layout<Shape, UnrolledDescriptorType>(shape, MakeUnrolledDescriptor(shape, Tuple<>{}));
120
121
122
123
}

// Layout helpers
// get
124

125
126
/**
 * \private
127
128
129
130
 * \brief Get dim.
 *
 * \param dim Dimension.
 * \return Returned the same dimension.
131
132
133
134
135
136
137
 */
template <typename T>
__host__ __device__ T constexpr get(const T& dim)
{
    return dim;
}

138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
/**
 * \brief Get element from tuple (Shape/Strides/Idxs).
 *
 * \tparam idx Index to lookup.
 * \param tuple Tuple to lookup.
 * \return Requsted element.
 */
template <index_t idx, typename... Dims>
__host__ __device__ constexpr auto get(const Tuple<Dims...>& tuple)
{
    return tuple.At(Number<idx>{});
}

/**
 * \brief Get sub layout.
 *
 * \tparam idx Index to lookup.
 * \param layout Layout to create sub layout.
 * \return Requsted sub layout.
 */
158
159
template <index_t idx, typename Shape, typename FlattenDesc>
__host__ __device__ constexpr auto get(const Layout<Shape, FlattenDesc>& layout)
160
{
161
162
    const auto& shape    = layout.GetShape();
    const auto new_shape = get<idx>(shape);
163
164
    static_assert(is_detected<is_tuple, decltype(new_shape)>::value,
                  "Shape of sub layout must be tuple");
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199

    constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size();
    constexpr auto new_shape_dims = decltype(UnrollNestedTuple(new_shape))::Size();
    constexpr auto shape_offset   = decltype(UnrollNestedTuple(TupleSlice<0, idx>(shape)))::Size();

    const auto unrolled_shape = UnrollNestedTuple(shape);
    const auto transforms     = generate_tuple(
        [&](auto i) {
            // Compare Idx with shape
            if constexpr(i < shape_offset || i >= shape_offset + new_shape_dims)
            {
                // Remove dimension
                return make_freeze_transform(Number<0>{});
            }
            else
            {
                return make_pass_through_transform(unrolled_shape.At(i));
            }
        },
        Number<old_shape_dims>{});

    const auto lower_dims =
        generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<old_shape_dims>{});
    const auto upper_dims = generate_tuple(
        [&](auto i) {
            if constexpr(i < shape_offset || i >= shape_offset + new_shape_dims)
                return Sequence<>{};

            else
            {
                return Sequence<i.value - shape_offset>{};
            }
        },
        Number<old_shape_dims>{});

200
    const auto& flatten_desc = layout.GetUnrolledDescriptor();
201
202
    auto new_desc = transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
    return Layout<decltype(new_shape), decltype(new_desc)>(new_shape, new_desc);
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
}

/**
 * \brief Hierarchical get.
 *
 * \tparam Idxs Indexes to lookup.
 * \param elem Element to lookup.
 * \return Requsted element.
 */
template <index_t Idx, index_t... Idxs, typename T>
__host__ __device__ constexpr auto get(const T& elem)
{
    return get<Idxs...>(get<Idx>(elem));
}

// size
219
220
/**
 * \private
221
222
223
224
 * \brief Get size.
 *
 * \param dim Size.
 * \return Returned the same size.
225
226
227
228
229
230
231
 */
template <typename T>
__host__ __device__ T constexpr size(const T& dim)
{
    return dim;
}

232
233
234
235
/**
 * \brief Length get (product if tuple).
 *
 * \tparam idx Index to lookup.
236
 * \param layout Layout to get Shape of.
237
238
 * \return Requsted length.
 */
239
240
template <index_t idx, typename Shape, typename UnrolledDescriptorType>
__host__ __device__ constexpr auto size(const Layout<Shape, UnrolledDescriptorType>& layout)
241
242
243
244
245
246
247
248
249
250
251
{
    return layout.template GetLength<idx>();
}

/**
 * \brief Shape size (product of dims).
 *
 * \param shape Shape to lookup.
 * \return Requsted size.
 */
template <typename... ShapeDims>
252
__host__ __device__ constexpr auto size(const Tuple<ShapeDims...>& shape)
253
254
255
256
257
258
259
260
261
262
263
264
{
    const auto unrolled_shape = UnrollNestedTuple(shape);
    return TupleReduce<0, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
                                                 unrolled_shape);
}

/**
 * \brief Layout size (product of dims).
 *
 * \param layout Layout to calculate shape size.
 * \return Requsted size.
 */
265
266
template <typename Shape, typename UnrolledDescriptorType>
__host__ __device__ constexpr auto size(const Layout<Shape, UnrolledDescriptorType>& layout)
267
268
269
270
271
272
273
274
275
276
277
278
{
    return layout.GetLengths();
}

/**
 * \brief Length get from tuple (product if tuple).
 *
 * \tparam idx Index to lookup.
 * \param tuple Tuple to lookup.
 * \return Requsted length.
 */
template <index_t idx, typename... Ts>
279
__host__ __device__ constexpr auto size(const Tuple<Ts...>& tuple)
280
281
282
283
284
285
286
{
    return size(tuple.At(Number<idx>{}));
}

/**
 * \brief Hierarchical size.
 *
287
288
 * \tparam Idx First index to lookup (to avoid empty Idxs).
 * \tparam Idxs Next indexes to lookup.
289
290
291
 * \param elem Element to lookup.
 * \return Requsted element.
 */
292
template <index_t Idx, index_t... Idxs, typename T>
293
294
__host__ __device__ constexpr auto size(const T& elem)
{
295
    return size(get<Idx, Idxs...>(elem));
296
297
298
299
300
301
302
303
304
}

// rank
/**
 * \brief Get layout rank (num elements in shape).
 *
 * \param layout Layout to calculate rank.
 * \return Requsted rank.
 */
305
template <typename Shape, typename UnrolledDescriptorType>
306
__host__ __device__ constexpr auto
307
rank([[maybe_unused]] const Layout<Shape, UnrolledDescriptorType>& layout)
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
{
    return Shape::Size();
}

/**
 * \brief Get tuple rank (num elements in tuple).
 *        Return 1 if scalar passed.
 *
 * \param tuple Tuple to calculate rank.
 * \return Requsted rank.
 */
template <typename... Dims>
__host__ __device__ constexpr auto rank([[maybe_unused]] const Tuple<Dims...>& tuple)
{
    return Tuple<Dims...>::Size();
}

/**
 * \private
327
328
329
330
 * \brief Rank for scalar
 *
 * \param dim Dimension scalar.
 * \return Returned 1.
331
332
 */
template <index_t IDim>
333
__host__ __device__ constexpr index_t rank([[maybe_unused]] const Number<IDim>& dim)
334
335
336
337
338
339
{
    return 1;
}

/**
 * \private
340
341
342
343
 * \brief Rank for scalar
 *
 * \param dim Dimension scalar.
 * \return Returned 1.
344
 */
345
__host__ __device__ constexpr index_t rank([[maybe_unused]] const index_t& dim) { return 1; }
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366

/**
 * \brief Hierarchical rank.
 *
 * \tparam Idxs Indexes to lookup.
 * \param elem Element to lookup.
 * \return Requsted rank.
 */
template <index_t... Idxs, typename T>
__host__ __device__ constexpr auto rank(const T& elem)
{
    return rank(get<Idxs...>(elem));
}

// depth
/**
 * \brief Get depth of the layout shape (return 0 if scalar).
 *
 * \param layout Layout to calculate depth.
 * \return Requsted depth.
 */
367
368
template <typename Shape, typename UnrolledDescriptorType>
__host__ __device__ constexpr auto depth(const Layout<Shape, UnrolledDescriptorType>& layout)
369
{
370
371
    const auto& shape = layout.GetShape();
    return TupleDepth(shape);
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
}

/**
 * \brief Get depth of the tuple. (return 0 if scalar)
 *
 * \param tuple Tuple to calculate depth.
 * \return Requsted depth.
 */
template <typename... Dims>
__host__ __device__ constexpr auto depth(const Tuple<Dims...>& tuple)
{
    return TupleDepth(tuple);
}

/**
 * \private
388
389
390
391
 * \brief Depth for scalar
 *
 * \param dim Scalar.
 * \return Returned 0.
392
393
 */
template <index_t IDim>
394
__host__ __device__ constexpr index_t depth([[maybe_unused]] const Number<IDim>& dim)
395
396
397
398
399
400
{
    return 0;
}

/**
 * \private
401
402
403
404
 * \brief Depth for scalar
 *
 * \param dim Scalar.
 * \return Returned 0.
405
 */
406
__host__ __device__ constexpr index_t depth([[maybe_unused]] const index_t& dim) { return 0; }
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423

/**
 * \brief Hierarchical depth.
 *
 * \tparam Idxs Indexes to lookup.
 * \param elem Element to lookup.
 * \return Requsted depth.
 */
template <index_t... Idxs, typename T>
__host__ __device__ constexpr auto depth(const T& elem)
{
    return depth(get<Idxs...>(elem));
}

/**
 * \brief Get Layout shape.
 *
424
 * \param layout Layout to get shape from.
425
426
 * \return Requsted shape.
 */
427
428
template <typename LayoutType>
__host__ __device__ constexpr const auto& shape(const LayoutType& layout)
429
430
431
432
433
434
{
    return layout.GetShape();
}

} // namespace wrapper
} // namespace ck