tensor.hpp 15.8 KB
Newer Older
1
// SPDX-License-Identifier: MIT
2
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4
5
6

#pragma once

#include "utils/tensor_utils.hpp"
7
#include "utils/tensor_partition.hpp"
8
9
10
11
12
#include "utils/layout_utils.hpp"

namespace ck {
namespace wrapper {

13
14
namespace detail {
namespace {
15
/**
16
 * \brief Check if Tuple contains Slice object
17
 *
18
 * \return True if tuple contains Slice object.
19
 */
20
21
template <typename T>
__host__ __device__ constexpr bool HasSlice(T&&)
22
{
23
24
25
26
27
28
29
    return is_detected<is_slice, T>::value;
}
template <typename... Ts>
__host__ __device__ constexpr bool HasSlice(Tuple<Ts...>&&)
{
    return (HasSlice(Ts{}) || ...);
}
30

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
/**
 * \brief Calculate new shape after slice from parent shape.
 *
 * \param idxs Tuple of indexes defining slice ranges.
 * \param shape Shape which will be sliced.
 * \return New tensor shape.
 */
template <typename... Ts, typename SlicedShape>
__host__ __device__ constexpr auto GetSlicedShape(const Tuple<Ts...>& idxs,
                                                  const SlicedShape& shape)
{
    // Pack each value in tuple to remove empty tuples after generation
    auto new_shape = generate_tuple(
        [&](auto i) {
            constexpr auto num_i = Number<i>{};
            if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
            {
                if constexpr(!detail::HasSlice(tuple_element_t<i.value, Tuple<Ts...>>{}))
49
                {
50
                    // if tuple does not have any slice then we can remove dimension
51
52
53
54
                    return Tuple<>{};
                }
                else
                {
55
56
                    // if tuple then recurrence
                    return make_tuple(GetSlicedShape(idxs.At(num_i), shape.At(num_i)));
57
                }
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
            }
            else if constexpr(is_detected<is_slice, tuple_element_t<i.value, Tuple<Ts...>>>::value)
            {
                // calculate new dimension
                const auto& dim = size(shape.At(num_i));
                const auto val  = idxs.At(num_i).range(dim);
                return make_tuple(val);
            }
            else
            {
                // remove dimension for just value
                return Tuple<>{};
            }
        },
        Number<Tuple<Ts...>::Size()>{});
    // Remove empty tuples (deleted elements) and return
    return UnrollNestedTuple<0, 1>(new_shape);
}

/**
 * \brief Generate Freeze for each of nested shape.
 *
 * \param idx Tuple of start indices for slice.
 * \param shape Shape which will be freezed.
 * \return Generated freeze transforms.
 */
template <typename T, typename Shape>
__host__ __device__ constexpr auto GenerateMultipleFreeze(T idx, const Shape& shape)
{
    const auto unrolled_shape = UnrollNestedTuple(shape);
    return generate_tuple(
        [&](auto i) {
            // dimension offset from idx
            const auto dim     = unrolled_shape.At(Number<i>{});
            const auto dim_idx = idx % dim;
            idx /= dim;
            return make_freeze_transform(dim_idx);
        },
        Number<decltype(unrolled_shape)::Size()>{});
}
98

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
/**
 * \brief Generate transforms for slice tensor.
 *
 * \param idx Tuple of start indices for slice.
 * \param shape Shape which will be sliced.
 * \return Generated transforms.
 */
template <typename... Ts, typename Shape>
__host__ __device__ constexpr auto GenerateSliceTransforms(const Tuple<Ts...>& idx,
                                                           const Shape& shape)
{
    // Pack each value in tuple to remove empty tuples after generation
    auto transforms = generate_tuple(
        [&](auto i) {
            constexpr auto num_i = Number<i>{};
            if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
            {
                return GenerateSliceTransforms(idx.At(num_i), shape.At(num_i));
            }
            else if constexpr(is_detected<is_slice, tuple_element_t<i.value, Tuple<Ts...>>>::value)
            {

                const auto from  = idx.At(num_i).from_;
                const auto dim   = size<num_i>(shape);
                const auto range = idx.At(num_i).range(dim);
                return make_slice_transform(range, from, from + range);
            }
            else
            {
                // remove dimension for just value
                return GenerateMultipleFreeze(idx.At(num_i), shape.At(num_i));
            }
        },
        Number<Tuple<Ts...>::Size()>{});
    // Remove empty tuples (deleted elements) and return
    return UnrollNestedTuple(transforms);
}

template <index_t i, typename LowerIndex>
__host__ __device__ constexpr auto GetSequenceVal(const ck::Freeze<LowerIndex>&)
{
140
    // There is no output for Freeze transform
141
142
    return Sequence<>{};
}
143

144
145
146
147
148
template <index_t i, typename LowLength, typename SliceBegin, typename SliceEnd>
__host__ __device__ constexpr auto GetSequenceVal(const ck::Slice<LowLength, SliceBegin, SliceEnd>&)
{
    return Sequence<i>{};
}
149

150
151
152
153
154
155
156
157
158
159
160
161
162
template <index_t i>
__host__ __device__ constexpr auto GenerateUpperDims(const Tuple<>&)
{
    return Tuple<>{};
}

template <index_t i, typename... Transforms>
__host__ __device__ constexpr auto GenerateUpperDims(const Tuple<Transforms...>& transforms)
{
    constexpr auto num_transforms = Tuple<Transforms...>::Size();
    // Deduce Sequence element for specific transform
    const auto current_elem = GetSequenceVal<i>(transforms.At(Number<0>{}));
    if constexpr(is_same_v<decltype(current_elem), const Sequence<>>)
163
    {
164
165
        const auto next_tuple = GenerateUpperDims<i>(TupleSlice<1, num_transforms>(transforms));
        return concat_tuple(make_tuple(current_elem), next_tuple);
166
    }
167
    else
168
    {
169
170
171
        // Increase i if current_elem is Slice transform
        const auto next_tuple = GenerateUpperDims<i + 1>(TupleSlice<1, num_transforms>(transforms));
        return concat_tuple(make_tuple(current_elem), next_tuple);
172
    }
173
}
174

175
176
177
178
179
180
template <typename... Ts, typename Shape, typename FlattenDescriptor>
__host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>& idx,
                                                            const Shape& shape,
                                                            const FlattenDescriptor& flatten_desc)
{
    constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size();
181

182
183
    const auto transforms     = GenerateSliceTransforms(idx, shape);
    using TransformsTupleType = decltype(transforms);
184

185
186
187
188
189
190
191
    const auto lower_dims =
        generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<old_shape_dims>{});
    const auto upper_dims = decltype(GenerateUpperDims<0>(TransformsTupleType{})){};
    return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
}
} // namespace
} // namespace detail
192

