tensor.hpp 16.4 KB
Newer Older
1
// SPDX-License-Identifier: MIT
2
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4
5
6

#pragma once

#include "utils/tensor_utils.hpp"
7
#include "utils/tensor_partition.hpp"
8
9
10
11
12
#include "utils/layout_utils.hpp"

namespace ck {
namespace wrapper {

13
namespace {
14
namespace detail {
15
/**
16
 * \brief Check if Tuple contains Slice object
17
 *
18
 * \return True if tuple contains Slice object.
19
 */
20
21
template <typename T>
__host__ __device__ constexpr bool HasSlice(T&&)
22
{
23
24
25
26
27
28
29
    return is_detected<is_slice, T>::value;
}
template <typename... Ts>
__host__ __device__ constexpr bool HasSlice(Tuple<Ts...>&&)
{
    return (HasSlice(Ts{}) || ...);
}
30

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
/**
 * \brief Calculate new shape after slice from parent shape.
 *
 * \param idxs Tuple of indexes defining slice ranges.
 * \param shape Shape which will be sliced.
 * \return New tensor shape.
 */
template <typename... Ts, typename SlicedShape>
__host__ __device__ constexpr auto GetSlicedShape(const Tuple<Ts...>& idxs,
                                                  const SlicedShape& shape)
{
    // Pack each value in tuple to remove empty tuples after generation
    auto new_shape = generate_tuple(
        [&](auto i) {
            constexpr auto num_i = Number<i>{};
            if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
            {
                if constexpr(!detail::HasSlice(tuple_element_t<i.value, Tuple<Ts...>>{}))
49
                {
50
                    // if tuple does not have any slice then we can remove dimension
51
52
53
54
                    return Tuple<>{};
                }
                else
                {
55
56
                    // if tuple then recurrence
                    return make_tuple(GetSlicedShape(idxs.At(num_i), shape.At(num_i)));
57
                }
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
            }
            else if constexpr(is_detected<is_slice, tuple_element_t<i.value, Tuple<Ts...>>>::value)
            {
                // calculate new dimension
                const auto& dim = size(shape.At(num_i));
                const auto val  = idxs.At(num_i).range(dim);
                return make_tuple(val);
            }
            else
            {
                // remove dimension for just value
                return Tuple<>{};
            }
        },
        Number<Tuple<Ts...>::Size()>{});
    // Remove empty tuples (deleted elements) and return
    return UnrollNestedTuple<0, 1>(new_shape);
}

/**
 * \brief Generate Freeze for each of nested shape.
 *
 * \param idx Tuple of start indices for slice.
 * \param shape Shape which will be freezed.
 * \return Generated freeze transforms.
 */
template <typename T, typename Shape>
__host__ __device__ constexpr auto GenerateMultipleFreeze(T idx, const Shape& shape)
{
    const auto unrolled_shape = UnrollNestedTuple(shape);
    return generate_tuple(
        [&](auto i) {
            // dimension offset from idx
            const auto dim     = unrolled_shape.At(Number<i>{});
            const auto dim_idx = idx % dim;
            idx /= dim;
            return make_freeze_transform(dim_idx);
        },
        Number<decltype(unrolled_shape)::Size()>{});
}
98

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
/**
 * \brief Generate transforms for slice tensor.
 *
 * \param idx Tuple of start indices for slice.
 * \param shape Shape which will be sliced.
 * \return Generated transforms.
 */
template <typename... Ts, typename Shape>
__host__ __device__ constexpr auto GenerateSliceTransforms(const Tuple<Ts...>& idx,
                                                           const Shape& shape)
{
    // Pack each value in tuple to remove empty tuples after generation
    auto transforms = generate_tuple(
        [&](auto i) {
            constexpr auto num_i = Number<i>{};
            if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
            {
                return GenerateSliceTransforms(idx.At(num_i), shape.At(num_i));
            }
            else if constexpr(is_detected<is_slice, tuple_element_t<i.value, Tuple<Ts...>>>::value)
            {

                const auto from  = idx.At(num_i).from_;
                const auto dim   = size<num_i>(shape);
                const auto range = idx.At(num_i).range(dim);
                return make_slice_transform(range, from, from + range);
            }
            else
            {
                // remove dimension for just value
                return GenerateMultipleFreeze(idx.At(num_i), shape.At(num_i));
            }
        },
        Number<Tuple<Ts...>::Size()>{});
    // Remove empty tuples (deleted elements) and return
    return UnrollNestedTuple(transforms);
}

template <index_t i, typename LowerIndex>
__host__ __device__ constexpr auto GetSequenceVal(const ck::Freeze<LowerIndex>&)
{
140
    // There is no output for Freeze transform
141
142
    return Sequence<>{};
}
143

144
145
146
147
148
template <index_t i, typename LowLength, typename SliceBegin, typename SliceEnd>
__host__ __device__ constexpr auto GetSequenceVal(const ck::Slice<LowLength, SliceBegin, SliceEnd>&)
{
    return Sequence<i>{};
}
149

150
151
152
153
154
155
156
157
158
159
160
161
162
template <index_t i>
__host__ __device__ constexpr auto GenerateUpperDims(const Tuple<>&)
{
    return Tuple<>{};
}

template <index_t i, typename... Transforms>
__host__ __device__ constexpr auto GenerateUpperDims(const Tuple<Transforms...>& transforms)
{
    constexpr auto num_transforms = Tuple<Transforms...>::Size();
    // Deduce Sequence element for specific transform
    const auto current_elem = GetSequenceVal<i>(transforms.At(Number<0>{}));
    if constexpr(is_same_v<decltype(current_elem), const Sequence<>>)
163
    {
164
165
        const auto next_tuple = GenerateUpperDims<i>(TupleSlice<1, num_transforms>(transforms));
        return concat_tuple(make_tuple(current_elem), next_tuple);
166
    }
167
    else
168
    {
169
170
171
        // Increase i if current_elem is Slice transform
        const auto next_tuple = GenerateUpperDims<i + 1>(TupleSlice<1, num_transforms>(transforms));
        return concat_tuple(make_tuple(current_elem), next_tuple);
172
    }
173
}
174

175
176
177
178
179
180
template <typename... Ts, typename Shape, typename FlattenDescriptor>
__host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>& idx,
                                                            const Shape& shape,
                                                            const FlattenDescriptor& flatten_desc)
{
    constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size();
181

182
183
    const auto transforms     = GenerateSliceTransforms(idx, shape);
    using TransformsTupleType = decltype(transforms);
184

185
186
187
188
189
190
    const auto lower_dims =
        generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<old_shape_dims>{});
    const auto upper_dims = decltype(GenerateUpperDims<0>(TransformsTupleType{})){};
    return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
}
} // namespace detail
191
} // namespace
192

