tensor.hpp 16.6 KB
Newer Older
1
// SPDX-License-Identifier: MIT
2
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4
5
6

#pragma once

#include "utils/tensor_utils.hpp"
7
#include "utils/tensor_partition.hpp"
8
9
#include "utils/layout_utils.hpp"

10
11
// Disable from doxygen docs generation
/// @cond INTERNAL
12
13
namespace ck {
namespace wrapper {
14
/// @endcond
15

16
17
// Disable from doxygen docs generation
/// @cond INTERNAL
18
namespace {
19
namespace detail {
20
/**
21
 * \brief Check if Tuple contains Slice object
22
 *
23
 * \return True if tuple contains Slice object.
24
 */
25
26
template <typename T>
__host__ __device__ constexpr bool HasSlice(T&&)
27
{
28
29
30
31
32
33
34
    return is_detected<is_slice, T>::value;
}
template <typename... Ts>
__host__ __device__ constexpr bool HasSlice(Tuple<Ts...>&&)
{
    return (HasSlice(Ts{}) || ...);
}
35

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
/**
 * \brief Calculate new shape after slice from parent shape.
 *
 * \param idxs Tuple of indexes defining slice ranges.
 * \param shape Shape which will be sliced.
 * \return New tensor shape.
 */
template <typename... Ts, typename SlicedShape>
__host__ __device__ constexpr auto GetSlicedShape(const Tuple<Ts...>& idxs,
                                                  const SlicedShape& shape)
{
    // Pack each value in tuple to remove empty tuples after generation
    auto new_shape = generate_tuple(
        [&](auto i) {
            constexpr auto num_i = Number<i>{};
            if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
            {
                if constexpr(!detail::HasSlice(tuple_element_t<i.value, Tuple<Ts...>>{}))
54
                {
55
                    // if tuple does not have any slice then we can remove dimension
56
57
58
59
                    return Tuple<>{};
                }
                else
                {
60
61
                    // if tuple then recurrence
                    return make_tuple(GetSlicedShape(idxs.At(num_i), shape.At(num_i)));
62
                }
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
            }
            else if constexpr(is_detected<is_slice, tuple_element_t<i.value, Tuple<Ts...>>>::value)
            {
                // calculate new dimension
                const auto& dim = size(shape.At(num_i));
                const auto val  = idxs.At(num_i).range(dim);
                return make_tuple(val);
            }
            else
            {
                // remove dimension for just value
                return Tuple<>{};
            }
        },
        Number<Tuple<Ts...>::Size()>{});
    // Remove empty tuples (deleted elements) and return
    return UnrollNestedTuple<0, 1>(new_shape);
}

/**
 * \brief Generate Freeze for each of nested shape.
 *
 * \param idx Tuple of start indices for slice.
 * \param shape Shape which will be freezed.
 * \return Generated freeze transforms.
 */
template <typename T, typename Shape>
__host__ __device__ constexpr auto GenerateMultipleFreeze(T idx, const Shape& shape)
{
    const auto unrolled_shape = UnrollNestedTuple(shape);
    return generate_tuple(
        [&](auto i) {
            // dimension offset from idx
            const auto dim     = unrolled_shape.At(Number<i>{});
            const auto dim_idx = idx % dim;
            idx /= dim;
            return make_freeze_transform(dim_idx);
        },
        Number<decltype(unrolled_shape)::Size()>{});
}
103

104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
/**
 * \brief Generate transforms for slice tensor.
 *
 * \param idx Tuple of start indices for slice.
 * \param shape Shape which will be sliced.
 * \return Generated transforms.
 */
template <typename... Ts, typename Shape>
__host__ __device__ constexpr auto GenerateSliceTransforms(const Tuple<Ts...>& idx,
                                                           const Shape& shape)
{
    // Pack each value in tuple to remove empty tuples after generation
    auto transforms = generate_tuple(
        [&](auto i) {
            constexpr auto num_i = Number<i>{};
            if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
            {
                return GenerateSliceTransforms(idx.At(num_i), shape.At(num_i));
            }
            else if constexpr(is_detected<is_slice, tuple_element_t<i.value, Tuple<Ts...>>>::value)
            {

                const auto from  = idx.At(num_i).from_;
                const auto dim   = size<num_i>(shape);
                const auto range = idx.At(num_i).range(dim);
                return make_slice_transform(range, from, from + range);
            }
            else
            {
                // remove dimension for just value
                return GenerateMultipleFreeze(idx.At(num_i), shape.At(num_i));
            }
        },
        Number<Tuple<Ts...>::Size()>{});
    // Remove empty tuples (deleted elements) and return
    return UnrollNestedTuple(transforms);
}

template <index_t i, typename LowerIndex>
__host__ __device__ constexpr auto GetSequenceVal(const ck::Freeze<LowerIndex>&)
{
145
    // There is no output for Freeze transform
146
147
    return Sequence<>{};
}
148

149
150
151
152
153
template <index_t i, typename LowLength, typename SliceBegin, typename SliceEnd>
__host__ __device__ constexpr auto GetSequenceVal(const ck::Slice<LowLength, SliceBegin, SliceEnd>&)
{
    return Sequence<i>{};
}
154

155
156
157
158
159
160
161
162
163
164
165
166
167
template <index_t i>
__host__ __device__ constexpr auto GenerateUpperDims(const Tuple<>&)
{
    return Tuple<>{};
}

template <index_t i, typename... Transforms>
__host__ __device__ constexpr auto GenerateUpperDims(const Tuple<Transforms...>& transforms)
{
    constexpr auto num_transforms = Tuple<Transforms...>::Size();
    // Deduce Sequence element for specific transform
    const auto current_elem = GetSequenceVal<i>(transforms.At(Number<0>{}));
    if constexpr(is_same_v<decltype(current_elem), const Sequence<>>)
168
    {
169
170
        const auto next_tuple = GenerateUpperDims<i>(TupleSlice<1, num_transforms>(transforms));
        return concat_tuple(make_tuple(current_elem), next_tuple);
171
    }
172
    else
173
    {
174
175
176
        // Increase i if current_elem is Slice transform
        const auto next_tuple = GenerateUpperDims<i + 1>(TupleSlice<1, num_transforms>(transforms));
        return concat_tuple(make_tuple(current_elem), next_tuple);
177
    }
178
}
179

180
template <typename... Ts, typename Shape, typename UnrolledDescriptor>
181
182
__host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>& idx,
                                                            const Shape& shape,
183
                                                            const UnrolledDescriptor& flatten_desc)
184
185
{
    constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size();
186

187
188
    const auto transforms     = GenerateSliceTransforms(idx, shape);
    using TransformsTupleType = decltype(transforms);
189

190
191
192
193
194
195
    const auto lower_dims =
        generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<old_shape_dims>{});
    const auto upper_dims = decltype(GenerateUpperDims<0>(TransformsTupleType{})){};
    return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
}
} // namespace detail
196
} // namespace
197
/// @endcond
198

