#pragma once #include "common.hip.hpp" #include "ConstantTensorDescriptor.hip.hpp" // OriginalTensorDesc : ConstantTensorDescriptor<...> // it's the tensor whose dimensions are to be merged // OriginalDimMergeSeqs : Sequence<...>... // each is a sequence of original dimensions (of OriginalTensorDesc) to be merged template struct ConstantMergedTensorDescriptor { static constexpr auto mOriginalDimMergeSeqs = std::tuple{}; static constexpr index_t nDim = sizeof...(OriginalDimMergeSeqs); static constexpr index_t nOriginalDim = OriginalTensorDesc::GetNumOfDimension(); __host__ __device__ constexpr ConstantMergedTensorDescriptor() { static_assert(nDim <= nOriginalDim, "wrong!"); // TODO: check each of OriginalDimMergeSeqs contains at least 1, and at most // OriginalTensorDesc::nDim number of dimensions // TODO: check OriginalDimMergeSeqs contains all original dimensions // TODO: check there is no duplication in OriginalDimMergeSeqs } __host__ __device__ static constexpr auto GetOriginalTensorDescriptor() { return OriginalTensorDesc{}; } __host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; } template __host__ __device__ static constexpr auto GetContainedOriginalDimensions(Number) { return std::get(mOriginalDimMergeSeqs); } template __host__ __device__ static constexpr bool ContainMultipleOriginalDimensions(Number) { return (std::get(mOriginalDimMergeSeqs).GetSize() > 1); } template __host__ __device__ static constexpr index_t GetLength(Number) { constexpr auto original_dims_partial = std::get(mOriginalDimMergeSeqs); return OriginalTensorDesc::Extract(original_dims_partial).GetElementSize(); } template __host__ __device__ static constexpr index_t GetStride(Number) { static_assert(!ContainMultipleOriginalDimensions(Number{}), "wrong! stride of a merged dimension is undefined"); constexpr auto idim_original = std::get(mOriginalDimMergeSeqs).Front(); return OriginalTensorDesc::GetStride(Number{}); } __host__ __device__ static constexpr auto GetLengths() { return Sequence{}; } __host__ __device__ static constexpr index_t GetElementSize() { return OriginalTensorDesc::GetElementSize(); } #if 0 __host__ __device__ static constexpr auto GetOriginalMultiIndexFromMultiIndex(Array multi_id) { Array original_multi_id; static_for<0, nDim, 1>{}([&](auto IDim) { constexpr index_t idim = IDim.Get(); constexpr auto original_dims_partial = std::get(mOriginalDimMergeSeqs); // get partial original-multi-id corresponding to this merged dimension const auto original_multi_id_partial = OriginalTensorDesc::Extract(original_dims_partial) .GetMultiIndexFrom1dIndex(multi_id[idim]); static_for<0, original_dims_partial.GetSize(), 1>{}([&](auto I_) { constexpr auto I = decltype(I_){}; constexpr index_t idim_original = original_dims_partial.Get(I); original_multi_id[idim_original] = original_multi_id_partial[I.Get()]; }); }); return original_multi_id; } #else template struct GetOriginalMultiIndexFromMultiIndex_impl1 { const Array& original_multi_id_partial_ref; Array& original_multi_id_ref; __host__ __device__ constexpr GetOriginalMultiIndexFromMultiIndex_impl1( const Array& original_multi_id_partial, Array& original_multi_id) : original_multi_id_partial_ref(original_multi_id_partial), original_multi_id_ref(original_multi_id) { } template constexpr __host__ __device__ bool operator()(Number) const { constexpr index_t idim_original = OriginalDimsPartial::Get(Number{}); index_t itmp = original_multi_id_partial_ref.Get(Number{}); original_multi_id_ref.Set(Number{}, itmp); return true; } }; struct GetOriginalMultiIndexFromMultiIndex_impl0 { const Array& multi_id_ref; Array& original_multi_id_ref; __host__ __device__ constexpr GetOriginalMultiIndexFromMultiIndex_impl0( const Array& multi_id, Array& original_multi_id) : multi_id_ref(multi_id), original_multi_id_ref(original_multi_id) { } template constexpr __host__ __device__ bool operator()(Number) const { constexpr auto original_dims_partial = std::get(std::tuple{}); // get partial original-multi-id corresponding to this merged dimension const auto original_multi_id_partial = OriginalTensorDesc::Extract(original_dims_partial) .GetMultiIndexFrom1dIndex(multi_id_ref[IDim]); static_for<0, original_dims_partial.GetSize(), 1>{}( GetOriginalMultiIndexFromMultiIndex_impl1( original_multi_id_partial, original_multi_id_ref)); return true; } }; __host__ __device__ static constexpr auto GetOriginalMultiIndexFromMultiIndex(Array multi_id) { Array original_multi_id; static_for<0, nDim, 1>{}( GetOriginalMultiIndexFromMultiIndex_impl0(multi_id, original_multi_id)); return original_multi_id; } template __host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Sequence) { constexpr auto multi_id = sequence2array(Sequence{}); constexpr auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id); return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id); } #endif #if 0 // return type is Sequence<...> template __host__ __device__ static constexpr auto GetOriginalMultiIndexFromMultiIndex(Sequence) { // not implemented return Sequence<>{}; } #endif __host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Array multi_id) { auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id); return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id); } template __host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Is... is) { return GetOffsetFromMultiIndex(Array{is...}); } __host__ __device__ static constexpr Array GetMultiIndexFrom1dIndex(index_t id) { constexpr auto dummy_desc = make_ConstantTensorDescriptor_packed(GetLengths()); return dummy_desc.GetMultiIndexFrom1dIndex(id); } }; template __host__ __device__ constexpr auto make_ConstantMergedTensorDescriptor(OriginalTensorDesc, OriginalDimMergeSeqs...) { return ConstantMergedTensorDescriptor{}; } template __host__ __device__ void print_ConstantMergedTensorDescriptor(const char* s, TDesc) { print_ConstantTensorDescriptor(s, TDesc::GetOriginalTensorDescriptor()); }