layout.hpp 14.3 KB
Newer Older
1
2
3
4
5
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

6
#include "ck/wrapper/utils/layout_utils.hpp"
7
8

namespace ck {
9
namespace wrapper {
10
11

/**
12
 * \brief Layout wrapper that performs the tensor descriptor logic.
13
14
15
16
 *
 * \tparam Shape Tuple of Number<> (for compile-time layout) or index_t
 *         (dynamic layout). It is possible to pass nested shapes
 *         (e.g. ((4, 2), 2)), nested dimensions are merged.
17
 * \tparam UnnestedDescriptorType Tensor descriptor for unnested shape dims.
18
 */
19
template <typename Shape, typename UnnestedDescriptorType>
20
21
22
23
24
25
struct Layout
{
    private:
    static constexpr auto I0 = Number<0>{};
    static constexpr auto I1 = Number<1>{};

26
27
28
29
30
31
    // Generate default idxs tuple (idx with all merged nested shapes)
    template <typename... Ts>
    __host__ __device__ constexpr static auto GenerateDefaultIdxsTuple(const Tuple<Ts...>&)
    {
        return generate_tuple(
            [&](auto) {
32
                if constexpr(!UnnestedDescriptorType::IsKnownAtCompileTime())
33
34
35
36
37
38
39
40
41
42
43
44
45
                {
                    // runtime layout
                    return index_t(0);
                }
                else
                {
                    // compiletime layout
                    return I0;
                }
            },
            Number<Tuple<Ts...>::Size()>{});
    }

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    // Generate LowerDims in Compile-time for MergeTrasform using passed Type
    // If element of Tuple<Ts...> is also tuple, then merge (generate sequence for merge)
    // If tuple is element, then pass through (sequence with one element)
    template <typename Idx, typename... Ts>
    __host__ __device__ constexpr static auto GenerateLowerDim(const Tuple<Ts...>&)
    {
        if constexpr(Idx::value == 0)
        {
            if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value)
            {
                // Return Sequence for the first tuple
                constexpr index_t merge_nelems = decltype(UnrollNestedTuple(
                    tuple_element_t<Idx::value, Tuple<Ts...>>{}))::Size();
                using LowerDimsSequence =
                    typename arithmetic_sequence_gen<0, merge_nelems, 1>::type;
                return LowerDimsSequence::Reverse();
            }
            else
            {
                // Return first element
                return Sequence<0>{};
            }
        }
        else
        {
            // Get previous element using recurence (in compile-time)
            using PreviousSeqT = decltype(GenerateLowerDim<Number<Idx::value - 1>>(Tuple<Ts...>{}));
            const auto next_seq_val = PreviousSeqT::At(I0) + 1;
            if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value)
            {
                constexpr index_t merge_nelems = decltype(UnrollNestedTuple(
                    tuple_element_t<Idx::value, Tuple<Ts...>>{}))::Size();
                using LowerDimsSequence =
                    typename arithmetic_sequence_gen<next_seq_val, next_seq_val + merge_nelems, 1>::
                        type;
                return LowerDimsSequence::Reverse();
            }
            else
            {
                return Sequence<next_seq_val>{};
            }
        }
    }

    // Iterate over nested tuples in shape
    // Unroll nested tuples to align Tuple<ShapeDims...> to Tuple<IdxDims...>
    // Example idx:     (1,      1), 1,      1
    // Example shape:   (2, (2, 2)), 2, (2, 2)
    // Unrolled shape:  2,  (2, 2),  2, (2, 2)
    template <typename... ShapeDims, typename... IdxDims>
96
97
    __host__ __device__ constexpr static auto AlignShapeToIdx(const Tuple<ShapeDims...>& shape,
                                                              const Tuple<IdxDims...>& idx)
98
99
100
101
102
103
104
105
106
107
108
    {
        if constexpr(!IsNestedTuple(Tuple<IdxDims...>{}))
        {
            // Index unrolled to flatten, return shape
            return shape;
        }
        else
        {
            // Iterate over shape tuple elements:
            // 1. If corresponding idx element is tuple then return (will be unrolled)
            // 2. If no, pack in tuple. It will be restored during unroll.
109
            auto aligned_shape = generate_tuple(
110
111
112
113
114
115
116
117
118
119
120
121
122
123
                [&](auto i) {
                    if constexpr(is_detected<is_tuple,
                                             tuple_element_t<i, Tuple<IdxDims...>>>::value)
                    {
                        return shape.At(i);
                    }
                    else
                    {
                        return make_tuple(shape.At(i));
                    }
                },
                Number<Tuple<IdxDims...>::Size()>{});

            // Unroll and process next step
124
125
            return AlignShapeToIdx(UnrollNestedTuple<0, 1>(aligned_shape),
                                   UnrollNestedTuple<0, 1>(idx));
126
127
128
129
130
        }
    }

