layout.hpp 16.7 KB
Newer Older
1
2
3
4
5
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

6
#include "ck/wrapper/utils/layout_utils.hpp"
7
8

namespace ck {
9
namespace wrapper {
10
11

/**
12
 * \brief Layout wrapper that performs the tensor descriptor logic.
13
14
15
16
17
18
19
20
 *
 * \tparam Shape Tuple of Number<> (for compile-time layout) or index_t
 *         (dynamic layout). It is possible to pass nested shapes
 *         (e.g. ((4, 2), 2)), nested dimensions are merged.
 * \tparam Strides Tuple of Number<> (for compile-time layout) or index_t
 *         (dynamic layout). Stride tuple should be nested if shape tuple is
 *         nested.
 */
21
template <typename Shape, typename Strides>
22
23
24
25
26
27
struct Layout
{
    private:
    static constexpr auto I0 = Number<0>{};
    static constexpr auto I1 = Number<1>{};

28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    // Generate default idxs tuple (idx with all merged nested shapes)
    template <typename... Ts>
    __host__ __device__ constexpr static auto GenerateDefaultIdxsTuple(const Tuple<Ts...>&)
    {
        return generate_tuple(
            [&](auto) {
                if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime())
                {
                    // runtime layout
                    return index_t(0);
                }
                else
                {
                    // compiletime layout
                    return I0;
                }
            },
            Number<Tuple<Ts...>::Size()>{});
    }

48
49
50
    // Generate packed (column-major) strides if not passed
    template <typename... Ts>
    __host__ __device__ constexpr static auto
51
    GenerateColumnMajorPackedStrides(const Tuple<Ts...>& shape)
52
    {
53
        const auto unrolled_shape = UnrollNestedTuple(shape);
54
55
56
57
58
59
60
61
62
        return generate_tuple(
            [&](auto i) {
                if constexpr(i.value == 0)
                {
                    return I1;
                }
                else
                {
                    return TupleReduce<I0.value, i.value>([](auto x, auto y) { return x * y; },
63
                                                          unrolled_shape);
64
65
                }
            },
66
            Number<decltype(unrolled_shape)::Size()>{});
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    }

    // Generate LowerDims in Compile-time for MergeTrasform using passed Type
    // If element of Tuple<Ts...> is also tuple, then merge (generate sequence for merge)
    // If tuple is element, then pass through (sequence with one element)
    template <typename Idx, typename... Ts>
    __host__ __device__ constexpr static auto GenerateLowerDim(const Tuple<Ts...>&)
    {
        if constexpr(Idx::value == 0)
        {
            if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value)
            {
                // Return Sequence for the first tuple
                constexpr index_t merge_nelems = decltype(UnrollNestedTuple(
                    tuple_element_t<Idx::value, Tuple<Ts...>>{}))::Size();
                using LowerDimsSequence =
                    typename arithmetic_sequence_gen<0, merge_nelems, 1>::type;
                return LowerDimsSequence::Reverse();
            }
            else
            {
                // Return first element
                return Sequence<0>{};
            }
        }
        else
        {
            // Get previous element using recurence (in compile-time)
            using PreviousSeqT = decltype(GenerateLowerDim<Number<Idx::value - 1>>(Tuple<Ts...>{}));
            const auto next_seq_val = PreviousSeqT::At(I0) + 1;
            if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value)
            {
                constexpr index_t merge_nelems = decltype(UnrollNestedTuple(
                    tuple_element_t<Idx::value, Tuple<Ts...>>{}))::Size();
                using LowerDimsSequence =
                    typename arithmetic_sequence_gen<next_seq_val, next_seq_val + merge_nelems, 1>::
                        type;
                return LowerDimsSequence::Reverse();
            }
            else
            {
                return Sequence<next_seq_val>{};
            }
        }
    }

    // Iterate over nested tuples in shape
    // Unroll nested tuples to align Tuple<ShapeDims...> to Tuple<IdxDims...>
    // Example idx:     (1,      1), 1,      1
    // Example shape:   (2, (2, 2)), 2, (2, 2)
    // Unrolled shape:  2,  (2, 2),  2, (2, 2)
    template <typename... ShapeDims, typename... IdxDims>
