"magic_pdf/vscode:/vscode.git/clone" did not exist on "fdf47155e302330798ffa10995e200edd2a694e9"
layout_utils.hpp 11.3 KB
Newer Older
1
// SPDX-License-Identifier: MIT
2
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

#pragma once

#include "ck/ck.hpp"

#include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/tuple_helper.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/sequence_helper.hpp"
#include "ck/utility/is_detected.hpp"

#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"

namespace ck {
namespace wrapper {

// Disable from doxygen docs generation
/// @cond
// forward declaration
25
template <typename Shape, typename UnrolledDescriptorType>
26
27
28
29
struct Layout;

template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
30
31

namespace {
32
33
34
35
36
37
/**
 * \brief Generate packed (column-major) strides if not passed
 *
 * \param shape Tensor shape.
 * \return Generated column-major strides.
 */
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
template <typename... Ts>
__host__ __device__ constexpr static auto
GenerateColumnMajorPackedStrides(const Tuple<Ts...>& shape)
{
    const auto unrolled_shape = UnrollNestedTuple(shape);
    return generate_tuple(
        [&](auto i) {
            if constexpr(i.value == 0)
            {
                return Number<1>{};
            }
            else
            {
                return TupleReduce<Number<0>{}.value, i.value>([](auto x, auto y) { return x * y; },
                                                               unrolled_shape);
            }
        },
        Number<decltype(unrolled_shape)::Size()>{});
}

58
59
60
61
62
63
64
/**
 * \brief Create naive tensor descriptor from nested shape.
 *
 * \param shape Tensor shape.
 * \param strides Tensor strides.
 * \return Unrolled descriptor
 */
65
template <typename LayoutShape, typename LayoutStrides>
66
67
__host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& shape,
                                                          const LayoutStrides& strides)
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
{
    const auto unrolled_shape = UnrollNestedTuple(shape);
    if constexpr(is_same_v<LayoutStrides, Tuple<>>)
    {
        // if not passed, then generate
        const auto unrolled_strides = GenerateColumnMajorPackedStrides(unrolled_shape);
        static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
                      "Size of strides and shape are not consistent.");
        return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
    }
    else
    {
        const auto unrolled_strides = UnrollNestedTuple(strides);
        static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
                      "Size of strides and shape are not consistent.");
        return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
    }
}
} // namespace

88
89
90
91
92
93
94
95
96
97
98
/// @endcond

// make_*
/**
 * \brief Make layout function.
 *
 * \tparam Shape Shape for layout.
 * \tparam Strides Strides for layout.
 * \return Constructed layout.
 */
template <typename Shape, typename Strides>
99
__host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides& strides)
100
{
101
102
    using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Strides{}));
    return Layout<Shape, UnrolledDescriptorType>(shape, MakeUnrolledDescriptor(shape, strides));
103
104
105
106
107
108
109
110
111
112
}

/**
 * \brief Make layout function with packed strides
 *        (column-major).
 *
 * \tparam Shape Shape for layout.
 * \return Constructed layout.
 */
template <typename Shape>
113
__host__ __device__ constexpr auto make_layout(const Shape& shape)
114
{
115
116
    using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Tuple<>{}));
    return Layout<Shape, UnrolledDescriptorType>(shape, MakeUnrolledDescriptor(shape, Tuple<>{}));
117
118
119
120
}

// Layout helpers
// get
121

122
123
/**
 * \private
124
125
126
127
 * \brief Get dim.
 *
 * \param dim Dimension.
 * \return Returned the same dimension.
128
129
130
131
132
133
134
 */
template <typename T>
__host__ __device__ T constexpr get(const T& dim)
{
    return dim;
}

135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
/**
 * \brief Get element from tuple (Shape/Strides/Idxs).
 *
 * \tparam idx Index to lookup.
 * \param tuple Tuple to lookup.
 * \return Requsted element.
 */
template <index_t idx, typename... Dims>
__host__ __device__ constexpr auto get(const Tuple<Dims...>& tuple)
{
    return tuple.At(Number<idx>{});
}

/**
 * \brief Get sub layout.
 *
 * \tparam idx Index to lookup.
 * \param layout Layout to create sub layout.
 * \return Requsted sub layout.
 */