199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
/**
 * \brief Tensor wrapper that performs static and dynamic buffer logic.
 * The tensor is based on a descriptor stored in the Layout. Additionally,
 * tensor can be sliced or shifted using multi-index offset.
 *
 * \tparam BufferAddressSpace Memory type (Generic, Global, LDS, VGPR, SGPR).
 * \tparam ElementType Element data type.
 * \tparam Shape Tensor shape (layout component).
 * \tparam UnrolledDescriptorType Flatten descriptor (layout component).
 */
template <MemoryTypeEnum BufferAddressSpace,
          typename ElementType,
          typename Shape,
          typename UnrolledDescriptorType>
struct Tensor
{
215
    public:
216
217
    using ElementSpaceSize  = decltype(Layout<Shape, UnrolledDescriptorType>{
        Shape{}, UnrolledDescriptorType{}}.GetElementSpaceSize()); // SpaceSize type for buffer
218
219
220
221
    using TensorElementType = std::conditional_t<
        is_scalar_type<ElementType>::value,
        ElementType,
        typename scalar_type<std::remove_const_t<ElementType>>::type>; // DataType
222
223
224
225
226
227

    static constexpr MemoryTypeEnum TensorBufferAddressSpace = BufferAddressSpace;
    static constexpr bool IsDynamicBuffer = !(BufferAddressSpace == MemoryTypeEnum ::Sgpr ||
                                              BufferAddressSpace == MemoryTypeEnum ::Vgpr);

