// SPDX-License-Identifier: MIT // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "utils/tensor_utils.hpp" #include "utils/tensor_partition.hpp" #include "utils/layout_utils.hpp" namespace ck { namespace wrapper { /** * \brief Tensor wrapper that performs static and dynamic buffer logic. * * \tparam BufferAddressSpace Memory type (Generic, Global, LDS, VGPR, SGPR). * \tparam ElementType Element data type. * \tparam Shape Tensor shape (layout component). * \tparam UnnestedDescriptorType Unnested descriptor (layout component). * \tparam NumVectors Number of vectors (only for VGPR, SGPR). * \tparam ScalarPerVector Scalars per vector (only for VGPR, SGPR). */ template struct Tensor { private: // Check if Tuple contains Slice object template __host__ __device__ constexpr static bool IsSlicing(T&&) { return is_detected::value; } template __host__ __device__ constexpr static bool IsSlicing(Tuple&&) { return (IsSlicing(Ts{}) || ...); } // Calculate new tensor shape after slice template __host__ __device__ constexpr auto GetShapeFromSlicedTensor(const Tuple& idx, const ShapeTmpType& shape) const { // Pack each value in tuple to remove empty tuples after generation auto new_shape = generate_tuple( [&](auto i) { constexpr auto num_i = Number{}; if constexpr(is_detected>>::value) { if constexpr(!IsSlicing(tuple_element_t>{})) { // if tuple does not have any slice then we can remove dimension return Tuple<>{}; } else { // if tuple then recurrence return make_tuple(GetShapeFromSlicedTensor(idx.At(num_i), shape.At(num_i))); } } else if constexpr(is_detected>>::value) { // calculate new dimension const auto& dim = size(shape.At(num_i)); const auto val = idx.At(num_i).range(dim); return make_tuple(val); } else { // remove dimension for just value return Tuple<>{}; } }, Number::Size()>{}); // Remove empty tuples (deleted elements) and return return UnrollNestedTuple<0, 1>(new_shape); } // Generate Freeze for each of nested shape template __host__ __device__ constexpr auto GenerateMultipleFreeze(T idx, const ShapeTmpType& shape) const { const auto unrolled_shape = UnrollNestedTuple(shape); return generate_tuple( [&](auto i) { // dimension offset from idx const auto dim = unrolled_shape.At(Number{}); const auto dim_idx = idx % dim; idx /= dim; return make_freeze_transform(dim_idx); }, Number{}); } template __host__ __device__ constexpr auto GetTransformsFromSlicedTensor(const Tuple& idx, const ShapeTmpType& shape) const { // Pack each value in tuple to remove empty tuples after generation auto transforms = generate_tuple( [&](auto i) { constexpr auto num_i = Number{}; if constexpr(is_detected>>::value) { return GetTransformsFromSlicedTensor(idx.At(num_i), shape.At(num_i)); } else if constexpr(is_detected>>::value) { const auto from = idx.At(num_i).from_; const auto dim = shape.At(num_i); const auto range = idx.At(num_i).range(dim); return make_slice_transform(range, from, from + range); } else { // remove dimension for just value return GenerateMultipleFreeze(idx.At(num_i), shape.At(num_i)); } }, Number::Size()>{}); // Remove empty tuples (deleted elements) and return return UnrollNestedTuple(transforms); } // There is no output for Freeze transform template __host__ __device__ constexpr auto GetSequenceVal(const ck::Freeze&) const { return Sequence<>{}; } template __host__ __device__ constexpr auto GetSequenceVal(const ck::Slice&) const { return Sequence{}; } template __host__ __device__ constexpr auto GenerateUpperDims(const Tuple<>&) const { return Tuple<>{}; } template __host__ __device__ constexpr auto GenerateUpperDims(const Tuple& transforms) const { constexpr auto num_transforms = Tuple::Size(); // Deduce Sequence element for specific transform const auto currect_elem = GetSequenceVal(transforms.At(Number<0>{})); if constexpr(is_same_v>) { const auto next_tuple = GenerateUpperDims(TupleSlice<1, num_transforms>(transforms)); return concat_tuple(make_tuple(currect_elem), next_tuple); } else { // Increase i if current_elem is Slice transform const auto next_tuple = GenerateUpperDims(TupleSlice<1, num_transforms>(transforms)); return concat_tuple(make_tuple(currect_elem), next_tuple); } } template __host__ __device__ constexpr auto GetDescriptorFromSlicedTensor(const Tuple& idx, const ShapeTmpType& shape, const FlattenDescriptor& flatten_desc) const { constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size(); const auto transforms = GetTransformsFromSlicedTensor(idx, shape); using TransformsTupleType = decltype(transforms); const auto lower_dims = generate_tuple([&](auto i) { return Sequence{}; }, Number{}); const auto upper_dims = decltype(GenerateUpperDims<0>(TransformsTupleType{})){}; return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims); } public: using ElementSpaceSize = decltype(Layout{ Shape{}, UnnestedDescriptorType{}}.GetElementSpaceSize()); // SpaceSize type for buffer using TensorElementType = ElementType; // DataType static constexpr MemoryTypeEnum TensorBufferAddressSpace = BufferAddressSpace; static constexpr bool IsDynamicBuffer = !(BufferAddressSpace == MemoryTypeEnum ::Sgpr || BufferAddressSpace == MemoryTypeEnum ::Vgpr); __host__ __device__ Tensor() = delete; __host__ __device__ Tensor(ElementType* pointer, const Layout& layout) : layout_(layout), buffer_(make_dynamic_buffer(pointer, layout.GetElementSpaceSize())) { } __host__ __device__ Tensor(const Layout& layout) : layout_(layout) { static_assert(!IsDynamicBuffer, "Wrong BufferAddressSpace for register."); } __host__ __device__ constexpr const Layout& GetLayout() const { return layout_; } // Getter for new sliced tensor template {}), bool> = false> __host__ __device__ auto operator[](const Tuple& idx) const { static_assert(IsDynamicBuffer, "Register slice is not supported"); const auto& shape = layout_.GetShape(); auto new_shape = GetShapeFromSlicedTensor(idx, shape); const auto& flatten_desc = layout_.GetUnnestedDescriptor(); auto new_desc = GetDescriptorFromSlicedTensor(idx, shape, flatten_desc); const auto new_layout = Layout(new_shape, new_desc); return make_tensor(buffer_.p_data_, new_layout); } template {}), bool> = false> __host__ __device__ auto operator()(const Tuple& idx) const { return this->operator[](idx); } template {}), bool> = false> __host__ __device__ auto operator()(Idxs... idxs) const { return this->operator[](make_tuple(idxs...)); } // Getter for the const value template {}), bool> = false> __host__ __device__ const ElementType& operator[](const Tuple& idx) const { if constexpr(IsDynamicBuffer) { const index_t offset = layout_(idx); return buffer_[offset]; } else { constexpr index_t offset = Layout{ Shape{}, UnnestedDescriptorType{}}.template operator()>(); return buffer_[Number{}]; } } template {}), bool> = false> __host__ __device__ const ElementType& operator()(const Tuple& idx) const { return this->operator[](idx); } template {}), bool> = false> __host__ __device__ const ElementType& operator()(Idxs... idxs) const { return this->operator[](make_tuple(idxs...)); } // Getter for the value reference template {}), bool> = false> __host__ __device__ ElementType& operator[](const Tuple& idx) { if constexpr(IsDynamicBuffer) { const index_t offset = layout_(idx); return buffer_(offset); } else { constexpr index_t offset = Layout{ Shape{}, UnnestedDescriptorType{}}.template operator()>(); return buffer_(Number{}); } } template {}), bool> = false> __host__ __device__ ElementType& operator()(const Tuple& idx) { return this->operator[](idx); } template {}), bool> = false> __host__ __device__ ElementType& operator()(Idxs... idxs) { return this->operator[](make_tuple(idxs...)); } __host__ __device__ constexpr auto GetDefaultDescriptor() { return layout_.GetDefaultDescriptor(); } __host__ __device__ ElementType* GetPointer() const { return buffer_.p_data_; } private: using DynamicBufferType = DynamicBuffer; using StaticBufferType = StaticBufferTupleOfVector; // If register use static buffer, else use dynamic buffer using Buffer = std::conditional_t; const Layout layout_; Buffer buffer_; }; } // namespace wrapper } // namespace ck