155
156
template <index_t idx, typename Shape, typename FlattenDesc>
__host__ __device__ constexpr auto get(const Layout<Shape, FlattenDesc>& layout)
157
{
158
159
    const auto& shape    = layout.GetShape();
    const auto new_shape = get<idx>(shape);
160
161
    static_assert(is_detected<is_tuple, decltype(new_shape)>::value,
                  "Shape of sub layout must be tuple");
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196

    constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size();
    constexpr auto new_shape_dims = decltype(UnrollNestedTuple(new_shape))::Size();
    constexpr auto shape_offset   = decltype(UnrollNestedTuple(TupleSlice<0, idx>(shape)))::Size();

    const auto unrolled_shape = UnrollNestedTuple(shape);
    const auto transforms     = generate_tuple(
        [&](auto i) {
            // Compare Idx with shape
            if constexpr(i < shape_offset || i >= shape_offset + new_shape_dims)
            {
                // Remove dimension
                return make_freeze_transform(Number<0>{});
            }
            else
            {
                return make_pass_through_transform(unrolled_shape.At(i));
            }
        },
        Number<old_shape_dims>{});

    const auto lower_dims =
        generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<old_shape_dims>{});
    const auto upper_dims = generate_tuple(
        [&](auto i) {
            if constexpr(i < shape_offset || i >= shape_offset + new_shape_dims)
                return Sequence<>{};

            else
            {
                return Sequence<i.value - shape_offset>{};
            }
        },
        Number<old_shape_dims>{});

197
    const auto& flatten_desc = layout.GetUnrolledDescriptor();
198
199
    auto new_desc = transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
    return Layout<decltype(new_shape), decltype(new_desc)>(new_shape, new_desc);
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
}

/**
 * \brief Hierarchical get.
 *
 * \tparam Idxs Indexes to lookup.
 * \param elem Element to lookup.
 * \return Requsted element.
 */
template <index_t Idx, index_t... Idxs, typename T>
__host__ __device__ constexpr auto get(const T& elem)
{
    return get<Idxs...>(get<Idx>(elem));
}

// size
216
217
/**
 * \private
218
219
220
221
 * \brief Get size.
 *
 * \param dim Size.
 * \return Returned the same size.
222
223
224
225
226
227
228
 */
template <typename T>
__host__ __device__ T constexpr size(const T& dim)
{
    return dim;
}

229
230
231
232
/**
 * \brief Length get (product if tuple).
 *
 * \tparam idx Index to lookup.
233
 * \param layout Layout to get Shape of.
234
235
 * \return Requsted length.
 */
236
237
template <index_t idx, typename Shape, typename UnrolledDescriptorType>
__host__ __device__ constexpr auto size(const Layout<Shape, UnrolledDescriptorType>& layout)
238
239
240
241
242
243
244
245
246
247
248
{
    return layout.template GetLength<idx>();
}

/**
 * \brief Shape size (product of dims).
 *
 * \param shape Shape to lookup.
 * \return Requsted size.
 */
template <typename... ShapeDims>
249
__host__ __device__ constexpr auto size(const Tuple<ShapeDims...>& shape)
250
251
252
253
254
255
256
257
258
259
260
261
{
    const auto unrolled_shape = UnrollNestedTuple(shape);
    return TupleReduce<0, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
                                                 unrolled_shape);
}

/**
 * \brief Layout size (product of dims).
 *
 * \param layout Layout to calculate shape size.
 * \return Requsted size.
 */
262
263
template <typename Shape, typename UnrolledDescriptorType>
__host__ __device__ constexpr auto size(const Layout<Shape, UnrolledDescriptorType>& layout)
264
265
266
267
268
269
270
271
272
273
274
275
{
    return layout.GetLengths();
}

/**
 * \brief Length get from tuple (product if tuple).
 *
 * \tparam idx Index to lookup.
 * \param tuple Tuple to lookup.
 * \return Requsted length.
 */
template <index_t idx, typename... Ts>
276
__host__ __device__ constexpr auto size(const Tuple<Ts...>& tuple)
277
278
279
280
281
282
283
{
    return size(tuple.At(Number<idx>{}));
}