    __host__ __device__ Tensor() = delete;
228
229
    __host__ __device__ constexpr Tensor(ElementType* pointer,
                                         const Layout<Shape, UnrolledDescriptorType>& layout)
230
        : layout_(layout),
231
232
233
          buffer_(make_dynamic_buffer<BufferAddressSpace>(pointer, layout.GetElementSpaceSize())),
          multi_idx_offset_(make_zero_multi_index<Shape::Size()>()),
          base_offset_(0)
234
    {
235
        static_assert(IsDynamicBuffer, "Wrong BufferAddressSpace for register.");
236
237
    }

238
239
240
241
    __host__ __device__ constexpr Tensor(const Layout<Shape, UnrolledDescriptorType>& layout)
        : layout_(layout),
          multi_idx_offset_(make_zero_multi_index<Shape::Size()>()),
          base_offset_(0)
242
243
244
245
    {
        static_assert(!IsDynamicBuffer, "Wrong BufferAddressSpace for register.");
    }

246
    __host__ __device__ constexpr const Layout<Shape, UnrolledDescriptorType>& GetLayout() const
247
248
249
250
    {
        return layout_;
    }

251
252
253
254
255
256
257
258
    /**
     * \brief Get the new sliced tensor.
     *
     * \param idx Tuple of indices: slice(from,to) or scalar.
     * \return Sliced tensor.
     */
    template <typename... Ts, enable_if_t<detail::HasSlice(Tuple<Ts...>{}), bool> = false>
    __host__ __device__ auto operator[](const Tuple<Ts...>& idx)
259
260
    {
        static_assert(IsDynamicBuffer, "Register slice is not supported");
261
        const auto& shape = layout_.GetShape();
262
        auto new_shape    = detail::GetSlicedShape(idx, shape);
263

264
265
        const auto& flatten_desc = layout_.GetUnrolledDescriptor();
        auto new_desc            = detail::GenerateSlicedDescriptor(idx, shape, flatten_desc);
266
267
        const auto new_layout =
            Layout<decltype(new_shape), decltype(new_desc)>(new_shape, new_desc);
268
269
        // Update embed offset
        base_offset_ -= new_layout(make_tuple(Number<0>{}));
270
        return make_tensor<BufferAddressSpace>(buffer_.p_data_, new_layout);
271
272
    }

273
274
    template <typename... Ts, enable_if_t<detail::HasSlice(Tuple<Ts...>{}), bool> = false>
    __host__ __device__ auto operator()(const Tuple<Ts...>& idx)
275
276
277
278
    {
        return this->operator[](idx);
    }

279
280
    template <typename... Idxs, enable_if_t<detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
    __host__ __device__ auto operator()(Idxs... idxs)
281
282
283
284
    {
        return this->operator[](make_tuple(idxs...));
    }

285
286
287
288
289
290
291
    /**
     * \brief Getter of the tensor's const value reference.
     *
     * \param idx Tuple of indices.
     * \return Requested value.
     */
    template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
292
    __host__ __device__ const TensorElementType& operator[](const Tuple<Ts...>& idx) const
293
294
295
    {
        if constexpr(IsDynamicBuffer)
        {
296
            const index_t offset = layout_(idx) + base_offset_;
297
298
299
300
            return buffer_[offset];
        }
        else
        {
301
            constexpr index_t index_offset = Layout<Shape, UnrolledDescriptorType>{
302
                Shape{},
303
304
305
306
307
308
                UnrolledDescriptorType{}}.template operator()<Tuple<Ts...>>();
            // Calculate and apply base offset in compile-time
            constexpr index_t base_offset = Layout<Shape, UnrolledDescriptorType>{
                Shape{},
                UnrolledDescriptorType{}}.template operator()<MultiIndex<Shape::Size()>>();
            return buffer_[Number<index_offset + base_offset>{}];
309
310
311
        }
    }

312
    template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
313
    __host__ __device__ const TensorElementType& operator()(const Tuple<Ts...>& idx) const
314
315
316
317
    {
        return this->operator[](idx);
    }

318
    template <typename... Idxs, enable_if_t<!detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
319
    __host__ __device__ const TensorElementType& operator()(Idxs... idxs) const
320
321
322
323
    {
        return this->operator[](make_tuple(idxs...));
    }

324
325
326
327
328
329
330
    /**
     * \brief Getter of tensor value reference.
     *
     * \param idx Tuple of indices.
     * \return Requested value.
     */
    template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
331
    __host__ __device__ TensorElementType& operator[](const Tuple<Ts...>& idx)
332
333
334
    {
        if constexpr(IsDynamicBuffer)
        {
335
            const index_t offset = layout_(idx) + base_offset_;
336
337
338
339
            return buffer_(offset);
        }
        else
        {
340
341
342
343
344
            constexpr index_t index_offset = Layout<Shape, UnrolledDescriptorType>{
                Shape{},
                UnrolledDescriptorType{}}.template operator()<Tuple<Ts...>>();
            // Apply embed offset (calculate in compiletime)
            constexpr index_t base_offset = Layout<Shape, UnrolledDescriptorType>{
345
                Shape{},
346
347
                UnrolledDescriptorType{}}.template operator()<MultiIndex<Shape::Size()>>();
            return buffer_(Number<index_offset + base_offset>{});
348
349
350
        }
    }

351
    template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
352
    __host__ __device__ TensorElementType& operator()(const Tuple<Ts...>& idx)
353
354
355
356
    {
        return this->operator[](idx);
    }

357
    template <typename... Idxs, enable_if_t<!detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
358
    __host__ __device__ TensorElementType& operator()(Idxs... idxs)
359
360
361
362
    {
        return this->operator[](make_tuple(idxs...));
    }

363
364
365
366
367
368
    /**
     * \brief Get descriptor with all nested dimensions merged.
     *
     * \return Merged nests descriptor.
     */
    __host__ __device__ constexpr auto GetMergedNestingDescriptor()
369
    {
370
        return layout_.GetMergedNestingDescriptor();
371
372
    }

373
374
375
376
377
    /**
     * \brief Get pointer to the data.
     *
     * \return Pointer.
     */
378
    __host__ __device__ TensorElementType* GetPointer() const { return buffer_.p_data_; }
379

380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
    __host__ __device__ constexpr auto& GetBuffer() { return buffer_; }
    __host__ __device__ constexpr auto& GetBuffer() const { return buffer_; }

