layout.hpp 13.5 KB
Newer Older
1
2
3
4
5
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
6
#include "ck/wrapper/layout_utils.hpp"
7
8

namespace ck {
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
9
namespace wrapper {
10
11

/**
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
12
 * \brief Layout wrapper that performs the tensor descriptor logic.
13
14
15
16
17
18
19
20
 *
 * \tparam Shape Tuple of Number<> (for compile-time layout) or index_t
 *         (dynamic layout). It is possible to pass nested shapes
 *         (e.g. ((4, 2), 2)), nested dimensions are merged.
 * \tparam Strides Tuple of Number<> (for compile-time layout) or index_t
 *         (dynamic layout). Stride tuple should be nested if shape tuple is
 *         nested.
 */
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
21
template <typename Shape, typename Strides>
22
23
24
25
26
27
28
29
30
struct Layout
{
    private:
    static constexpr auto I0 = Number<0>{};
    static constexpr auto I1 = Number<1>{};

    // Generate packed (column-major) strides if not passed
    template <typename... Ts>
    __host__ __device__ constexpr static auto
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
31
    GenerateColumnMajorPackedStrides(const Tuple<Ts...>& shape)
32
    {
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
33
        const auto unrolled_shape = UnrollNestedTuple(shape);
34
35
36
37
38
39
40
41
42
        return generate_tuple(
            [&](auto i) {
                if constexpr(i.value == 0)
                {
                    return I1;
                }
                else
                {
                    return TupleReduce<I0.value, i.value>([](auto x, auto y) { return x * y; },
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
43
                                                          unrolled_shape);
44
45
                }
            },
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
46
            Number<decltype(unrolled_shape)::Size()>{});
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
    }

    // Generate LowerDims in Compile-time for MergeTrasform using passed Type
    // If element of Tuple<Ts...> is also tuple, then merge (generate sequence for merge)
    // If tuple is element, then pass through (sequence with one element)
    template <typename Idx, typename... Ts>
    __host__ __device__ constexpr static auto GenerateLowerDim(const Tuple<Ts...>&)
    {
        if constexpr(Idx::value == 0)
        {
            if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value)
            {
                // Return Sequence for the first tuple
                constexpr index_t merge_nelems = decltype(UnrollNestedTuple(
                    tuple_element_t<Idx::value, Tuple<Ts...>>{}))::Size();
                using LowerDimsSequence =
                    typename arithmetic_sequence_gen<0, merge_nelems, 1>::type;
                return LowerDimsSequence::Reverse();
            }
            else
            {
                // Return first element
                return Sequence<0>{};
            }
        }
        else
        {
            // Get previous element using recurence (in compile-time)
            using PreviousSeqT = decltype(GenerateLowerDim<Number<Idx::value - 1>>(Tuple<Ts...>{}));
            const auto next_seq_val = PreviousSeqT::At(I0) + 1;
            if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value)
            {
                constexpr index_t merge_nelems = decltype(UnrollNestedTuple(
                    tuple_element_t<Idx::value, Tuple<Ts...>>{}))::Size();
                using LowerDimsSequence =
                    typename arithmetic_sequence_gen<next_seq_val, next_seq_val + merge_nelems, 1>::
                        type;
                return LowerDimsSequence::Reverse();
            }
            else
            {
                return Sequence<next_seq_val>{};
            }
        }
    }

    // Iterate over nested tuples in shape
    // Unroll nested tuples to align Tuple<ShapeDims...> to Tuple<IdxDims...>
    // Example idx:     (1,      1), 1,      1
    // Example shape:   (2, (2, 2)), 2, (2, 2)
    // Unrolled shape:  2,  (2, 2),  2, (2, 2)
    template <typename... ShapeDims, typename... IdxDims>
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
99
100
    __host__ __device__ constexpr static auto AlignShapeToIdx(const Tuple<ShapeDims...>& shape,
                                                              const Tuple<IdxDims...>& idx)
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    {
        if constexpr(!IsNestedTuple(Tuple<IdxDims...>{}))
        {
            // Index unrolled to flatten, return shape
            return shape;
        }
        else
        {
            // Iterate over shape tuple elements:
            // 1. If corresponding idx element is tuple then return (will be unrolled)
            // 2. If no, pack in tuple. It will be restored during unroll.
            auto unrolled_shape_via_idx = generate_tuple(
                [&](auto i) {
                    if constexpr(is_detected<is_tuple,
                                             tuple_element_t<i, Tuple<IdxDims...>>>::value)
                    {
                        return shape.At(i);
                    }
                    else
                    {
                        return make_tuple(shape.At(i));
                    }
                },
                Number<Tuple<IdxDims...>::Size()>{});

            // Unroll and process next step
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
127
128
            return AlignShapeToIdx(UnrollNestedTuple<0, 1>(unrolled_shape_via_idx),
                                   UnrollNestedTuple<0, 1>(idx));
129
130
131
132
133
134
135
136
        }
    }

