layout.hpp 13.5 KB
Newer Older
1
2
3
4
5
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
6
#include "ck/wrapper/layout_utils.hpp"
7
8

namespace ck {
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
9
namespace wrapper {
10
11

/**
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
12
 * \brief Layout wrapper that performs the tensor descriptor logic.
13
14
15
16
17
18
19
20
 *
 * \tparam Shape Tuple of Number<> (for compile-time layout) or index_t
 *         (dynamic layout). It is possible to pass nested shapes
 *         (e.g. ((4, 2), 2)), nested dimensions are merged.
 * \tparam Strides Tuple of Number<> (for compile-time layout) or index_t
 *         (dynamic layout). Stride tuple should be nested if shape tuple is
 *         nested.
 */
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
21
template <typename Shape, typename Strides>
22
23
24
25
26
27
28
29
30
struct Layout
{
    private:
    static constexpr auto I0 = Number<0>{};
    static constexpr auto I1 = Number<1>{};

    // Generate packed (column-major) strides if not passed
    template <typename... Ts>
    __host__ __device__ constexpr static auto
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
31
    GenerateColumnMajorPackedStrides(const Tuple<Ts...>& shape)
32
    {
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
33
        const auto unrolled_shape = UnrollNestedTuple(shape);
34
35
36
37
38
39
40
41
42
        return generate_tuple(
            [&](auto i) {
                if constexpr(i.value == 0)
                {
                    return I1;
                }
                else
                {
                    return TupleReduce<I0.value, i.value>([](auto x, auto y) { return x * y; },
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
43
                                                          unrolled_shape);
44
45
                }
            },
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
46
            Number<decltype(unrolled_shape)::Size()>{});
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
    }

    // Generate LowerDims in Compile-time for MergeTrasform using passed Type
    // If element of Tuple<Ts...> is also tuple, then merge (generate sequence for merge)
    // If tuple is element, then pass through (sequence with one element)
    template <typename Idx, typename... Ts>
    __host__ __device__ constexpr static auto GenerateLowerDim(const Tuple<Ts...>&)
    {
        if constexpr(Idx::value == 0)
        {
            if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value)
            {
                // Return Sequence for the first tuple
                constexpr index_t merge_nelems = decltype(UnrollNestedTuple(
                    tuple_element_t<Idx::value, Tuple<Ts...>>{}))::Size();
                using LowerDimsSequence =
                    typename arithmetic_sequence_gen<0, merge_nelems, 1>::type;
                return LowerDimsSequence::Reverse();
            }
            else
            {
                // Return first element
                return Sequence<0>{};
            }
        }
        else
        {
            // Get previous element using recurence (in compile-time)
            using PreviousSeqT = decltype(GenerateLowerDim<Number<Idx::value - 1>>(Tuple<Ts...>{}));
            const auto next_seq_val = PreviousSeqT::At(I0) + 1;
            if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value)
            {
                constexpr index_t merge_nelems = decltype(UnrollNestedTuple(
                    tuple_element_t<Idx::value, Tuple<Ts...>>{}))::Size();
                using LowerDimsSequence =
                    typename arithmetic_sequence_gen<next_seq_val, next_seq_val + merge_nelems, 1>::
                        type;
                return LowerDimsSequence::Reverse();
            }
            else
            {
                return Sequence<next_seq_val>{};
            }
        }
    }

    // Iterate over nested tuples in shape
    // Unroll nested tuples to align Tuple<ShapeDims...> to Tuple<IdxDims...>
    // Example idx:     (1,      1), 1,      1
    // Example shape:   (2, (2, 2)), 2, (2, 2)
    // Unrolled shape:  2,  (2, 2),  2, (2, 2)
    template <typename... ShapeDims, typename... IdxDims>
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
99
100
    __host__ __device__ constexpr static auto AlignShapeToIdx(const Tuple<ShapeDims...>& shape,
                                                              const Tuple<IdxDims...>& idx)
101
102
103
104
105
106
107
108
109
110
111
    {
        if constexpr(!IsNestedTuple(Tuple<IdxDims...>{}))
        {
            // Index unrolled to flatten, return shape
            return shape;
        }
        else
        {
            // Iterate over shape tuple elements:
            // 1. If corresponding idx element is tuple then return (will be unrolled)
            // 2. If no, pack in tuple. It will be restored during unroll.
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
112
            auto aligned_shape = generate_tuple(
113
114
115
116
117
118
119
120
121
122
123
124
125
126
                [&](auto i) {
                    if constexpr(is_detected<is_tuple,
                                             tuple_element_t<i, Tuple<IdxDims...>>>::value)
                    {
                        return shape.At(i);
                    }
                    else
                    {
                        return make_tuple(shape.At(i));
                    }
                },
                Number<Tuple<IdxDims...>::Size()>{});

            // Unroll and process next step
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
127
            return AlignShapeToIdx(UnrollNestedTuple<0, 1>(aligned_shape),
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
128
                                   UnrollNestedTuple<0, 1>(idx));
129
130
131
132
133
134
135
136
        }
    }