    template <typename... ShapeDims, typename DescriptorToMerge>
    __host__ __device__ constexpr static auto MakeMerge1d(const Tuple<ShapeDims...>& shape,
131
                                                          const DescriptorToMerge& desc)
132
133
    {
        // Reverse each element in tuple
134
        const auto merge_elems = TupleReverse(UnrollNestedTuple(shape));
135
        // Generate reverted indexes (column major traverse)
136
137
138
        using MergeElemsSequence = typename arithmetic_sequence_gen<0, merge_elems.Size(), 1>::type;
        const auto lower_dims    = make_tuple(MergeElemsSequence::Reverse());
        const auto upper_dims    = make_tuple(Sequence<0>{});
139
140
141
142
143
        // Merge to 1d
        return transform_tensor_descriptor(
            desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims);
    }

144
    // Merge nested shape dims when corresponding index is also nested.
145
146
147
148
149
    // Input desc shape: 2,  2,  2, 2,  2,  2
    // Example idx:      1,      1, 1,      1
    // Example shape:    2, (2, 2), 2, (2, 2)
    // Merged shape:     2,      4, 2,      4
    template <typename... ShapeDims, typename... IdxDims, typename DescriptorToMerge>
150
151
    __host__ __device__ constexpr static auto CreateMergedDescriptor(
        const Tuple<ShapeDims...>& shape, const Tuple<IdxDims...>&, DescriptorToMerge& desc)
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    {
        const auto transforms = generate_tuple(
            [&](auto i) {
                // Compare Idx with shape
                if constexpr(is_detected<is_tuple,
                                         tuple_element_t<i, Tuple<ShapeDims...>>>::value &&
                             !is_detected<is_tuple, tuple_element_t<i, Tuple<IdxDims...>>>::value)
                {
                    // If shape element is tuple and idx element is Number, then merge
                    // Unroll and reverse tuple to traverse column-major
                    const auto merge_elems = TupleReverse(UnrollNestedTuple(shape.At(i)));
                    return make_merge_transform(merge_elems);
                }
                else
                {
                    // If shape element is integer and idx element is tuple, passed idx is wrong
                    static_assert(
                        !(!is_detected<is_tuple, tuple_element_t<i, Tuple<ShapeDims...>>>::value &&
                          is_detected<is_tuple, tuple_element_t<i, Tuple<IdxDims...>>>::value),
                        "Wrong Idx for layout()");
                    // If shape element has the same type as idx element, then pass through
                    return make_pass_through_transform(shape.At(i));
                }
            },
            Number<Tuple<ShapeDims...>::Size()>{});

        const auto lower_dims =
            generate_tuple([&](auto i) { return GenerateLowerDim<Number<i>>(shape); },
                           Number<Tuple<ShapeDims...>::Size()>{});
        const auto upper_dims = generate_tuple([&](auto i) { return Sequence<i.value>{}; },
                                               Number<Tuple<ShapeDims...>::Size()>{});

        return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
    }

187
    using Descriptor1dType =
188
        remove_cvref_t<decltype(MakeMerge1d(Shape{}, UnnestedDescriptorType{}))>;
189
190
    using DefaultIdxsTupleType = remove_cvref_t<decltype(GenerateDefaultIdxsTuple(Shape{}))>;

191
    template <typename... ShapeDims, typename... IdxDims>
192
193
194
    __host__ __device__ constexpr static auto
    TransformDesc(const Tuple<ShapeDims...>& shape,
                  const Tuple<IdxDims...>& idx,
195
                  const UnnestedDescriptorType& naive_descriptor)
196
197
198
199
    {
        if constexpr(Tuple<IdxDims...>::Size() == I1)
        {
            // 1d idx path
200
            return MakeMerge1d(shape, naive_descriptor);
201
202
203
204
205
206
207
208
209
210
        }
        else
        {
            // Merge nested shape dims
            // Example idx:   (1,      1), 1,      1
            // Example shape: (2, (2, 2)), 2, (2, 2)
            // Merged shape:  (2,      4), 2,      4
            static_assert(Tuple<ShapeDims...>::Size() == Tuple<IdxDims...>::Size(),
                          "Idx rank and Shape rank must be the same (except 1d).");
            // Unroll while IdxDims is nested
211
            const auto aligned_shape = AlignShapeToIdx(shape, idx);
212
            // Transform correct form of shape
213
            return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idx), naive_descriptor);
214
215
216
        }
    }

217
    using MergedNestsDescriptorType = remove_cvref_t<decltype(TransformDesc(
218
        Shape{}, DefaultIdxsTupleType{}, UnnestedDescriptorType{}))>;
219
220

    public:
221
222
    __host__ __device__ constexpr auto GetElementSpaceSize() const
    {
223
        return unnested_descriptor_.GetElementSpaceSize();
224
    }
225

226
    __host__ __device__ Layout() = delete;
227

228
229
230
231
    /**
     * \brief Layout constructor.
     *
     * \param shape Shape for layout.
232
     * \param unnested_descriptor Descriptor
233
     */
234
235
236
    __host__ __device__ constexpr Layout(const Shape& shape,
                                         const UnnestedDescriptorType& unnested_descriptor)
        : shape_(shape)