    template <typename... ShapeDims, typename DescriptorToMerge>
    __host__ __device__ constexpr static auto MakeMerge1d(const Tuple<ShapeDims...>& shape,
                                                          DescriptorToMerge& desc)
    {
        // Reverse each element in tuple
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
137
        const auto merge_elems = TupleReverse(UnrollNestedTuple(shape));
138
        // Generate reverted indexes (column major traverse)
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
139
140
141
        using MergeElemsSequence = typename arithmetic_sequence_gen<0, merge_elems.Size(), 1>::type;
        const auto lower_dims    = make_tuple(MergeElemsSequence::Reverse());
        const auto upper_dims    = make_tuple(Sequence<0>{});
142
143
144
145
146
147
148
149
150
151
152
        // Merge to 1d
        return transform_tensor_descriptor(
            desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims);
    }

    // Merge nested shape dims
    // Input desc shape: 2,  2,  2, 2,  2,  2
    // Example idx:      1,      1, 1,      1
    // Example shape:    2, (2, 2), 2, (2, 2)
    // Merged shape:     2,      4, 2,      4
    template <typename... ShapeDims, typename... IdxDims, typename DescriptorToMerge>
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
153
154
    __host__ __device__ constexpr static auto CreateMergedDescriptor(
        const Tuple<ShapeDims...>& shape, const Tuple<IdxDims...>&, DescriptorToMerge& desc)
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    {
        const auto transforms = generate_tuple(
            [&](auto i) {
                // Compare Idx with shape
                if constexpr(is_detected<is_tuple,
                                         tuple_element_t<i, Tuple<ShapeDims...>>>::value &&
                             !is_detected<is_tuple, tuple_element_t<i, Tuple<IdxDims...>>>::value)
                {
                    // If shape element is tuple and idx element is Number, then merge
                    // Unroll and reverse tuple to traverse column-major
                    const auto merge_elems = TupleReverse(UnrollNestedTuple(shape.At(i)));
                    return make_merge_transform(merge_elems);
                }
                else
                {
                    // If shape element is integer and idx element is tuple, passed idx is wrong
                    static_assert(
                        !(!is_detected<is_tuple, tuple_element_t<i, Tuple<ShapeDims...>>>::value &&
                          is_detected<is_tuple, tuple_element_t<i, Tuple<IdxDims...>>>::value),
                        "Wrong Idx for layout()");
                    // If shape element has the same type as idx element, then pass through
                    return make_pass_through_transform(shape.At(i));
                }
            },
            Number<Tuple<ShapeDims...>::Size()>{});

        const auto lower_dims =
            generate_tuple([&](auto i) { return GenerateLowerDim<Number<i>>(shape); },
                           Number<Tuple<ShapeDims...>::Size()>{});
        const auto upper_dims = generate_tuple([&](auto i) { return Sequence<i.value>{}; },
                                               Number<Tuple<ShapeDims...>::Size()>{});

        return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
    }

    template <typename... ShapeDims, typename... IdxDims>
    __host__ __device__ constexpr auto TransformDesc(const Tuple<ShapeDims...>& shape,
                                                     const Tuple<IdxDims...>& idx) const
    {
        if constexpr(Tuple<IdxDims...>::Size() == I1)
        {
            // 1d idx path
            return MakeMerge1d(shape, descriptor_);
        }
        else
        {
            // Merge nested shape dims
            // Example idx:   (1,      1), 1,      1
            // Example shape: (2, (2, 2)), 2, (2, 2)
            // Merged shape:  (2,      4), 2,      4
            static_assert(Tuple<ShapeDims...>::Size() == Tuple<IdxDims...>::Size(),
                          "Idx rank and Shape rank must be the same (except 1d).");
            // Unroll while IdxDims is nested
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
208
            const auto unrolled_shape_via_idx = AlignShapeToIdx(shape, idx);
209
            // Transform correct form of shape
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
210
211
            return CreateMergedDescriptor(
                unrolled_shape_via_idx, UnrollNestedTuple(idx), descriptor_);
212
213
214
215
216
217
218
        }
    }