119
120
    __host__ __device__ constexpr static auto AlignShapeToIdx(const Tuple<ShapeDims...>& shape,
                                                              const Tuple<IdxDims...>& idx)
121
122
123
124
125
126
127
128
129
130
131
    {
        if constexpr(!IsNestedTuple(Tuple<IdxDims...>{}))
        {
            // Index unrolled to flatten, return shape
            return shape;
        }
        else
        {
            // Iterate over shape tuple elements:
            // 1. If corresponding idx element is tuple then return (will be unrolled)
            // 2. If no, pack in tuple. It will be restored during unroll.
132
            auto aligned_shape = generate_tuple(
133
134
135
136
137
138
139
140
141
142
143
144
145
146
                [&](auto i) {
                    if constexpr(is_detected<is_tuple,
                                             tuple_element_t<i, Tuple<IdxDims...>>>::value)
                    {
                        return shape.At(i);
                    }
                    else
                    {
                        return make_tuple(shape.At(i));
                    }
                },
                Number<Tuple<IdxDims...>::Size()>{});

            // Unroll and process next step
147
148
            return AlignShapeToIdx(UnrollNestedTuple<0, 1>(aligned_shape),
                                   UnrollNestedTuple<0, 1>(idx));
149
150
151
152
153
        }
    }

    template <typename... ShapeDims, typename DescriptorToMerge>
    __host__ __device__ constexpr static auto MakeMerge1d(const Tuple<ShapeDims...>& shape,
154
                                                          const DescriptorToMerge& desc)
155
156
    {
        // Reverse each element in tuple
157
        const auto merge_elems = TupleReverse(UnrollNestedTuple(shape));
158
        // Generate reverted indexes (column major traverse)
159
160
161
        using MergeElemsSequence = typename arithmetic_sequence_gen<0, merge_elems.Size(), 1>::type;
        const auto lower_dims    = make_tuple(MergeElemsSequence::Reverse());
        const auto upper_dims    = make_tuple(Sequence<0>{});
162
163
164
165
166
        // Merge to 1d
        return transform_tensor_descriptor(
            desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims);
    }

167
    // Merge nested shape dims when corresponding index is also nested.
168
169
170
171
172
    // Input desc shape: 2,  2,  2, 2,  2,  2
    // Example idx:      1,      1, 1,      1
    // Example shape:    2, (2, 2), 2, (2, 2)
    // Merged shape:     2,      4, 2,      4
    template <typename... ShapeDims, typename... IdxDims, typename DescriptorToMerge>
173
174
    __host__ __device__ constexpr static auto CreateMergedDescriptor(
        const Tuple<ShapeDims...>& shape, const Tuple<IdxDims...>&, DescriptorToMerge& desc)
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    {
        const auto transforms = generate_tuple(
            [&](auto i) {
                // Compare Idx with shape
                if constexpr(is_detected<is_tuple,
                                         tuple_element_t<i, Tuple<ShapeDims...>>>::value &&
                             !is_detected<is_tuple, tuple_element_t<i, Tuple<IdxDims...>>>::value)
                {
                    // If shape element is tuple and idx element is Number, then merge
                    // Unroll and reverse tuple to traverse column-major
                    const auto merge_elems = TupleReverse(UnrollNestedTuple(shape.At(i)));
                    return make_merge_transform(merge_elems);
                }
                else
                {
                    // If shape element is integer and idx element is tuple, passed idx is wrong
                    static_assert(
                        !(!is_detected<is_tuple, tuple_element_t<i, Tuple<ShapeDims...>>>::value &&
                          is_detected<is_tuple, tuple_element_t<i, Tuple<IdxDims...>>>::value),
                        "Wrong Idx for layout()");
                    // If shape element has the same type as idx element, then pass through
                    return make_pass_through_transform(shape.At(i));
                }
            },
            Number<Tuple<ShapeDims...>::Size()>{});

        const auto lower_dims =
            generate_tuple([&](auto i) { return GenerateLowerDim<Number<i>>(shape); },
                           Number<Tuple<ShapeDims...>::Size()>{});
        const auto upper_dims = generate_tuple([&](auto i) { return Sequence<i.value>{}; },
                                               Number<Tuple<ShapeDims...>::Size()>{});

        return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
    }