193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
/**
 * \brief Tensor wrapper that performs static and dynamic buffer logic.
 * The tensor is based on a descriptor stored in the Layout. Additionally,
 * tensor can be sliced or shifted using multi-index offset.
 *
 * \tparam BufferAddressSpace Memory type (Generic, Global, LDS, VGPR, SGPR).
 * \tparam ElementType Element data type.
 * \tparam Shape Tensor shape (layout component).
 * \tparam UnrolledDescriptorType Flatten descriptor (layout component).
 */
template <MemoryTypeEnum BufferAddressSpace,
          typename ElementType,
          typename Shape,
          typename UnrolledDescriptorType>
struct Tensor
{
209
    public:
210
211
    using ElementSpaceSize  = decltype(Layout<Shape, UnrolledDescriptorType>{
        Shape{}, UnrolledDescriptorType{}}.GetElementSpaceSize()); // SpaceSize type for buffer
212
    using TensorElementType = ElementType;                          // DataType
213
214
215
216
217
218

    static constexpr MemoryTypeEnum TensorBufferAddressSpace = BufferAddressSpace;
    static constexpr bool IsDynamicBuffer = !(BufferAddressSpace == MemoryTypeEnum ::Sgpr ||
                                              BufferAddressSpace == MemoryTypeEnum ::Vgpr);