237
238
    {
        // Construct if runtime mode
239
        if constexpr(!UnnestedDescriptorType::IsKnownAtCompileTime())
240
        {
241
242
            unnested_descriptor_ = unnested_descriptor;
            descriptor_1d_       = MakeMerge1d(shape_, unnested_descriptor_);
243
            merged_nests_descriptor_ =
244
                TransformDesc(shape_, DefaultIdxsTupleType{}, unnested_descriptor_);
245
246
247
248
249
250
251
252
253
254
255
256
        }
    }

    /**
     * \brief Returns real offset to element in runtime.
     *
     * \tparam Idxs Tuple of indexes.
     * \return Calculated offset.
     */
    template <typename Idxs>
    __host__ __device__ constexpr index_t operator()() const
    {
257
        static_assert(UnnestedDescriptorType::IsKnownAtCompileTime(),
258
                      "Compiletime operator used on runtime layout.");
259
        using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, UnnestedDescriptorType{}));
260
261
262
263
264
265
266
267
268
269
270
271
272
        using UnrolledIdx     = decltype(UnrollNestedTuple(Idxs{}));
        return TransformedDesc{}.CalculateOffset(UnrolledIdx{});
    }

    /**
     * \brief Returns real offset to element in compile time.
     *
     * \param Idx Tuple of indexes.
     * \return Calculated offset.
     */
    template <typename... Ts>
    __host__ __device__ index_t operator()(const Tuple<Ts...>& Idx) const
    {
273
274
275
276
277
278
279
280
281
282
283
284
285
        if constexpr(!IsNestedTuple(Tuple<Ts...>{}) && Tuple<Ts...>::Size() == 1)
        {
            // if 1d access
            return descriptor_1d_.CalculateOffset(Idx);
        }
        else if constexpr(!IsNestedTuple(Tuple<Ts...>{}) && Tuple<Ts...>::Size() == Shape::Size())
        {
            // if Shape::Size() access (merged nested shapes)
            return merged_nests_descriptor_.CalculateOffset(UnrollNestedTuple(Idx));
        }
        else
        {
            // Custom index, need to transform descriptor
286
            const auto transformed_desc = TransformDesc(shape_, Idx, unnested_descriptor_);
287
288
            return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx));
        }
289
290
291
292
293
294
295
296
297
    }

    /**
     * \brief Length getter (product if tuple).
     *
     * \tparam IDim Tuple of indexes or index.
     * \return Calculated size.
     */
    template <index_t IDim>
298
    __host__ __device__ constexpr auto GetLength() const
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
    {
        const auto elem = shape_.At(Number<IDim>{});
        if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>::value)
        {
            const auto unrolled_element = UnrollNestedTuple(elem);
            return TupleReduce<I0.value, unrolled_element.Size()>(
                [](auto x, auto y) { return x * y; }, unrolled_element);
        }
        else
        {
            return elem;
        }
    }

    /**
     * \brief Layout size getter (product of shape).
     *
     * \return Calculated size.
     */
318
    __host__ __device__ constexpr auto GetLengths() const
319
320
321
322
323
324
325
    {
        const auto unrolled_shape = UnrollNestedTuple(shape_);
        return TupleReduce<I0.value, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
                                                            unrolled_shape);
    }

    /**
326
     * \brief Shape getter.
327
     *
328
     * \return Shape.
329
     */
330
    __host__ __device__ constexpr const Shape& GetShape() const { return shape_; }
331

332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
    /**
     * \brief Get default lengths (tuple filled with Shape length elements).
     *
     * \return Default lengths.
     */
    __host__ __device__ constexpr auto GetDefaultLengthsTuple() const
    {
        return generate_tuple([&](auto i) { return GetLength<i>(); }, Number<Shape::Size()>{});
    }

    /**
     * \brief Get default start idx (tuple filled with 0s of the same size as Shape).
     *
     * \return Default start idx.
     */
    __host__ __device__ constexpr auto GetDefaultStartIdxs() const
    {
        return GenerateDefaultIdxsTuple(shape_);
    }

    /**
     * \brief Get default descriptor (with the same size as Shape)
     *
     * \return Default descriptor.
     */
357
    __host__ __device__ constexpr const MergedNestsDescriptorType& GetDefaultDescriptor() const
358
359
360
    {
        return merged_nests_descriptor_;
    }
361

362
363
364
365
366
367
368
369
370
371
    /**
     * \brief Get unnested descriptor (with unrolled dims)
     *
     * \return Flatten descriptor.
     */
    __host__ __device__ constexpr const UnnestedDescriptorType& GetUnnestedDescriptor() const
    {
        return unnested_descriptor_;
    }

372
    private:
373
    UnnestedDescriptorType unnested_descriptor_;
374
375
376
    Descriptor1dType descriptor_1d_;
    MergedNestsDescriptorType merged_nests_descriptor_;
    const Shape shape_;
377
378
};

379
} // namespace wrapper
380
} // namespace ck