    template <typename... ShapeDims, typename DescriptorToMerge>
    __host__ __device__ constexpr static auto MakeMerge1d(const Tuple<ShapeDims...>& shape,
                                                          DescriptorToMerge& desc)
    {
        // Reverse each element in tuple
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
137
        const auto merge_elems = TupleReverse(UnrollNestedTuple(shape));
138
        // Generate reverted indexes (column major traverse)
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
139
140
141
        using MergeElemsSequence = typename arithmetic_sequence_gen<0, merge_elems.Size(), 1>::type;
        const auto lower_dims    = make_tuple(MergeElemsSequence::Reverse());
        const auto upper_dims    = make_tuple(Sequence<0>{});
142
143
144
145
146
        // Merge to 1d
        return transform_tensor_descriptor(
            desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims);
    }

Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
147
    // Merge nested shape dims. Merge nested shape dims when idx is also nested.
148
149
150
151
152
    // Input desc shape: 2,  2,  2, 2,  2,  2
    // Example idx:      1,      1, 1,      1
    // Example shape:    2, (2, 2), 2, (2, 2)
    // Merged shape:     2,      4, 2,      4
    template <typename... ShapeDims, typename... IdxDims, typename DescriptorToMerge>
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
153
154
    __host__ __device__ constexpr static auto CreateMergedDescriptor(
        const Tuple<ShapeDims...>& shape, const Tuple<IdxDims...>&, DescriptorToMerge& desc)
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    {
        const auto transforms = generate_tuple(
            [&](auto i) {
                // Compare Idx with shape
                if constexpr(is_detected<is_tuple,
                                         tuple_element_t<i, Tuple<ShapeDims...>>>::value &&
                             !is_detected<is_tuple, tuple_element_t<i, Tuple<IdxDims...>>>::value)
                {
                    // If shape element is tuple and idx element is Number, then merge
                    // Unroll and reverse tuple to traverse column-major
                    const auto merge_elems = TupleReverse(UnrollNestedTuple(shape.At(i)));
                    return make_merge_transform(merge_elems);
                }
                else
                {
                    // If shape element is integer and idx element is tuple, passed idx is wrong
                    static_assert(
                        !(!is_detected<is_tuple, tuple_element_t<i, Tuple<ShapeDims...>>>::value &&
                          is_detected<is_tuple, tuple_element_t<i, Tuple<IdxDims...>>>::value),
                        "Wrong Idx for layout()");
                    // If shape element has the same type as idx element, then pass through
                    return make_pass_through_transform(shape.At(i));
                }
            },
            Number<Tuple<ShapeDims...>::Size()>{});

        const auto lower_dims =
            generate_tuple([&](auto i) { return GenerateLowerDim<Number<i>>(shape); },
                           Number<Tuple<ShapeDims...>::Size()>{});
        const auto upper_dims = generate_tuple([&](auto i) { return Sequence<i.value>{}; },
                                               Number<Tuple<ShapeDims...>::Size()>{});

        return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
    }

    template <typename... ShapeDims, typename... IdxDims>
    __host__ __device__ constexpr auto TransformDesc(const Tuple<ShapeDims...>& shape,
                                                     const Tuple<IdxDims...>& idx) const
    {
        if constexpr(Tuple<IdxDims...>::Size() == I1)
        {
            // 1d idx path
            return MakeMerge1d(shape, descriptor_);
        }
        else
        {
            // Merge nested shape dims
            // Example idx:   (1,      1), 1,      1
            // Example shape: (2, (2, 2)), 2, (2, 2)
            // Merged shape:  (2,      4), 2,      4
            static_assert(Tuple<ShapeDims...>::Size() == Tuple<IdxDims...>::Size(),
                          "Idx rank and Shape rank must be the same (except 1d).");
            // Unroll while IdxDims is nested
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
208
            const auto aligned_shape = AlignShapeToIdx(shape, idx);
209
            // Transform correct form of shape
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
210
            return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idx), descriptor_);
211
212
213
214
215
216
217
        }
    }

    template <typename LayoutShape, typename LayoutStrides>
    __host__ __device__ static auto MakeNaiveDescriptor(const LayoutShape& shape,
                                                        const LayoutStrides& strides)
    {
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
218
219
220
221
222
        const auto unrolled_shape   = UnrollNestedTuple(shape);
        const auto unrolled_strides = UnrollNestedTuple(strides);
        static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
                      "Size of strides and shape are not consistent.");
        return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
