// SPDX-License-Identifier: MIT // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "utils/tensor_utils.hpp" #include "utils/tensor_partition.hpp" #include "utils/layout_utils.hpp" namespace ck { namespace wrapper { namespace { namespace detail { /** * \brief Check if Tuple contains Slice object * * \return True if tuple contains Slice object. */ template __host__ __device__ constexpr bool HasSlice(T&&) { return is_detected::value; } template __host__ __device__ constexpr bool HasSlice(Tuple&&) { return (HasSlice(Ts{}) || ...); } /** * \brief Calculate new shape after slice from parent shape. * * \param idxs Tuple of indexes defining slice ranges. * \param shape Shape which will be sliced. * \return New tensor shape. */ template __host__ __device__ constexpr auto GetSlicedShape(const Tuple& idxs, const SlicedShape& shape) { // Pack each value in tuple to remove empty tuples after generation auto new_shape = generate_tuple( [&](auto i) { constexpr auto num_i = Number{}; if constexpr(is_detected>>::value) { if constexpr(!detail::HasSlice(tuple_element_t>{})) { // if tuple does not have any slice then we can remove dimension return Tuple<>{}; } else { // if tuple then recurrence return make_tuple(GetSlicedShape(idxs.At(num_i), shape.At(num_i))); } } else if constexpr(is_detected>>::value) { // calculate new dimension const auto& dim = size(shape.At(num_i)); const auto val = idxs.At(num_i).range(dim); return make_tuple(val); } else { // remove dimension for just value return Tuple<>{}; } }, Number::Size()>{}); // Remove empty tuples (deleted elements) and return return UnrollNestedTuple<0, 1>(new_shape); } /** * \brief Generate Freeze for each of nested shape. * * \param idx Tuple of start indices for slice. * \param shape Shape which will be freezed. * \return Generated freeze transforms. */ template __host__ __device__ constexpr auto GenerateMultipleFreeze(T idx, const Shape& shape) { const auto unrolled_shape = UnrollNestedTuple(shape); return generate_tuple( [&](auto i) { // dimension offset from idx const auto dim = unrolled_shape.At(Number{}); const auto dim_idx = idx % dim; idx /= dim; return make_freeze_transform(dim_idx); }, Number{}); } /** * \brief Generate transforms for slice tensor. * * \param idx Tuple of start indices for slice. * \param shape Shape which will be sliced. * \return Generated transforms. */ template __host__ __device__ constexpr auto GenerateSliceTransforms(const Tuple& idx, const Shape& shape) { // Pack each value in tuple to remove empty tuples after generation auto transforms = generate_tuple( [&](auto i) { constexpr auto num_i = Number{}; if constexpr(is_detected>>::value) { return GenerateSliceTransforms(idx.At(num_i), shape.At(num_i)); } else if constexpr(is_detected>>::value) { const auto from = idx.At(num_i).from_; const auto dim = size(shape); const auto range = idx.At(num_i).range(dim); return make_slice_transform(range, from, from + range); } else { // remove dimension for just value return GenerateMultipleFreeze(idx.At(num_i), shape.At(num_i)); } }, Number::Size()>{}); // Remove empty tuples (deleted elements) and return return UnrollNestedTuple(transforms); } template __host__ __device__ constexpr auto GetSequenceVal(const ck::Freeze&) { // There is no output for Freeze transform return Sequence<>{}; } template __host__ __device__ constexpr auto GetSequenceVal(const ck::Slice&) { return Sequence{}; } template __host__ __device__ constexpr auto GenerateUpperDims(const Tuple<>&) { return Tuple<>{}; } template __host__ __device__ constexpr auto GenerateUpperDims(const Tuple& transforms) { constexpr auto num_transforms = Tuple::Size(); // Deduce Sequence element for specific transform const auto current_elem = GetSequenceVal(transforms.At(Number<0>{})); if constexpr(is_same_v>) { const auto next_tuple = GenerateUpperDims(TupleSlice<1, num_transforms>(transforms)); return concat_tuple(make_tuple(current_elem), next_tuple); } else { // Increase i if current_elem is Slice transform const auto next_tuple = GenerateUpperDims(TupleSlice<1, num_transforms>(transforms)); return concat_tuple(make_tuple(current_elem), next_tuple); } } template __host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple& idx, const Shape& shape, const FlattenDescriptor& flatten_desc) { constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size(); const auto transforms = GenerateSliceTransforms(idx, shape); using TransformsTupleType = decltype(transforms); const auto lower_dims = generate_tuple([&](auto i) { return Sequence{}; }, Number{}); const auto upper_dims = decltype(GenerateUpperDims<0>(TransformsTupleType{})){}; return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims); } } // namespace detail } // namespace /** * \brief Tensor wrapper that performs static and dynamic buffer logic. * The tensor is based on a descriptor stored in the Layout. Additionally, * tensor can be sliced or shifted using multi-index offset. * * \tparam BufferAddressSpace Memory type (Generic, Global, LDS, VGPR, SGPR). * \tparam ElementType Element data type. * \tparam Shape Tensor shape (layout component). * \tparam UnrolledDescriptorType Flatten descriptor (layout component). */ template struct Tensor { public: using ElementSpaceSize = decltype(Layout{ Shape{}, UnrolledDescriptorType{}}.GetElementSpaceSize()); // SpaceSize type for buffer using TensorElementType = std::conditional_t< is_scalar_type::value, ElementType, typename scalar_type>::type>; // DataType static constexpr MemoryTypeEnum TensorBufferAddressSpace = BufferAddressSpace; static constexpr bool IsDynamicBuffer = !(BufferAddressSpace == MemoryTypeEnum ::Sgpr || BufferAddressSpace == MemoryTypeEnum ::Vgpr); __host__ __device__ Tensor() = delete; __host__ __device__ constexpr Tensor(ElementType* pointer, const Layout& layout) : layout_(layout), buffer_(make_dynamic_buffer(pointer, layout.GetElementSpaceSize())), multi_idx_offset_(make_zero_multi_index()), base_offset_(0) { static_assert(IsDynamicBuffer, "Wrong BufferAddressSpace for register."); } __host__ __device__ constexpr Tensor(const Layout& layout) : layout_(layout), multi_idx_offset_(make_zero_multi_index()), base_offset_(0) { static_assert(!IsDynamicBuffer, "Wrong BufferAddressSpace for register."); } __host__ __device__ constexpr const Layout& GetLayout() const { return layout_; } /** * \brief Get the new sliced tensor. * * \param idx Tuple of indices: slice(from,to) or scalar. * \return Sliced tensor. */ template {}), bool> = false> __host__ __device__ auto operator[](const Tuple& idx) { static_assert(IsDynamicBuffer, "Register slice is not supported"); const auto& shape = layout_.GetShape(); auto new_shape = detail::GetSlicedShape(idx, shape); const auto& flatten_desc = layout_.GetUnrolledDescriptor(); auto new_desc = detail::GenerateSlicedDescriptor(idx, shape, flatten_desc); const auto new_layout = Layout(new_shape, new_desc); // Update embed offset base_offset_ -= new_layout(make_tuple(Number<0>{})); return make_tensor(buffer_.p_data_, new_layout); } template {}), bool> = false> __host__ __device__ auto operator()(const Tuple& idx) { return this->operator[](idx); } template {}), bool> = false> __host__ __device__ auto operator()(Idxs... idxs) { return this->operator[](make_tuple(idxs...)); } /** * \brief Getter of the tensor's const value reference. * * \param idx Tuple of indices. * \return Requested value. */ template {}), bool> = false> __host__ __device__ const TensorElementType& operator[](const Tuple& idx) const { if constexpr(IsDynamicBuffer) { const index_t offset = layout_(idx) + base_offset_; return buffer_[offset]; } else { constexpr index_t index_offset = Layout{ Shape{}, UnrolledDescriptorType{}}.template operator()>(); // Calculate and apply base offset in compile-time constexpr index_t base_offset = Layout{ Shape{}, UnrolledDescriptorType{}}.template operator()>(); return buffer_[Number{}]; } } template {}), bool> = false> __host__ __device__ const TensorElementType& operator()(const Tuple& idx) const { return this->operator[](idx); } template {}), bool> = false> __host__ __device__ const TensorElementType& operator()(Idxs... idxs) const { return this->operator[](make_tuple(idxs...)); } /** * \brief Getter of tensor value reference. * * \param idx Tuple of indices. * \return Requested value. */ template {}), bool> = false> __host__ __device__ TensorElementType& operator[](const Tuple& idx) { if constexpr(IsDynamicBuffer) { const index_t offset = layout_(idx) + base_offset_; return buffer_(offset); } else { constexpr index_t index_offset = Layout{ Shape{}, UnrolledDescriptorType{}}.template operator()>(); // Apply embed offset (calculate in compiletime) constexpr index_t base_offset = Layout{ Shape{}, UnrolledDescriptorType{}}.template operator()>(); return buffer_(Number{}); } } template {}), bool> = false> __host__ __device__ TensorElementType& operator()(const Tuple& idx) { return this->operator[](idx); } template {}), bool> = false> __host__ __device__ TensorElementType& operator()(Idxs... idxs) { return this->operator[](make_tuple(idxs...)); } /** * \brief Get descriptor with all nested dimensions merged. * * \return Merged nests descriptor. */ __host__ __device__ constexpr auto GetMergedNestingDescriptor() { return layout_.GetMergedNestingDescriptor(); } /** * \brief Get pointer to the data. * * \return Pointer. */ __host__ __device__ TensorElementType* GetPointer() const { return buffer_.p_data_; } __host__ __device__ constexpr auto& GetBuffer() { return buffer_; } __host__ __device__ constexpr auto& GetBuffer() const { return buffer_; } /** * \brief Get multi index offset to the data. * * \return Multi index offset. */ __host__ __device__ constexpr auto& GetMultiIdxOffsets() const { return multi_idx_offset_; } /** * \brief Apply multi index offset on the tensor. * * \param multi_idx_offset Multi index offset. */ template __host__ __device__ constexpr void SetMultiIdxOffset(const MultiIdxOffsets multi_idx_offset) { multi_idx_offset_ = multi_idx_offset; base_offset_ += layout_(multi_idx_offset); } private: using DynamicBufferType = DynamicBuffer; using StaticBufferType = std::conditional_t< is_scalar_type::value, StaticBuffer, StaticBufferTupleOfVector>::vector_size, scalar_type>::vector_size, true /*InvalidElementUseNumericalZeroValue*/>>; // If register use static buffer, else use dynamic buffer using Buffer = std::conditional_t; const Layout layout_; Buffer buffer_; // We use multi_idx_offset_ to enable the creation of a descriptor in // compile time for partitions or tiles if tile shape and thread layout // is known at compile time (We can use the same descriptor for each // thread). Additionally, the copy between the static and dynamic buffer // requires a descriptor known at compile time, so we can shift data using // such multi_idx_offset_. MultiIndex multi_idx_offset_; // Base offset and multi index offset are corresponding to exactly the // same element in tensor ( and in physical memory ). Multi index offset // is multi dimensional index. However base offset is calculated using // tensor descriptor (thus all it's transforms) and is linear (1D). // We store base_offset_ to avoid multiple recalculations. index_t base_offset_; }; } // namespace wrapper } // namespace ck