#pragma once #include "common.hip.hpp" #include "ConstantTensorDescriptor.hip.hpp" // OriginalTensorDesc : ConstantTensorDescriptor<...> // it's the tensor whose dimensions are to be merged // OriginalDimMergeSeqs : Sequence<...>... // each is a sequence of original dimensions (of OriginalTensorDesc) to be merged template struct ConstantMergedTensorDescriptor { static constexpr auto mOriginalDimMergeSeqs = std::tuple{}; static constexpr index_t nDim = std::tuple_size::value; static constexpr index_t nOriginalDim = OriginalDesc::GetNumOfDimension(); __host__ __device__ constexpr ConstantMergedTensorDescriptor() { static_assert(nDim <= nOriginalDim, "wrong!"); // TODO: check each of OriginalDimMergeSeqs contains at least 1, and at most // OriginalTensorDesc::nDim number of dimensions // TODO: check there is no duplication in OriginalDimMergeSeqs // TODO: check OriginalDimMergeSeqs contains all original dimensions } __host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; } __host__ __device__ static constexpr index_t GetNumOfOriginalDimension() { return nOriginalDim } template __host__ __device__ static constexpr bool ContainMultipleOriginalDimensions(Number) { return (std::Get(mOriginalDimMergeSeqs).GetSize() > 1); } template __host__ __device__ static constexpr index_t GetLength(Number) { constexpr auto original_dims_partial = std::Get(mOriginalDimMergeSeqs); return OriginalTensorDesc::Extract(original_dims_partial).GetElementSize(); } template __host__ __device__ static constexpr index_t GetStride(Number) { static_assert(!ContainMultipleOriginalDimensions(Number{}), "wrong! stride of a merged dimension is undefined"); constexpr auto idim_original = std::Get(mOriginalDimMergeSeqs).Front(); return OriginalTensorDesc::GetStride(Number{}); } __host__ __device__ static constexpr auto GetLengths() { return Sequence{}; } __host__ __device__ static constexpr index_t GetElementSize() { return OriginalTensorDesc::GetElementSize(); } __host__ __device__ static auto GetOriginalMultiIndexFromMultiIndex(Array multi_id) { Array original_multi_id; static_for<0, nDim, 1>{}([&](auto IDim) { constexpr index_t idim = IDim.Get(); constexpr auto original_dims_partial = std::get(mOriginalDimMergeSeqs); // get partial original-multi-id corresponding to this merged dimension constexpr auto original_multi_id_partial = OriginalTensorDesc::Extract(original_dims_partial) .GetMultiIndexFrom1dIndex(multi_id[idim]); // make sure compiler unroll this loop and propagate all the constants for(index_t i = 0; i < original_dims_partial.GetSize(); ++i) { index_t idim_original = original_dims_partial[i]; original_multi_id[idim_original] = original_multi_id_partial[i] } }); return original_multi_id; } __host__ __device__ static index_t GetOffsetFromMultiIndex(Array multi_id) { const auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id); return OriginalTensorDesc::GetOffsetFromMultiIndex(orginal_multi_id); } template __host__ __device__ static index_t GetOffsetFromMultiIndex(Is... is) { return GetOffsetFromMultiIndex(Array{is...}); } __host__ __device__ static Array GetMultiIndexFrom1dIndex(index_t id) { constexpr auto dummy_desc = make_packed_ConstantTensorDescriptor(GetLengths()); return dummy_desc.GetMultiIndexFrom1dIndex(id); } }; template constexpr auto make_ConstantMergedTensorDescriptor(OriginalTensorDesc, OriginalDimMergeSeqs...) { return ConstantMergedTensorDescriptor{}; }