layout.hpp 17.4 KB
Newer Older
1
// SPDX-License-Identifier: MIT
2
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4
5

#pragma once

6
#include "ck/wrapper/utils/layout_utils.hpp"
7
8

namespace ck {
9
namespace wrapper {
10
11

/**
12
 * \brief Layout wrapper that performs the tensor descriptor logic.
13
14
15
16
 *
 * \tparam Shape Tuple of Number<> (for compile-time layout) or index_t
 *         (dynamic layout). It is possible to pass nested shapes
 *         (e.g. ((4, 2), 2)), nested dimensions are merged.
17
 * \tparam UnrolledDescriptorType Tensor descriptor for unnested shape dims.
18
 */
19
template <typename Shape, typename UnrolledDescriptorType>
20
21
22
23
24
25
struct Layout
{
    private:
    static constexpr auto I0 = Number<0>{};
    static constexpr auto I1 = Number<1>{};

26
27
28
29
30
31
    /**
     * \brief Generate default indices tuple (idx with all merged nested shapes)
     *
     * \param shape Shape to align.
     * \return Multi idx tuple with zeros.
     */
32
    template <typename... Ts>
33
34
    __host__ __device__ constexpr static auto
    GenerateDefaultIdxsTuple([[maybe_unused]] const Tuple<Ts...>& shape)
35
36
37
    {
        return generate_tuple(
            [&](auto) {
38
                if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
39
40
41
42
43
44
45
46
47
48
49
50
51
                {
                    // runtime layout
                    return index_t(0);
                }
                else
                {
                    // compiletime layout
                    return I0;
                }
            },
            Number<Tuple<Ts...>::Size()>{});
    }

52
53
54
55
56
57
58
59
60
    /**
     * \brief Generate lower dims in compile-time for the Merge transform using
     * provided type. If element of nested Tuple<Ts...> is also a tuple, then
     * merge (generate sequence for merge). If tuple is element, then pass
     * through (sequence with one element).
     *
     * \param shape Shape to align.
     * \return LowerDims for MergeTrasform.
     */
61
    template <typename Idx, typename... Ts>
62
63
    __host__ __device__ constexpr static auto
    GenerateLowerDim([[maybe_unused]] const Tuple<Ts...>& shape)
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    {
        if constexpr(Idx::value == 0)
        {
            if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value)
            {
                // Return Sequence for the first tuple
                constexpr index_t merge_nelems = decltype(UnrollNestedTuple(
                    tuple_element_t<Idx::value, Tuple<Ts...>>{}))::Size();
                using LowerDimsSequence =
                    typename arithmetic_sequence_gen<0, merge_nelems, 1>::type;
                return LowerDimsSequence::Reverse();
            }
            else
            {
                // Return first element
                return Sequence<0>{};
            }
        }
        else
        {
            // Get previous element using recurence (in compile-time)
            using PreviousSeqT = decltype(GenerateLowerDim<Number<Idx::value - 1>>(Tuple<Ts...>{}));
            const auto next_seq_val = PreviousSeqT::At(I0) + 1;
            if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value)
            {
                constexpr index_t merge_nelems = decltype(UnrollNestedTuple(
                    tuple_element_t<Idx::value, Tuple<Ts...>>{}))::Size();
                using LowerDimsSequence =
                    typename arithmetic_sequence_gen<next_seq_val, next_seq_val + merge_nelems, 1>::
                        type;
                return LowerDimsSequence::Reverse();
            }
            else
            {
                return Sequence<next_seq_val>{};
            }
        }
    }

103
104
105
106
107
108
109
110
111
112
113
    /**
     * \brief Iterate over the nested tuples in the shape.
     * Unroll nested tuples to align Tuple<ShapeDims...> to Tuple<IdxDims...>
     * Example idx:     (1,      1), 1,      1
     * Example shape:   (2, (2, 2)), 2, (2, 2)
     * Unrolled shape:  2,  (2, 2),  2, (2, 2)
     *
     * \param shape Layout shape.
     * \param idx Idx to align.
     * \return Algined shape.
     */
114
    template <typename... ShapeDims, typename... IdxDims>
115
116
    __host__ __device__ constexpr static auto AlignShapeToIdx(const Tuple<ShapeDims...>& shape,
                                                              const Tuple<IdxDims...>& idx)
117
118
119
120
121
122
123
124
125
126
127
    {
        if constexpr(!IsNestedTuple(Tuple<IdxDims...>{}))
        {
            // Index unrolled to flatten, return shape
            return shape;
        }
        else
        {
            // Iterate over shape tuple elements:
            // 1. If corresponding idx element is tuple then return (will be unrolled)
            // 2. If no, pack in tuple. It will be restored during unroll.
128
            auto aligned_shape = generate_tuple(
129
130
131
132
133
134
135
136
137
138
139
140
141
142
                [&](auto i) {
                    if constexpr(is_detected<is_tuple,
                                             tuple_element_t<i, Tuple<IdxDims...>>>::value)
                    {
                        return shape.At(i);
                    }
                    else
                    {
                        return make_tuple(shape.At(i));
                    }
                },
                Number<Tuple<IdxDims...>::Size()>{});

            // Unroll and process next step
143
144
            return AlignShapeToIdx(UnrollNestedTuple<0, 1>(aligned_shape),
                                   UnrollNestedTuple<0, 1>(idx));
145
146
147
        }
    }

148
149
150
151
152
153
154
    /**
     * \brief Merge descriptor to 1D.
     *
     * \param shape Layout shape.
     * \param desc Descriptor to merge.
     * \return 1D descriptor.
     */
155
156
    template <typename... ShapeDims, typename DescriptorToMerge>
    __host__ __device__ constexpr static auto MakeMerge1d(const Tuple<ShapeDims...>& shape,
157
                                                          const DescriptorToMerge& desc)
158
159
    {
        // Reverse each element in tuple
160
        const auto merge_elems = TupleReverse(UnrollNestedTuple(shape));
161
        // Generate reverted indexes (column major traverse)
162
163
164
        using MergeElemsSequence = typename arithmetic_sequence_gen<0, merge_elems.Size(), 1>::type;
        const auto lower_dims    = make_tuple(MergeElemsSequence::Reverse());
        const auto upper_dims    = make_tuple(Sequence<0>{});
165
        // Merge to 1d
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
        if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
        {
            return transform_tensor_descriptor(
                desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims);
        }
        else
        {
            // If the descriptor is known at the compilation time,
            // use `make_merge_transform_v1_carry_check` because it doesn't use
            // memcpy.
            return transform_tensor_descriptor(
                desc,
                make_tuple(make_merge_transform_v1_carry_check(merge_elems)),
                lower_dims,
                upper_dims);
        }
182
183
    }

184
185
186
187
188
189
190
191
192
193
194
195
    /**
     * \brief Merge nested shape dims when corresponding index is also merged.
     * Input desc shape: 2,  2,  2, 2,  2, 2
     * Example idx:      1,      1, 1, (1, 1)
     * Example shape:    2, (2, 2), 2, (2, 2)
     * Merged shape:     2,      4, 2,  2, 2
     *
     * \param shape Layout shape.
     * \param idxs Indexes to align descriptor.
     * \param desc Descriptor to merge.
     * \return Aligned descriptor to idx.
     */
196
    template <typename... ShapeDims, typename... IdxDims, typename DescriptorToMerge>
197
198
199
200
    __host__ __device__ constexpr static auto
    CreateMergedDescriptor(const Tuple<ShapeDims...>& shape,
                           [[maybe_unused]] const Tuple<IdxDims...>& idxs,
                           DescriptorToMerge& desc)
201
202
203
204
205
206
207
208
209
210
211
    {
        const auto transforms = generate_tuple(
            [&](auto i) {
                // Compare Idx with shape
                if constexpr(is_detected<is_tuple,
                                         tuple_element_t<i, Tuple<ShapeDims...>>>::value &&
                             !is_detected<is_tuple, tuple_element_t<i, Tuple<IdxDims...>>>::value)
                {
                    // If shape element is tuple and idx element is Number, then merge
                    // Unroll and reverse tuple to traverse column-major
                    const auto merge_elems = TupleReverse(UnrollNestedTuple(shape.At(i)));
212
213
214
215
216
217
218
219
220
221
222
                    if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
                    {
                        return make_merge_transform(merge_elems);
                    }
                    else
                    {
                        // If the descriptor is known at the compilation time,
                        // use `make_merge_transform_v1_carry_check` because
                        // it doesn't use memcpy.
                        return make_merge_transform_v1_carry_check(merge_elems);
                    }
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
                }
                else
                {
                    // If shape element is integer and idx element is tuple, passed idx is wrong
                    static_assert(
                        !(!is_detected<is_tuple, tuple_element_t<i, Tuple<ShapeDims...>>>::value &&
                          is_detected<is_tuple, tuple_element_t<i, Tuple<IdxDims...>>>::value),
                        "Wrong Idx for layout()");
                    // If shape element has the same type as idx element, then pass through
                    return make_pass_through_transform(shape.At(i));
                }
            },
            Number<Tuple<ShapeDims...>::Size()>{});

        const auto lower_dims =
            generate_tuple([&](auto i) { return GenerateLowerDim<Number<i>>(shape); },
                           Number<Tuple<ShapeDims...>::Size()>{});
        const auto upper_dims = generate_tuple([&](auto i) { return Sequence<i.value>{}; },
                                               Number<Tuple<ShapeDims...>::Size()>{});

        return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
    }

246
    using Descriptor1dType =
247
        remove_cvref_t<decltype(MakeMerge1d(Shape{}, UnrolledDescriptorType{}))>;
248
249
    using DefaultIdxsTupleType = remove_cvref_t<decltype(GenerateDefaultIdxsTuple(Shape{}))>;

250
    public:
251
252
253
    using LayoutShape                  = Shape;
    using LayoutUnrolledDescriptorType = UnrolledDescriptorType;

254
255
256
257
258
259
260
261
    /**
     * \brief Transform descriptor to align to passed indexes.
     *
     * \param shape Layout shape.
     * \param idxs Indexes to align descriptor.
     * \param naive_descriptor Descriptor to merge.
     * \return Aligned descriptor to idx.
     */
262
    template <typename... ShapeDims, typename... IdxDims>
263
264
    __host__ __device__ constexpr static auto
    TransformDesc(const Tuple<ShapeDims...>& shape,
265
266
                  const Tuple<IdxDims...>& idxs,
                  const UnrolledDescriptorType& naive_descriptor)
267
268
269
270
    {
        if constexpr(Tuple<IdxDims...>::Size() == I1)
        {
            // 1d idx path
271
            return MakeMerge1d(shape, naive_descriptor);
272
273
274
275
276
277
278
279
280
281
        }
        else
        {
            // Merge nested shape dims
            // Example idx:   (1,      1), 1,      1
            // Example shape: (2, (2, 2)), 2, (2, 2)
            // Merged shape:  (2,      4), 2,      4
            static_assert(Tuple<ShapeDims...>::Size() == Tuple<IdxDims...>::Size(),
                          "Idx rank and Shape rank must be the same (except 1d).");
            // Unroll while IdxDims is nested
282
            const auto aligned_shape = AlignShapeToIdx(shape, idxs);
283
            // Transform correct form of shape
284
            return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idxs), naive_descriptor);
285
286
287
        }
    }

