tensor.hpp 16.1 KB
Newer Older
1
// SPDX-License-Identifier: MIT
2
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4
5
6

#pragma once

#include "utils/tensor_utils.hpp"
7
#include "utils/tensor_partition.hpp"
8
9
#include "utils/layout_utils.hpp"

10
11
// Disable from doxygen docs generation
/// @cond INTERNAL
12
13
namespace ck {
namespace wrapper {
14
/// @endcond
15

16
17
18
// Disable from doxygen docs generation
/// @cond INTERNAL
namespace {
19
20
namespace detail {
namespace {
21
/**
22
 * \brief Check if Tuple contains Slice object
23
 *
24
 * \return True if tuple contains Slice object.
25
 */
26
27
template <typename T>
__host__ __device__ constexpr bool HasSlice(T&&)
28
{
29
30
31
32
33
34
35
    return is_detected<is_slice, T>::value;
}
template <typename... Ts>
__host__ __device__ constexpr bool HasSlice(Tuple<Ts...>&&)
{
    return (HasSlice(Ts{}) || ...);
}
36

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
/**
 * \brief Calculate new shape after slice from parent shape.
 *
 * \param idxs Tuple of indexes defining slice ranges.
 * \param shape Shape which will be sliced.
 * \return New tensor shape.
 */
template <typename... Ts, typename SlicedShape>
__host__ __device__ constexpr auto GetSlicedShape(const Tuple<Ts...>& idxs,
                                                  const SlicedShape& shape)
{
    // Pack each value in tuple to remove empty tuples after generation
    auto new_shape = generate_tuple(
        [&](auto i) {
            constexpr auto num_i = Number<i>{};
            if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
            {
                if constexpr(!detail::HasSlice(tuple_element_t<i.value, Tuple<Ts...>>{}))
55
                {
56
                    // if tuple does not have any slice then we can remove dimension
57
58
59
60
                    return Tuple<>{};
                }
                else
                {
61
62
                    // if tuple then recurrence
                    return make_tuple(GetSlicedShape(idxs.At(num_i), shape.At(num_i)));
63
                }
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
            }
            else if constexpr(is_detected<is_slice, tuple_element_t<i.value, Tuple<Ts...>>>::value)
            {
                // calculate new dimension
                const auto& dim = size(shape.At(num_i));
                const auto val  = idxs.At(num_i).range(dim);
                return make_tuple(val);
            }
            else
            {
                // remove dimension for just value
                return Tuple<>{};
            }
        },
        Number<Tuple<Ts...>::Size()>{});
    // Remove empty tuples (deleted elements) and return
    return UnrollNestedTuple<0, 1>(new_shape);
}

/**
 * \brief Generate Freeze for each of nested shape.
 *
 * \param idx Tuple of start indices for slice.
 * \param shape Shape which will be freezed.
 * \return Generated freeze transforms.
 */
template <typename T, typename Shape>
__host__ __device__ constexpr auto GenerateMultipleFreeze(T idx, const Shape& shape)
{
    const auto unrolled_shape = UnrollNestedTuple(shape);
    return generate_tuple(
        [&](auto i) {
            // dimension offset from idx
            const auto dim     = unrolled_shape.At(Number<i>{});
            const auto dim_idx = idx % dim;
            idx /= dim;
            return make_freeze_transform(dim_idx);
        },
        Number<decltype(unrolled_shape)::Size()>{});
}
104

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
/**
 * \brief Generate transforms for slice tensor.
 *
 * \param idx Tuple of start indices for slice.
 * \param shape Shape which will be sliced.
 * \return Generated transforms.
 */
template <typename... Ts, typename Shape>
__host__ __device__ constexpr auto GenerateSliceTransforms(const Tuple<Ts...>& idx,
                                                           const Shape& shape)
{
    // Pack each value in tuple to remove empty tuples after generation
    auto transforms = generate_tuple(
        [&](auto i) {
            constexpr auto num_i = Number<i>{};
            if constexpr(is_detected<is_tuple, tuple_element_t<i.value, Tuple<Ts...>>>::value)
            {
                return GenerateSliceTransforms(idx.At(num_i), shape.At(num_i));
            }
            else if constexpr(is_detected<is_slice, tuple_element_t<i.value, Tuple<Ts...>>>::value)
            {

                const auto from  = idx.At(num_i).from_;
                const auto dim   = size<num_i>(shape);
                const auto range = idx.At(num_i).range(dim);
                return make_slice_transform(range, from, from + range);
            }
            else
            {
                // remove dimension for just value
                return GenerateMultipleFreeze(idx.At(num_i), shape.At(num_i));
            }
        },
        Number<Tuple<Ts...>::Size()>{});
    // Remove empty tuples (deleted elements) and return
    return UnrollNestedTuple(transforms);
}

template <index_t i, typename LowerIndex>
__host__ __device__ constexpr auto GetSequenceVal(const ck::Freeze<LowerIndex>&)
{
146
    // There is no output for Freeze transform
147
148
    return Sequence<>{};
}
149

150
151
152
153
154
template <index_t i, typename LowLength, typename SliceBegin, typename SliceEnd>
__host__ __device__ constexpr auto GetSequenceVal(const ck::Slice<LowLength, SliceBegin, SliceEnd>&)
{
    return Sequence<i>{};
}
155

156
157
158
159
160
161
162
163
164
165
166
167
168
template <index_t i>
__host__ __device__ constexpr auto GenerateUpperDims(const Tuple<>&)
{
    return Tuple<>{};
}

template <index_t i, typename... Transforms>
__host__ __device__ constexpr auto GenerateUpperDims(const Tuple<Transforms...>& transforms)
{
    constexpr auto num_transforms = Tuple<Transforms...>::Size();
    // Deduce Sequence element for specific transform
    const auto current_elem = GetSequenceVal<i>(transforms.At(Number<0>{}));
    if constexpr(is_same_v<decltype(current_elem), const Sequence<>>)
169
    {
170
171
        const auto next_tuple = GenerateUpperDims<i>(TupleSlice<1, num_transforms>(transforms));
        return concat_tuple(make_tuple(current_elem), next_tuple);
172
    }
173
    else
174
    {
175
176
177
        // Increase i if current_elem is Slice transform
        const auto next_tuple = GenerateUpperDims<i + 1>(TupleSlice<1, num_transforms>(transforms));
        return concat_tuple(make_tuple(current_elem), next_tuple);
178
    }
179
}
180

181
182
183
184
185
186
template <typename... Ts, typename Shape, typename FlattenDescriptor>
__host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>& idx,
                                                            const Shape& shape,
                                                            const FlattenDescriptor& flatten_desc)
{
    constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size();
187

188
189
    const auto transforms     = GenerateSliceTransforms(idx, shape);
    using TransformsTupleType = decltype(transforms);
190

191
192
193
194
195
196
    const auto lower_dims =
        generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<old_shape_dims>{});
    const auto upper_dims = decltype(GenerateUpperDims<0>(TransformsTupleType{})){};
    return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
}
} // namespace
197
<<<<<<< HEAD
198
} // namespace detail
199
200
201
=======
/// @endcond
>>>>>>> 42fc8eddd (Fix warnings during wrapper docs generation (#1192))
202

203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
/**
 * \brief Tensor wrapper that performs static and dynamic buffer logic.
 * The tensor is based on a descriptor stored in the Layout. Additionally,
 * tensor can be sliced or shifted using multi-index offset.
 *
 * \tparam BufferAddressSpace Memory type (Generic, Global, LDS, VGPR, SGPR).
 * \tparam ElementType Element data type.
 * \tparam Shape Tensor shape (layout component).
 * \tparam UnrolledDescriptorType Flatten descriptor (layout component).
 */
template <MemoryTypeEnum BufferAddressSpace,
          typename ElementType,
          typename Shape,
          typename UnrolledDescriptorType>
struct Tensor
{
219
    public:
220
221
    using ElementSpaceSize  = decltype(Layout<Shape, UnrolledDescriptorType>{
        Shape{}, UnrolledDescriptorType{}}.GetElementSpaceSize()); // SpaceSize type for buffer
222
    using TensorElementType = ElementType;                          // DataType
223
224
225
226
227
228

    static constexpr MemoryTypeEnum TensorBufferAddressSpace = BufferAddressSpace;
    static constexpr bool IsDynamicBuffer = !(BufferAddressSpace == MemoryTypeEnum ::Sgpr ||
                                              BufferAddressSpace == MemoryTypeEnum ::Vgpr);

    __host__ __device__ Tensor() = delete;
229
230
    __host__ __device__ constexpr Tensor(ElementType* pointer,
                                         const Layout<Shape, UnrolledDescriptorType>& layout)
231
        : layout_(layout),
232
233
234
          buffer_(make_dynamic_buffer<BufferAddressSpace>(pointer, layout.GetElementSpaceSize())),
          multi_idx_offset_(make_zero_multi_index<Shape::Size()>()),
          base_offset_(0)
235
    {
236
        static_assert(IsDynamicBuffer, "Wrong BufferAddressSpace for register.");
237
238
    }

239
240
241
242
    __host__ __device__ constexpr Tensor(const Layout<Shape, UnrolledDescriptorType>& layout)
        : layout_(layout),
          multi_idx_offset_(make_zero_multi_index<Shape::Size()>()),
          base_offset_(0)
243
244
245
246
    {
        static_assert(!IsDynamicBuffer, "Wrong BufferAddressSpace for register.");
    }

247
    __host__ __device__ constexpr const Layout<Shape, UnrolledDescriptorType>& GetLayout() const
248
249
250
251
    {
        return layout_;
    }

252
253
254
255
256
257
258
259
    /**
     * \brief Get the new sliced tensor.
     *
     * \param idx Tuple of indices: slice(from,to) or scalar.
     * \return Sliced tensor.
     */
    template <typename... Ts, enable_if_t<detail::HasSlice(Tuple<Ts...>{}), bool> = false>
    __host__ __device__ auto operator[](const Tuple<Ts...>& idx)
260
261
    {
        static_assert(IsDynamicBuffer, "Register slice is not supported");
262
        const auto& shape = layout_.GetShape();
263
        auto new_shape    = detail::GetSlicedShape(idx, shape);
264

265
266
        const auto& flatten_desc = layout_.GetUnrolledDescriptor();
        auto new_desc            = detail::GenerateSlicedDescriptor(idx, shape, flatten_desc);
267
268
        const auto new_layout =
            Layout<decltype(new_shape), decltype(new_desc)>(new_shape, new_desc);
269
270
        // Update embed offset
        base_offset_ -= new_layout(make_tuple(Number<0>{}));
271
        return make_tensor<BufferAddressSpace>(buffer_.p_data_, new_layout);
272
273
    }

274
275
    template <typename... Ts, enable_if_t<detail::HasSlice(Tuple<Ts...>{}), bool> = false>
    __host__ __device__ auto operator()(const Tuple<Ts...>& idx)
276
277
278
279
    {
        return this->operator[](idx);
    }

280
281
    template <typename... Idxs, enable_if_t<detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
    __host__ __device__ auto operator()(Idxs... idxs)
282
283
284
285
    {
        return this->operator[](make_tuple(idxs...));
    }

286
287
288
289
290
291
292
    /**
     * \brief Getter of the tensor's const value reference.
     *
     * \param idx Tuple of indices.
     * \return Requested value.
     */
    template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
293
294
295
296
    __host__ __device__ const ElementType& operator[](const Tuple<Ts...>& idx) const
    {
        if constexpr(IsDynamicBuffer)
        {
297
            const index_t offset = layout_(idx) + base_offset_;
298
299
300
301
            return buffer_[offset];
        }
        else
        {
302
            constexpr index_t index_offset = Layout<Shape, UnrolledDescriptorType>{
303
                Shape{},
304
305
306
307
308
309
                UnrolledDescriptorType{}}.template operator()<Tuple<Ts...>>();
            // Calculate and apply base offset in compile-time
            constexpr index_t base_offset = Layout<Shape, UnrolledDescriptorType>{
                Shape{},
                UnrolledDescriptorType{}}.template operator()<MultiIndex<Shape::Size()>>();
            return buffer_[Number<index_offset + base_offset>{}];
310
311
312
        }
    }

313
    template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
314
315
316
317
318
    __host__ __device__ const ElementType& operator()(const Tuple<Ts...>& idx) const
    {
        return this->operator[](idx);
    }

319
    template <typename... Idxs, enable_if_t<!detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
320
321
322
323
324
    __host__ __device__ const ElementType& operator()(Idxs... idxs) const
    {
        return this->operator[](make_tuple(idxs...));
    }

325
326
327
328
329
330
331
    /**
     * \brief Getter of tensor value reference.
     *
     * \param idx Tuple of indices.
     * \return Requested value.
     */
    template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
332
333
334
335
    __host__ __device__ ElementType& operator[](const Tuple<Ts...>& idx)
    {
        if constexpr(IsDynamicBuffer)
        {
336
            const index_t offset = layout_(idx) + base_offset_;
337
338
339
340
            return buffer_(offset);
        }
        else
        {
341
342
343
344
345
            constexpr index_t index_offset = Layout<Shape, UnrolledDescriptorType>{
                Shape{},
                UnrolledDescriptorType{}}.template operator()<Tuple<Ts...>>();
            // Apply embed offset (calculate in compiletime)
            constexpr index_t base_offset = Layout<Shape, UnrolledDescriptorType>{
346
                Shape{},
347
348
                UnrolledDescriptorType{}}.template operator()<MultiIndex<Shape::Size()>>();
            return buffer_(Number<index_offset + base_offset>{});
349
350
351
        }
    }

352
    template <typename... Ts, enable_if_t<!detail::HasSlice(Tuple<Ts...>{}), bool> = false>
353
354
355
356
357
    __host__ __device__ ElementType& operator()(const Tuple<Ts...>& idx)
    {
        return this->operator[](idx);
    }

358
    template <typename... Idxs, enable_if_t<!detail::HasSlice(Tuple<Idxs...>{}), bool> = false>
359
360
361
362
363
    __host__ __device__ ElementType& operator()(Idxs... idxs)
    {
        return this->operator[](make_tuple(idxs...));
    }

364
365
366
367
368
369
    /**
     * \brief Get descriptor with all nested dimensions merged.
     *
     * \return Merged nests descriptor.
     */
    __host__ __device__ constexpr auto GetMergedNestingDescriptor()
370
    {
371
        return layout_.GetMergedNestingDescriptor();
372
373
    }

374
375
376
377
378
    /**
     * \brief Get pointer to the data.
     *
     * \return Pointer.
     */
379
380
    __host__ __device__ ElementType* GetPointer() const { return buffer_.p_data_; }

381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
    __host__ __device__ constexpr auto& GetBuffer() { return buffer_; }
    __host__ __device__ constexpr auto& GetBuffer() const { return buffer_; }

    /**
     * \brief Get multi index offset to the data.
     *
     * \return Multi index offset.
     */
    __host__ __device__ constexpr auto& GetMultiIdxOffsets() const { return multi_idx_offset_; }

    /**
     * \brief Apply multi index offset on the tensor.
     *
     * \param multi_idx_offset Multi index offset.
     */
    template <typename MultiIdxOffsets>
    __host__ __device__ constexpr void SetMultiIdxOffset(const MultiIdxOffsets multi_idx_offset)
    {
        multi_idx_offset_ = multi_idx_offset;
        base_offset_ += layout_(multi_idx_offset);
    }

403
    private:
404
405
    // Disable from doxygen docs generation
    /// @cond INTERNAL
406
407
408
409
    using DynamicBufferType = DynamicBuffer<BufferAddressSpace,
                                            ElementType,
                                            ElementSpaceSize,
                                            true /*InvalidElementUseNumericalZeroValue*/>;
410
411
412
413
    using StaticBufferType  = StaticBuffer<BufferAddressSpace,
                                          ElementType,
                                          size(Shape{}),
                                          true /*InvalidElementUseNumericalZeroValue*/>;
414
415
416
    // If register use static buffer, else use dynamic buffer
    using Buffer = std::conditional_t<IsDynamicBuffer, DynamicBufferType, StaticBufferType>;

417
    const Layout<Shape, UnrolledDescriptorType> layout_;
418
    Buffer buffer_;
419
420
421
422
423
424
425
426
427
428
429
430
431
    // We use multi_idx_offset_ to enable the creation of a descriptor in
    // compile time for partitions or tiles if tile shape and thread layout
    // is known at compile time (We can use the same descriptor for each
    // thread). Additionally, the copy between the static and dynamic buffer
    // requires a descriptor known at compile time, so we can shift data using
    // such multi_idx_offset_.
    MultiIndex<Shape::Size()> multi_idx_offset_;
    // Base offset and multi index offset are corresponding to exactly the
    // same element in tensor ( and in physical memory ). Multi index offset
    // is multi dimensional index. However base offset is calculated using
    // tensor descriptor (thus all it's transforms) and is linear (1D).
    // We store base_offset_ to avoid multiple recalculations.
    index_t base_offset_;
432
    /// @endcond
433
434
435
436
};

} // namespace wrapper
} // namespace ck