    /**
     * \brief Get multi index offset to the data.
     *
     * \return Multi index offset.
     */
    __host__ __device__ constexpr auto& GetMultiIdxOffsets() const { return multi_idx_offset_; }

    /**
     * \brief Apply multi index offset on the tensor.
     *
     * \param multi_idx_offset Multi index offset.
     */
    template <typename MultiIdxOffsets>
    __host__ __device__ constexpr void SetMultiIdxOffset(const MultiIdxOffsets multi_idx_offset)
    {
        multi_idx_offset_ = multi_idx_offset;
        base_offset_ += layout_(multi_idx_offset);
    }

402
    private:
403
404
    // Disable from doxygen docs generation
    /// @cond INTERNAL
405
406
407
408
    using DynamicBufferType = DynamicBuffer<BufferAddressSpace,
                                            ElementType,
                                            ElementSpaceSize,
                                            true /*InvalidElementUseNumericalZeroValue*/>;
409
410
411
412
413
414
415
416
417
418
419
420
    using StaticBufferType  = std::conditional_t<
        is_scalar_type<ElementType>::value,
        StaticBuffer<BufferAddressSpace,
                     ElementType,
                     size(Shape{}),
                     true /*InvalidElementUseNumericalZeroValue*/>,
        StaticBufferTupleOfVector<BufferAddressSpace,
                                  TensorElementType,
                                  size(Shape{}) /
                                      scalar_type<std::remove_const_t<ElementType>>::vector_size,
                                  scalar_type<std::remove_const_t<ElementType>>::vector_size,
                                  true /*InvalidElementUseNumericalZeroValue*/>>;
421
422
423
    // If register use static buffer, else use dynamic buffer
    using Buffer = std::conditional_t<IsDynamicBuffer, DynamicBufferType, StaticBufferType>;

424
    const Layout<Shape, UnrolledDescriptorType> layout_;
425
    Buffer buffer_;
426
427
428
429
430
431
432
433
434
435
436
437
438
    // We use multi_idx_offset_ to enable the creation of a descriptor in
    // compile time for partitions or tiles if tile shape and thread layout
    // is known at compile time (We can use the same descriptor for each
    // thread). Additionally, the copy between the static and dynamic buffer
    // requires a descriptor known at compile time, so we can shift data using
    // such multi_idx_offset_.
    MultiIndex<Shape::Size()> multi_idx_offset_;
    // Base offset and multi index offset are corresponding to exactly the
    // same element in tensor ( and in physical memory ). Multi index offset
    // is multi dimensional index. However base offset is calculated using
    // tensor descriptor (thus all it's transforms) and is linear (1D).
    // We store base_offset_ to avoid multiple recalculations.
    index_t base_offset_;
439
    /// @endcond
440
441
442
443
};

} // namespace wrapper
} // namespace ck