288
    using MergedNestsDescriptorType = remove_cvref_t<decltype(TransformDesc(
289
        Shape{}, DefaultIdxsTupleType{}, UnrolledDescriptorType{}))>;
290

291
292
    __host__ __device__ constexpr auto GetElementSpaceSize() const
    {
293
        return unrolled_descriptor_.GetElementSpaceSize();
294
    }
295

296
    __host__ __device__ Layout() = delete;
297

298
299
300
301
    /**
     * \brief Layout constructor.
     *
     * \param shape Shape for layout.
302
     * \param unnested_descriptor Descriptor
303
     */
304
    __host__ __device__ constexpr Layout(const Shape& shape,
305
306
                                         const UnrolledDescriptorType& unnested_descriptor)
        : unrolled_descriptor_(unnested_descriptor), shape_(shape)
307
308
    {
        // Construct if runtime mode
309
        if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
310
        {
311
            descriptor_1d_ = MakeMerge1d(shape_, unrolled_descriptor_);
312
            merged_nests_descriptor_ =
313
                TransformDesc(shape_, DefaultIdxsTupleType{}, unrolled_descriptor_);
314
315
316
317
318
319
320
321
322
323
324
325
        }
    }

    /**
     * \brief Returns real offset to element in runtime.
     *
     * \tparam Idxs Tuple of indexes.
     * \return Calculated offset.
     */
    template <typename Idxs>
    __host__ __device__ constexpr index_t operator()() const
    {
326
        static_assert(remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime(),
327
                      "Compiletime operator used on runtime layout.");
328
        using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, UnrolledDescriptorType{}));