193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
/**
 * \brief Tensor wrapper that performs static and dynamic buffer logic.
 * The tensor is based on a descriptor stored in the Layout. Additionally,
 * tensor can be sliced or shifted using multi-index offset.
 *
 * \tparam BufferAddressSpace Memory type (Generic, Global, LDS, VGPR, SGPR).
 * \tparam ElementType Element data type.
 * \tparam Shape Tensor shape (layout component).
 * \tparam UnrolledDescriptorType Flatten descriptor (layout component).
 */
template <MemoryTypeEnum BufferAddressSpace,
          typename ElementType,
          typename Shape,
          typename UnrolledDescriptorType>
struct Tensor
{
209
    public:
210
211
    using ElementSpaceSize  = decltype(Layout<Shape, UnrolledDescriptorType>{
        Shape{}, UnrolledDescriptorType{}}.GetElementSpaceSize()); // SpaceSize type for buffer
212
213
214
215
    using TensorElementType = std::conditional_t<
        is_scalar_type<ElementType>::value,
        ElementType,
        typename scalar_type<std::remove_const_t<ElementType>>::type>; // DataType
216
217
218
219
220
221

    static constexpr MemoryTypeEnum TensorBufferAddressSpace = BufferAddressSpace;
    static constexpr bool IsDynamicBuffer = !(BufferAddressSpace == MemoryTypeEnum ::Sgpr ||
                                              BufferAddressSpace == MemoryTypeEnum ::Vgpr);