223
224
225
    }

    public:
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
226
    // If the stride is not passed, you can infer it from `GenerateColumnMajorPackedStrides`.
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
227
228
229
230
231
232
    using DeducedStrides =
        std::conditional_t<is_same_v<Strides, Tuple<>>,
                           remove_cvref_t<decltype(GenerateColumnMajorPackedStrides(Shape{}))>,
                           Strides>;
    using NaiveDescriptorType =
        remove_cvref_t<decltype(MakeNaiveDescriptor(Shape{}, DeducedStrides{}))>;
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247

    /**
     * \brief Layout constructor.
     *
     * \param shape Shape for layout.
     * \param strides Strides for layout (optional if tensor is packed).
     * \return Layout object.
     */
    __host__ __device__ Layout() = delete;
    __host__ __device__ Layout(const Shape& shape, const Strides& strides) : descriptor_{}
    {
        // Construct if runtime mode
        if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime())
        {
            shape_      = shape;
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
248
249
            strides_    = strides;
            descriptor_ = MakeNaiveDescriptor(shape_, strides_);
250
251
252
253
254
255
256
257
        }
    }

    __host__ __device__ Layout(const Shape& shape) : descriptor_{}
    {
        if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime())
        {
            shape_      = shape;
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
258
259
            strides_    = GenerateColumnMajorPackedStrides(shape_);
            descriptor_ = MakeNaiveDescriptor(shape_, strides_);
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
        }
    }

    /**
     * \brief Returns real offset to element in runtime.
     *
     * \tparam Idxs Tuple of indexes.
     * \return Calculated offset.
     */
    template <typename Idxs>
    __host__ __device__ constexpr index_t operator()() const
    {
        using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}));
        using UnrolledIdx     = decltype(UnrollNestedTuple(Idxs{}));
        return TransformedDesc{}.CalculateOffset(UnrolledIdx{});
    }

    /**
     * \brief Returns real offset to element in compile time.
     *
     * \param Idx Tuple of indexes.
     * \return Calculated offset.
     */
    template <typename... Ts>
    __host__ __device__ index_t operator()(const Tuple<Ts...>& Idx) const
    {
        // Static to construct transformed_desc only once
        static const auto transformed_desc = TransformDesc(shape_, Idx);
        return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx));
    }

    /**
     * \brief Length getter (product if tuple).
     *
     * \tparam IDim Tuple of indexes or index.
     * \return Calculated size.
     */
    template <index_t IDim>
    __host__ __device__ constexpr index_t GetLength() const
    {
        const auto elem = shape_.At(Number<IDim>{});
        if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>::value)
        {
            const auto unrolled_element = UnrollNestedTuple(elem);
            return TupleReduce<I0.value, unrolled_element.Size()>(
                [](auto x, auto y) { return x * y; }, unrolled_element);
        }
        else
        {
            return elem;
        }
    }

    /**
     * \brief Layout size getter (product of shape).
     *
     * \return Calculated size.
     */
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
318
    __host__ __device__ constexpr index_t GetLengths() const
319
320
321
322
323
324
325
    {
        const auto unrolled_shape = UnrollNestedTuple(shape_);
        return TupleReduce<I0.value, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
                                                            unrolled_shape);
    }

    /**
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
326
     * \brief Shape getter.
327
     *
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
328
     * \return Shape.
329
     */
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
330
331
332
333
334
335
336
337
    __host__ __device__ constexpr Shape GetShape() const { return shape_; }

    /**
     * \brief Strides getter.
     *
     * \return Strides.
     */
    __host__ __device__ constexpr DeducedStrides GetStrides() const { return strides_; }
338
339
340
341

    private:
    NaiveDescriptorType descriptor_;
    Shape shape_;
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
342
    DeducedStrides strides_;
343
344
};

Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
345
} // namespace wrapper
346
} // namespace ck