329
330
331
332
333
334
335
336
337
338
339
340
341
        using UnrolledIdx     = decltype(UnrollNestedTuple(Idxs{}));
        return TransformedDesc{}.CalculateOffset(UnrolledIdx{});
    }

    /**
     * \brief Returns real offset to element in compile time.
     *
     * \param Idx Tuple of indexes.
     * \return Calculated offset.
     */
    template <typename... Ts>
    __host__ __device__ index_t operator()(const Tuple<Ts...>& Idx) const
    {
342
343
344
345
346
347
348
349
350
351
352
353
354
        if constexpr(!IsNestedTuple(Tuple<Ts...>{}) && Tuple<Ts...>::Size() == 1)
        {
            // if 1d access
            return descriptor_1d_.CalculateOffset(Idx);
        }
        else if constexpr(!IsNestedTuple(Tuple<Ts...>{}) && Tuple<Ts...>::Size() == Shape::Size())
        {
            // if Shape::Size() access (merged nested shapes)
            return merged_nests_descriptor_.CalculateOffset(UnrollNestedTuple(Idx));
        }
        else
        {
            // Custom index, need to transform descriptor
355
            const auto transformed_desc = TransformDesc(shape_, Idx, unrolled_descriptor_);
356
357
            return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx));
        }
358
359
360
361
362
363
364
365
366
    }

    /**
     * \brief Length getter (product if tuple).
     *
     * \tparam IDim Tuple of indexes or index.
     * \return Calculated size.
     */
    template <index_t IDim>