    __host__ __device__ Tensor() = delete;
222
223
    __host__ __device__ constexpr Tensor(ElementType* pointer,
                                         const Layout<Shape, UnrolledDescriptorType>& layout)
224
        : layout_(layout),
225
226
227
          buffer_(make_dynamic_buffer<BufferAddressSpace>(pointer, layout.GetElementSpaceSize())),
          multi_idx_offset_(make_zero_multi_index<Shape::Size()>()),
          base_offset_(0)
228
    {
229
        static_assert(IsDynamicBuffer, "Wrong BufferAddressSpace for register.");
230
231
    }

232
233
234
235
    __host__ __device__ constexpr Tensor(const Layout<Shape, UnrolledDescriptorType>& layout)
        : layout_(layout),
          multi_idx_offset_(make_zero_multi_index<Shape::Size()>()),
          base_offset_(0)
236
237
238
239
    {
        static_assert(!IsDynamicBuffer, "Wrong BufferAddressSpace for register.");
    }

240
    __host__ __device__ constexpr const Layout<Shape, UnrolledDescriptorType>& GetLayout() const
241
242
243
244
    {
        return layout_;
    }

245
246
247
248
249
250
251
252
    /**
     * \brief Get the new sliced tensor.
     *
     * \param idx Tuple of indices: slice(from,to) or scalar.
     * \return Sliced tensor.
     */
    template <typename... Ts, enable_if_t<detail::HasSlice(Tuple<Ts...>{}), bool> = false>
    __host__ __device__ auto operator[](const Tuple<Ts...>& idx)
253
254
    {
        static_assert(IsDynamicBuffer, "Register slice is not supported");
255
        const auto& shape = layout_.GetShape();
256
        auto new_shape    = detail::GetSlicedShape(idx, shape);
257

258
259
        const auto& flatten_desc = layout_.GetUnrolledDescriptor();
        auto new_desc            = detail::GenerateSlicedDescriptor(idx, shape, flatten_desc);
260
261
        const auto new_layout =
            Layout<decltype(new_shape), decltype(new_desc)>(new_shape, new_desc);
262
263
        // Update embed offset
        base_offset_ -= new_layout(make_tuple(Number<0>{}));
264
        return make_tensor<BufferAddressSpace>(buffer_.p_data_, new_layout);
265
266
    }

267
268
    template <typename... Ts, enable_if_t<detail::HasSlice(Tuple<Ts...>{}), bool> = false>
    __host__ __device__ auto operator()(const Tuple<Ts...>& idx)
269
270
271
272
    {
        return this->operator[](idx);
    }

273
274
    template <typename... Idxs, enable_if_t<detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
    __host__ __device__ auto operator()(Idxs... idxs)
275
276
277
278
    {
        return this->operator[](make_tuple(idxs...));
    }

279
280
281
282
283
284
285
    /**
     * \brief Getter of the tensor's const value reference.
     *
     * \param idx Tuple of indices.
     * \return Requested value.
     */
    template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
286
    __host__ __device__ const TensorElementType& operator[](const Tuple<Ts...>& idx) const
287
288
289
    {
        if constexpr(IsDynamicBuffer)
        {
290
            const index_t offset = layout_(idx) + base_offset_;
291
292
293
294
            return buffer_[offset];
        }
        else
        {
295
            constexpr index_t index_offset = Layout<Shape, UnrolledDescriptorType>{
296
                Shape{},
297
298
299
300
301
302
                UnrolledDescriptorType{}}.template operator()<Tuple<Ts...>>();
            // Calculate and apply base offset in compile-time
            constexpr index_t base_offset = Layout<Shape, UnrolledDescriptorType>{
                Shape{},
                UnrolledDescriptorType{}}.template operator()<MultiIndex<Shape::Size()>>();
            return buffer_[Number<index_offset + base_offset>{}];
303
304
305
        }
    }

306
    template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
307
    __host__ __device__ const TensorElementType& operator()(const Tuple<Ts...>& idx) const
308
309
310
311
    {
        return this->operator[](idx);
    }

312
    template <typename... Idxs, enable_if_t<!detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
313
    __host__ __device__ const TensorElementType& operator()(Idxs... idxs) const
314
315
316
317
    {
        return this->operator[](make_tuple(idxs...));
    }

318
319
320
321
322
323
324
    /**
     * \brief Getter of tensor value reference.
     *
     * \param idx Tuple of indices.
     * \return Requested value.
     */
    template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
325
    __host__ __device__ TensorElementType& operator[](const Tuple<Ts...>& idx)
326
327
328
    {
        if constexpr(IsDynamicBuffer)
        {
329
            const index_t offset = layout_(idx) + base_offset_;
330
331
332
333
            return buffer_(offset);
        }
        else
        {
334
335
336
337
338
            constexpr index_t index_offset = Layout<Shape, UnrolledDescriptorType>{
                Shape{},
                UnrolledDescriptorType{}}.template operator()<Tuple<Ts...>>();
            // Apply embed offset (calculate in compiletime)
            constexpr index_t base_offset = Layout<Shape, UnrolledDescriptorType>{
339
                Shape{},
340
341
                UnrolledDescriptorType{}}.template operator()<MultiIndex<Shape::Size()>>();
            return buffer_(Number<index_offset + base_offset>{});
342
343
344
        }
    }

345
    template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
346
    __host__ __device__ TensorElementType& operator()(const Tuple<Ts...>& idx)
347
348
349
350
    {
        return this->operator[](idx);
    }

351
    template <typename... Idxs, enable_if_t<!detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
352
    __host__ __device__ TensorElementType& operator()(Idxs... idxs)
353
354
355
356
    {
        return this->operator[](make_tuple(idxs...));
    }

357
358
359
360
361
362
    /**
     * \brief Get descriptor with all nested dimensions merged.
     *
     * \return Merged nests descriptor.
     */
    __host__ __device__ constexpr auto GetMergedNestingDescriptor()
363
    {
364
        return layout_.GetMergedNestingDescriptor();
365
366
    }

367
368
369
370
371
    /**
     * \brief Get pointer to the data.
     *
     * \return Pointer.
     */
372
    __host__ __device__ TensorElementType* GetPointer() const { return buffer_.p_data_; }
373

374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
    __host__ __device__ constexpr auto& GetBuffer() { return buffer_; }
    __host__ __device__ constexpr auto& GetBuffer() const { return buffer_; }