/**
 * \brief Hierarchical size.
 *
284
285
 * \tparam Idx First index to lookup (to avoid empty Idxs).
 * \tparam Idxs Next indexes to lookup.
286
287
288
 * \param elem Element to lookup.
 * \return Requsted element.
 */
289
template <index_t Idx, index_t... Idxs, typename T>
290
291
__host__ __device__ constexpr auto size(const T& elem)
{
292
    return size(get<Idx, Idxs...>(elem));
293
294
295
296
297
298
299
300
301
}

// rank
/**
 * \brief Get layout rank (num elements in shape).
 *
 * \param layout Layout to calculate rank.
 * \return Requsted rank.
 */
302
template <typename Shape, typename UnrolledDescriptorType>
303
__host__ __device__ constexpr auto
304
rank([[maybe_unused]] const Layout<Shape, UnrolledDescriptorType>& layout)
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
{
    return Shape::Size();
}

/**
 * \brief Get tuple rank (num elements in tuple).
 *        Return 1 if scalar passed.
 *
 * \param tuple Tuple to calculate rank.
 * \return Requsted rank.
 */
template <typename... Dims>
__host__ __device__ constexpr auto rank([[maybe_unused]] const Tuple<Dims...>& tuple)
{
    return Tuple<Dims...>::Size();
}

/**
 * \private
324
325
326
327
 * \brief Rank for scalar
 *
 * \param dim Dimension scalar.
 * \return Returned 1.
328
329
 */
template <index_t IDim>
330
__host__ __device__ constexpr index_t rank([[maybe_unused]] const Number<IDim>& dim)
331
332
333
334
335
336
{
    return 1;
}

/**
 * \private
337
338
339
340
 * \brief Rank for scalar
 *
 * \param dim Dimension scalar.
 * \return Returned 1.
341
 */
342
__host__ __device__ constexpr index_t rank([[maybe_unused]] const index_t& dim) { return 1; }
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363

/**
 * \brief Hierarchical rank.
 *
 * \tparam Idxs Indexes to lookup.
 * \param elem Element to lookup.
 * \return Requsted rank.
 */
template <index_t... Idxs, typename T>
__host__ __device__ constexpr auto rank(const T& elem)
{
    return rank(get<Idxs...>(elem));
}

// depth
/**
 * \brief Get depth of the layout shape (return 0 if scalar).
 *
 * \param layout Layout to calculate depth.
 * \return Requsted depth.
 */
364
365
template <typename Shape, typename UnrolledDescriptorType>
__host__ __device__ constexpr auto depth(const Layout<Shape, UnrolledDescriptorType>& layout)
366
{
367
368
    const auto& shape = layout.GetShape();
    return TupleDepth(shape);
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
}

/**
 * \brief Get depth of the tuple. (return 0 if scalar)
 *
 * \param tuple Tuple to calculate depth.
 * \return Requsted depth.
 */
template <typename... Dims>
__host__ __device__ constexpr auto depth(const Tuple<Dims...>& tuple)
{
    return TupleDepth(tuple);
}

/**
 * \private
385
386
387
388
 * \brief Depth for scalar
 *
 * \param dim Scalar.
 * \return Returned 0.
389
390
 */
template <index_t IDim>
391
__host__ __device__ constexpr index_t depth([[maybe_unused]] const Number<IDim>& dim)
392
393
394
395
396
397
{
    return 0;
}

/**
 * \private
398
399
400
401
 * \brief Depth for scalar
 *
 * \param dim Scalar.
 * \return Returned 0.
402
 */
403
__host__ __device__ constexpr index_t depth([[maybe_unused]] const index_t& dim) { return 0; }
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420

/**
 * \brief Hierarchical depth.
 *
 * \tparam Idxs Indexes to lookup.
 * \param elem Element to lookup.
 * \return Requsted depth.
 */
template <index_t... Idxs, typename T>
__host__ __device__ constexpr auto depth(const T& elem)
{
    return depth(get<Idxs...>(elem));
}

/**
 * \brief Get Layout shape.
 *
421
 * \param layout Layout to get shape from.
422
423
 * \return Requsted shape.
 */
424
425
template <typename LayoutType>
__host__ __device__ constexpr const auto& shape(const LayoutType& layout)
426
427
428
429
430
431
{
    return layout.GetShape();
}

} // namespace wrapper
} // namespace ck