    template <typename LayoutShape, typename LayoutStrides>
    __host__ __device__ static auto MakeNaiveDescriptor(const LayoutShape& shape,
                                                        const LayoutStrides& strides)
    {
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
219
220
221
222
223
        const auto unrolled_shape   = UnrollNestedTuple(shape);
        const auto unrolled_strides = UnrollNestedTuple(strides);
        static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
                      "Size of strides and shape are not consistent.");
        return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
224
225
226
    }

    public:
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
227
228
229
230
231
232
233
    // If stride not passed, deduce from GenerateColumnMajorPackedStrides
    using DeducedStrides =
        std::conditional_t<is_same_v<Strides, Tuple<>>,
                           remove_cvref_t<decltype(GenerateColumnMajorPackedStrides(Shape{}))>,
                           Strides>;
    using NaiveDescriptorType =
        remove_cvref_t<decltype(MakeNaiveDescriptor(Shape{}, DeducedStrides{}))>;
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248

    /**
     * \brief Layout constructor.
     *
     * \param shape Shape for layout.
     * \param strides Strides for layout (optional if tensor is packed).
     * \return Layout object.
     */
    __host__ __device__ Layout() = delete;
    __host__ __device__ Layout(const Shape& shape, const Strides& strides) : descriptor_{}
    {
        // Construct if runtime mode
        if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime())
        {
            shape_      = shape;
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
249
250
            strides_    = strides;
            descriptor_ = MakeNaiveDescriptor(shape_, strides_);
251
252
253
254
255
256
257
258
        }
    }

    __host__ __device__ Layout(const Shape& shape) : descriptor_{}
    {
        if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime())
        {
            shape_      = shape;
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
259
260
            strides_    = GenerateColumnMajorPackedStrides(shape_);
            descriptor_ = MakeNaiveDescriptor(shape_, strides_);
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
        }
    }

    /**
     * \brief Returns real offset to element in runtime.
     *
     * \tparam Idxs Tuple of indexes.
     * \return Calculated offset.
     */
    template <typename Idxs>
    __host__ __device__ constexpr index_t operator()() const
    {
        using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}));
        using UnrolledIdx     = decltype(UnrollNestedTuple(Idxs{}));
        return TransformedDesc{}.CalculateOffset(UnrolledIdx{});
    }

    /**
     * \brief Returns real offset to element in compile time.
     *
     * \param Idx Tuple of indexes.
     * \return Calculated offset.
     */
    template <typename... Ts>
    __host__ __device__ index_t operator()(const Tuple<Ts...>& Idx) const
    {
        // Static to construct transformed_desc only once
        static const auto transformed_desc = TransformDesc(shape_, Idx);
        return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx));
    }

    /**
     * \brief Length getter (product if tuple).
     *
     * \tparam IDim Tuple of indexes or index.
     * \return Calculated size.
     */
    template <index_t IDim>
    __host__ __device__ constexpr index_t GetLength() const
    {
        const auto elem = shape_.At(Number<IDim>{});
        if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>::value)
        {
            const auto unrolled_element = UnrollNestedTuple(elem);
            return TupleReduce<I0.value, unrolled_element.Size()>(
                [](auto x, auto y) { return x * y; }, unrolled_element);
        }
        else
        {
            return elem;
        }
    }

    /**
     * \brief Layout size getter (product of shape).
     *
     * \return Calculated size.
     */
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
319
    __host__ __device__ constexpr index_t GetLengths() const
320
321
322
323
324
325
326
    {
        const auto unrolled_shape = UnrollNestedTuple(shape_);
        return TupleReduce<I0.value, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
                                                            unrolled_shape);
    }

    /**
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
327
     * \brief Shape getter.
328
     *
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
329
     * \return Shape.
330
     */
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
331
332
333
334
335
336
337
338
    __host__ __device__ constexpr Shape GetShape() const { return shape_; }

    /**
     * \brief Strides getter.
     *
     * \return Strides.
     */
    __host__ __device__ constexpr DeducedStrides GetStrides() const { return strides_; }
339
340
341
342

    private:
    NaiveDescriptorType descriptor_;
    Shape shape_;
Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
343
    DeducedStrides strides_;
344
345
};

Bartlomiej Kocot's avatar
Bartlomiej Kocot committed
346
} // namespace wrapper
347
} // namespace ck