210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
    template <typename LayoutShape, typename LayoutStrides>
    __host__ __device__ static auto MakeFlattenDescriptor(const LayoutShape& shape,
                                                          const LayoutStrides& strides)
    {
        const auto unrolled_shape   = UnrollNestedTuple(shape);
        const auto unrolled_strides = UnrollNestedTuple(strides);
        static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
                      "Size of strides and shape are not consistent.");
        return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
    }

    // If the stride is not passed, you can infer it from `GenerateColumnMajorPackedStrides`.
    using DeducedStrides =
        std::conditional_t<is_same_v<Strides, Tuple<>>,
                           remove_cvref_t<decltype(GenerateColumnMajorPackedStrides(Shape{}))>,
                           Strides>;
    using FlattenDescriptorType =
        remove_cvref_t<decltype(MakeFlattenDescriptor(Shape{}, DeducedStrides{}))>;
    using Descriptor1dType =
        remove_cvref_t<decltype(MakeMerge1d(Shape{}, FlattenDescriptorType{}))>;
    using DefaultIdxsTupleType = remove_cvref_t<decltype(GenerateDefaultIdxsTuple(Shape{}))>;

232
    template <typename... ShapeDims, typename... IdxDims>
233
234
235
236
    __host__ __device__ constexpr static auto
    TransformDesc(const Tuple<ShapeDims...>& shape,
                  const Tuple<IdxDims...>& idx,
                  const FlattenDescriptorType& naive_descriptor)
237
238
239
240
    {
        if constexpr(Tuple<IdxDims...>::Size() == I1)
        {
            // 1d idx path
241
            return MakeMerge1d(shape, naive_descriptor);
242
243
244
245
246
247
248
249
250
251
        }
        else
        {
            // Merge nested shape dims
            // Example idx:   (1,      1), 1,      1
            // Example shape: (2, (2, 2)), 2, (2, 2)
            // Merged shape:  (2,      4), 2,      4
            static_assert(Tuple<ShapeDims...>::Size() == Tuple<IdxDims...>::Size(),
                          "Idx rank and Shape rank must be the same (except 1d).");
            // Unroll while IdxDims is nested
252
            const auto aligned_shape = AlignShapeToIdx(shape, idx);
253
            // Transform correct form of shape
254
            return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idx), naive_descriptor);
255
256
257
        }
    }

258
259
    using MergedNestsDescriptorType = remove_cvref_t<decltype(TransformDesc(
        Shape{}, DefaultIdxsTupleType{}, FlattenDescriptorType{}))>;
260
261

    public:
262
263
264
265
    __host__ __device__ constexpr auto GetElementSpaceSize() const
    {
        return flatten_descriptor_.GetElementSpaceSize();
    }
266

267
    __host__ __device__ Layout() = delete;
268
269
270
271
272
273
    /**
     * \brief Layout constructor.
     *
     * \param shape Shape for layout.
     * \param strides Strides for layout (optional if tensor is packed).
     */
274
275
    __host__ __device__ constexpr Layout(const Shape& shape, const Strides& strides)
        : flatten_descriptor_{}, shape_(shape), strides_(strides)
276
277
    {
        // Construct if runtime mode
278
        if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime())
279
        {
280
281
282
283
            flatten_descriptor_ = MakeFlattenDescriptor(shape_, strides_);
            descriptor_1d_      = MakeMerge1d(shape_, flatten_descriptor_);
            merged_nests_descriptor_ =
                TransformDesc(shape_, DefaultIdxsTupleType{}, flatten_descriptor_);
284
285
286
        }
    }

287
288
289
290
291
292
293
    /**
     * \brief Layout constructor (with default packed column-major strides).
     *
     * \param shape Shape for layout.
     */
    __host__ __device__ constexpr Layout(const Shape& shape)
        : flatten_descriptor_{}, shape_(shape), strides_(GenerateColumnMajorPackedStrides(shape_))
294
    {
295
        if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime())