367
    __host__ __device__ constexpr auto GetLength() const
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
    {
        const auto elem = shape_.At(Number<IDim>{});
        if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>::value)
        {
            const auto unrolled_element = UnrollNestedTuple(elem);
            return TupleReduce<I0.value, unrolled_element.Size()>(
                [](auto x, auto y) { return x * y; }, unrolled_element);
        }
        else
        {
            return elem;
        }
    }

    /**
     * \brief Layout size getter (product of shape).
     *
     * \return Calculated size.
     */
387
    __host__ __device__ constexpr auto GetLengths() const
388
389
390
391
392
393
394
    {
        const auto unrolled_shape = UnrollNestedTuple(shape_);
        return TupleReduce<I0.value, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
                                                            unrolled_shape);
    }

    /**
395
     * \brief Shape getter.
396
     *
397
     * \return Shape.
398
     */
399
    __host__ __device__ constexpr const Shape& GetShape() const { return shape_; }
400

401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
    /**
     * \brief Get default lengths (tuple filled with Shape length elements).
     *
     * \return Default lengths.
     */
    __host__ __device__ constexpr auto GetDefaultLengthsTuple() const
    {
        return generate_tuple([&](auto i) { return GetLength<i>(); }, Number<Shape::Size()>{});
    }

    /**
     * \brief Get default start idx (tuple filled with 0s of the same size as Shape).
     *
     * \return Default start idx.
     */
    __host__ __device__ constexpr auto GetDefaultStartIdxs() const
    {
        return GenerateDefaultIdxsTuple(shape_);
    }

    /**
422
423
424
     * \brief Get descriptor with all nested dimensions merged.
     * Example, shape: ((2, 2), 2)
     * Descriptor lengths: (4, 2)
425
     *
426
427
428
     * \note The size of merged descriptor is the same as Layout's shape.
     *
     * \return Merged nests descriptor.
429
     */
430
431
    __host__ __device__ constexpr const MergedNestsDescriptorType&
    GetMergedNestingDescriptor() const
432
433
434
    {
        return merged_nests_descriptor_;
    }
435

436
437
438
439
440
441
442
443
444
445
446
447
    /**
     * \brief Get descriptor with all dimensions are merged (1D).
     * Example, shape: ((2, 2), 2)
     * Descriptor lengths: (8)
     *
     * \return 1D descriptor.
     */
    __host__ __device__ constexpr const Descriptor1dType& Get1DDescriptor() const
    {
        return descriptor_1d_;
    }

448
449
    /**
     * \brief Get unnested descriptor (with unrolled dims)
450
451
     * Example, shape: ((2, 2), 2)
     * Descriptor lengths: (2, 2, 2)
452
     *
453
     * \return Flattened descriptor.
454
     */
455
    __host__ __device__ constexpr const UnrolledDescriptorType& GetUnrolledDescriptor() const
456
    {
457
        return unrolled_descriptor_;
458
459
    }

460
    private:
461
462
463
    // All dimensions are unrolled
    UnrolledDescriptorType unrolled_descriptor_;
    // 1D descriptor
464
    Descriptor1dType descriptor_1d_;
465
    // All nesting are merged
466
    MergedNestsDescriptorType merged_nests_descriptor_;
467
468
469
470
    // Example, shape: ((2, 2), 2)
    // UnrolledDescriptorType lengths: (2, 2, 2)
    // Descriptor1dType lengths: (8)
    // MergedNestsDescriptorType lengths: (4, 2)
471
    const Shape shape_;
472
473
};

474
} // namespace wrapper
475
} // namespace ck