    /**
     * \brief Get multi index offset to the data.
     *
     * \return Multi index offset.
     */
    __host__ __device__ constexpr auto& GetMultiIdxOffsets() const { return multi_idx_offset_; }

    /**
     * \brief Apply multi index offset on the tensor.
     *
     * \param multi_idx_offset Multi index offset.
     */
    template <typename MultiIdxOffsets>
    __host__ __device__ constexpr void SetMultiIdxOffset(const MultiIdxOffsets multi_idx_offset)
    {
        multi_idx_offset_ = multi_idx_offset;
        base_offset_ += layout_(multi_idx_offset);
    }

396
397
398
399
400
    private:
    using DynamicBufferType = DynamicBuffer<BufferAddressSpace,
                                            ElementType,
                                            ElementSpaceSize,
                                            true /*InvalidElementUseNumericalZeroValue*/>;
401
402
403
404
405
406
407
408
409
410
411
412
    using StaticBufferType  = std::conditional_t<
        is_scalar_type<ElementType>::value,
        StaticBuffer<BufferAddressSpace,
                     ElementType,
                     size(Shape{}),
                     true /*InvalidElementUseNumericalZeroValue*/>,
        StaticBufferTupleOfVector<BufferAddressSpace,
                                  TensorElementType,
                                  size(Shape{}) /
                                      scalar_type<std::remove_const_t<ElementType>>::vector_size,
                                  scalar_type<std::remove_const_t<ElementType>>::vector_size,
                                  true /*InvalidElementUseNumericalZeroValue*/>>;
413
414
415
    // If register use static buffer, else use dynamic buffer
    using Buffer = std::conditional_t<IsDynamicBuffer, DynamicBufferType, StaticBufferType>;

416
    const Layout<Shape, UnrolledDescriptorType> layout_;
417
    Buffer buffer_;
418
419
420
421
422
423
424
425
426
427
428
429
430
    // We use multi_idx_offset_ to enable the creation of a descriptor in
    // compile time for partitions or tiles if tile shape and thread layout
    // is known at compile time (We can use the same descriptor for each
    // thread). Additionally, the copy between the static and dynamic buffer
    // requires a descriptor known at compile time, so we can shift data using
    // such multi_idx_offset_.
    MultiIndex<Shape::Size()> multi_idx_offset_;
    // Base offset and multi index offset are corresponding to exactly the
    // same element in tensor ( and in physical memory ). Multi index offset
    // is multi dimensional index. However base offset is calculated using
    // tensor descriptor (thus all it's transforms) and is linear (1D).
    // We store base_offset_ to avoid multiple recalculations.
    index_t base_offset_;
431
432
433
434
};

} // namespace wrapper
} // namespace ck