296
        {
297
298
299
300
            flatten_descriptor_ = MakeFlattenDescriptor(shape_, strides_);
            descriptor_1d_      = MakeMerge1d(shape_, flatten_descriptor_);
            merged_nests_descriptor_ =
                TransformDesc(shape_, DefaultIdxsTupleType{}, flatten_descriptor_);
301
302
303
304
305
306
307
308
309
310
311
312
        }
    }

    /**
     * \brief Returns real offset to element in runtime.
     *
     * \tparam Idxs Tuple of indexes.
     * \return Calculated offset.
     */
    template <typename Idxs>
    __host__ __device__ constexpr index_t operator()() const
    {
313
314
315
        static_assert(FlattenDescriptorType::IsKnownAtCompileTime(),
                      "Compiletime operator used on runtime layout.");
        using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, FlattenDescriptorType{}));
316
317
318
319
320
321
322
323
324
325
326
327
328
        using UnrolledIdx     = decltype(UnrollNestedTuple(Idxs{}));
        return TransformedDesc{}.CalculateOffset(UnrolledIdx{});
    }

    /**
     * \brief Returns real offset to element in compile time.
     *
     * \param Idx Tuple of indexes.
     * \return Calculated offset.
     */
    template <typename... Ts>
    __host__ __device__ index_t operator()(const Tuple<Ts...>& Idx) const
    {
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
        if constexpr(!IsNestedTuple(Tuple<Ts...>{}) && Tuple<Ts...>::Size() == 1)
        {
            // if 1d access
            return descriptor_1d_.CalculateOffset(Idx);
        }
        else if constexpr(!IsNestedTuple(Tuple<Ts...>{}) && Tuple<Ts...>::Size() == Shape::Size())
        {
            // if Shape::Size() access (merged nested shapes)
            return merged_nests_descriptor_.CalculateOffset(UnrollNestedTuple(Idx));
        }
        else
        {
            // Custom index, need to transform descriptor
            const auto transformed_desc = TransformDesc(shape_, Idx, flatten_descriptor_);
            return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx));
        }
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
    }

    /**
     * \brief Length getter (product if tuple).
     *
     * \tparam IDim Tuple of indexes or index.
     * \return Calculated size.
     */
    template <index_t IDim>
    __host__ __device__ constexpr index_t GetLength() const
    {
        const auto elem = shape_.At(Number<IDim>{});
        if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>::value)
        {
            const auto unrolled_element = UnrollNestedTuple(elem);
            return TupleReduce<I0.value, unrolled_element.Size()>(
                [](auto x, auto y) { return x * y; }, unrolled_element);
        }
        else
        {
            return elem;
        }
    }

    /**
     * \brief Layout size getter (product of shape).
     *
     * \return Calculated size.
     */
374
    __host__ __device__ constexpr index_t GetLengths() const
375
376
377
378
379
380
381
    {
        const auto unrolled_shape = UnrollNestedTuple(shape_);
        return TupleReduce<I0.value, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
                                                            unrolled_shape);
    }

    /**
382
     * \brief Shape getter.
383
     *
384
     * \return Shape.
385
     */
386
    __host__ __device__ constexpr const Shape& GetShape() const { return shape_; }
387
388
389
390
391
392

    /**
     * \brief Strides getter.
     *
     * \return Strides.
     */
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
    __host__ __device__ constexpr const DeducedStrides& GetStrides() const { return strides_; }

    /**
     * \brief Get default lengths (tuple filled with Shape length elements).
     *
     * \return Default lengths.
     */
    __host__ __device__ constexpr auto GetDefaultLengthsTuple() const
    {
        return generate_tuple([&](auto i) { return GetLength<i>(); }, Number<Shape::Size()>{});
    }

    /**
     * \brief Get default start idx (tuple filled with 0s of the same size as Shape).
     *
     * \return Default start idx.
     */
    __host__ __device__ constexpr auto GetDefaultStartIdxs() const
    {
        return GenerateDefaultIdxsTuple(shape_);
    }

    /**
     * \brief Get default descriptor (with the same size as Shape)
     *
     * \return Default descriptor.
     */
    __host__ __device__ constexpr MergedNestsDescriptorType GetDefaultDescriptor()
    {
        return merged_nests_descriptor_;
    }
424
425

    private:
426
427
428
429
430
    FlattenDescriptorType flatten_descriptor_;
    Descriptor1dType descriptor_1d_;
    MergedNestsDescriptorType merged_nests_descriptor_;
    const Shape shape_;
    const DeducedStrides strides_;
431
432
};

433
} // namespace wrapper
434
} // namespace ck