    __host__ __device__ Tensor() = delete;
219
220
    __host__ __device__ constexpr Tensor(ElementType* pointer,
                                         const Layout<Shape, UnrolledDescriptorType>& layout)
221
        : layout_(layout),
222
223
224
          buffer_(make_dynamic_buffer<BufferAddressSpace>(pointer, layout.GetElementSpaceSize())),
          multi_idx_offset_(make_zero_multi_index<Shape::Size()>()),
          base_offset_(0)
225
    {
226
        static_assert(IsDynamicBuffer, "Wrong BufferAddressSpace for register.");
227
228
    }

229
230
231
232
    __host__ __device__ constexpr Tensor(const Layout<Shape, UnrolledDescriptorType>& layout)
        : layout_(layout),
          multi_idx_offset_(make_zero_multi_index<Shape::Size()>()),
          base_offset_(0)
233
234
235
236
    {
        static_assert(!IsDynamicBuffer, "Wrong BufferAddressSpace for register.");
    }

237
    __host__ __device__ constexpr const Layout<Shape, UnrolledDescriptorType>& GetLayout() const
238
239
240
241
    {
        return layout_;
    }

242
243
244
245
246
247
248
249
    /**
     * \brief Get the new sliced tensor.
     *
     * \param idx Tuple of indices: slice(from,to) or scalar.
     * \return Sliced tensor.
     */
    template <typename... Ts, enable_if_t<detail::HasSlice(Tuple<Ts...>{}), bool> = false>
    __host__ __device__ auto operator[](const Tuple<Ts...>& idx)
250
251
    {
        static_assert(IsDynamicBuffer, "Register slice is not supported");
252
        const auto& shape = layout_.GetShape();
253
        auto new_shape    = detail::GetSlicedShape(idx, shape);
254

255
256
        const auto& flatten_desc = layout_.GetUnrolledDescriptor();
        auto new_desc            = detail::GenerateSlicedDescriptor(idx, shape, flatten_desc);
257
258
        const auto new_layout =
            Layout<decltype(new_shape), decltype(new_desc)>(new_shape, new_desc);
259
260
        // Update embed offset
        base_offset_ -= new_layout(make_tuple(Number<0>{}));
261
        return make_tensor<BufferAddressSpace>(buffer_.p_data_, new_layout);
262
263
    }

264
265
    template <typename... Ts, enable_if_t<detail::HasSlice(Tuple<Ts...>{}), bool> = false>
    __host__ __device__ auto operator()(const Tuple<Ts...>& idx)
266
267
268
269
    {
        return this->operator[](idx);
    }

270
271
    template <typename... Idxs, enable_if_t<detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
    __host__ __device__ auto operator()(Idxs... idxs)
272
273
274
275
    {
        return this->operator[](make_tuple(idxs...));
    }

276
277
278
279
280
281
282
    /**
     * \brief Getter of the tensor's const value reference.
     *
     * \param idx Tuple of indices.
     * \return Requested value.
     */
    template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
283
284
285
286
    __host__ __device__ const ElementType& operator[](const Tuple<Ts...>& idx) const
    {
        if constexpr(IsDynamicBuffer)
        {
287
            const index_t offset = layout_(idx) + base_offset_;
288
289
290
291
            return buffer_[offset];
        }
        else
        {
292
            constexpr index_t index_offset = Layout<Shape, UnrolledDescriptorType>{
293
                Shape{},
294
295
296
297
298
299
                UnrolledDescriptorType{}}.template operator()<Tuple<Ts...>>();
            // Calculate and apply base offset in compile-time
            constexpr index_t base_offset = Layout<Shape, UnrolledDescriptorType>{
                Shape{},
                UnrolledDescriptorType{}}.template operator()<MultiIndex<Shape::Size()>>();
            return buffer_[Number<index_offset + base_offset>{}];
300
301
302
        }
    }

303
    template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
304
305
306
307
308
    __host__ __device__ const ElementType& operator()(const Tuple<Ts...>& idx) const
    {
        return this->operator[](idx);
    }

309
    template <typename... Idxs, enable_if_t<!detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
310
311
312
313
314
    __host__ __device__ const ElementType& operator()(Idxs... idxs) const
    {
        return this->operator[](make_tuple(idxs...));
    }

315
316
317
318
319
320
321
    /**
     * \brief Getter of tensor value reference.
     *
     * \param idx Tuple of indices.
     * \return Requested value.
     */
    template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
322
323
324
325
    __host__ __device__ ElementType& operator[](const Tuple<Ts...>& idx)
    {
        if constexpr(IsDynamicBuffer)
        {
326
            const index_t offset = layout_(idx) + base_offset_;
327
328
329
330
            return buffer_(offset);
        }
        else
        {
331
332
333
334
335
            constexpr index_t index_offset = Layout<Shape, UnrolledDescriptorType>{
                Shape{},
                UnrolledDescriptorType{}}.template operator()<Tuple<Ts...>>();
            // Apply embed offset (calculate in compiletime)
            constexpr index_t base_offset = Layout<Shape, UnrolledDescriptorType>{
336
                Shape{},
337
338
                UnrolledDescriptorType{}}.template operator()<MultiIndex<Shape::Size()>>();
            return buffer_(Number<index_offset + base_offset>{});
339
340
341
        }
    }

342
    template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
343
344
345
346
347
    __host__ __device__ ElementType& operator()(const Tuple<Ts...>& idx)
    {
        return this->operator[](idx);
    }

348
    template <typename... Idxs, enable_if_t<!detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
349
350
351
352
353
    __host__ __device__ ElementType& operator()(Idxs... idxs)
    {
        return this->operator[](make_tuple(idxs...));
    }

354
355
356
357
358
359
    /**
     * \brief Get descriptor with all nested dimensions merged.
     *
     * \return Merged nests descriptor.
     */
    __host__ __device__ constexpr auto GetMergedNestingDescriptor()
360
    {
361
        return layout_.GetMergedNestingDescriptor();
362
363
    }

364
365
366
367
368
    /**
     * \brief Get pointer to the data.
     *
     * \return Pointer.
     */
369
370
    __host__ __device__ ElementType* GetPointer() const { return buffer_.p_data_; }

371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
    __host__ __device__ constexpr auto& GetBuffer() { return buffer_; }
    __host__ __device__ constexpr auto& GetBuffer() const { return buffer_; }

    /**
     * \brief Get multi index offset to the data.
     *
     * \return Multi index offset.
     */
    __host__ __device__ constexpr auto& GetMultiIdxOffsets() const { return multi_idx_offset_; }

    /**
     * \brief Apply multi index offset on the tensor.
     *
     * \param multi_idx_offset Multi index offset.
     */
    template <typename MultiIdxOffsets>
    __host__ __device__ constexpr void SetMultiIdxOffset(const MultiIdxOffsets multi_idx_offset)
    {
        multi_idx_offset_ = multi_idx_offset;
        base_offset_ += layout_(multi_idx_offset);
    }

393
394
395
396
397
    private:
    using DynamicBufferType = DynamicBuffer<BufferAddressSpace,
                                            ElementType,
                                            ElementSpaceSize,
                                            true /*InvalidElementUseNumericalZeroValue*/>;
398
399
400
401
    using StaticBufferType  = StaticBuffer<BufferAddressSpace,
                                          ElementType,
                                          size(Shape{}),
                                          true /*InvalidElementUseNumericalZeroValue*/>;
402
403
404
    // If register use static buffer, else use dynamic buffer
    using Buffer = std::conditional_t<IsDynamicBuffer, DynamicBufferType, StaticBufferType>;

405
    const Layout<Shape, UnrolledDescriptorType> layout_;
406
    Buffer buffer_;
407
408
409
410
411
412
413
414
415
416
417
418
419
    // We use multi_idx_offset_ to enable the creation of a descriptor in
    // compile time for partitions or tiles if tile shape and thread layout
    // is known at compile time (We can use the same descriptor for each
    // thread). Additionally, the copy between the static and dynamic buffer
    // requires a descriptor known at compile time, so we can shift data using
    // such multi_idx_offset_.
    MultiIndex<Shape::Size()> multi_idx_offset_;
    // Base offset and multi index offset are corresponding to exactly the
    // same element in tensor ( and in physical memory ). Multi index offset
    // is multi dimensional index. However base offset is calculated using
    // tensor descriptor (thus all it's transforms) and is linear (1D).
    // We store base_offset_ to avoid multiple recalculations.
    index_t base_offset_;
420
421
422
423
};

} // namespace wrapper
} // namespace ck