#ifndef CK_CONSTANT_MERGED_TENSOR_DESCRIPTOR_HPP #define CK_CONSTANT_MERGED_TENSOR_DESCRIPTOR_HPP #include "common.hpp" #include "ConstantTensorDescriptor.hpp" namespace ck { // OriginalTensorDesc : ConstantTensorDescriptor<...> // it's the tensor whose dimensions are to be merged // OriginalDimMergeSeqs : Sequence<...>... // each is a sequence of original dimensions (of OriginalTensorDesc) to be merged template struct ConstantMergedTensorDescriptor { using Type = ConstantMergedTensorDescriptor; static constexpr auto mOriginalDimMergeSeqs = std::tuple{}; static constexpr index_t nDim = sizeof...(OriginalDimMergeSeqs); static constexpr index_t nOriginalDim = OriginalTensorDesc::GetNumOfDimension(); __host__ __device__ constexpr ConstantMergedTensorDescriptor() { static_assert(nDim <= nOriginalDim, "wrong!"); // TODO: check each of OriginalDimMergeSeqs contains at least 1, and at most // OriginalTensorDesc::nDim number of dimensions // TODO: check OriginalDimMergeSeqs contains all original dimensions // TODO: check there is no duplication in OriginalDimMergeSeqs } __host__ __device__ static constexpr auto GetOriginalTensorDescriptor() { return OriginalTensorDesc{}; } __host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; } template __host__ __device__ static constexpr auto GetContainedOriginalDimensions(Number) { return std::get(mOriginalDimMergeSeqs); } template __host__ __device__ static constexpr bool ContainMultipleOriginalDimensions(Number) { return (std::get(mOriginalDimMergeSeqs).GetSize() > 1); } template __host__ __device__ static constexpr index_t GetLength(Number) { constexpr auto original_dims_partial = std::get(mOriginalDimMergeSeqs); return OriginalTensorDesc::Extract(original_dims_partial).GetElementSize(); } template __host__ __device__ static constexpr index_t GetStride(Number) { static_assert(!ContainMultipleOriginalDimensions(Number{}), "wrong! stride of a merged dimension is undefined"); constexpr auto idim_original = std::get(mOriginalDimMergeSeqs).Front(); return OriginalTensorDesc::GetStride(Number{}); } __host__ __device__ static constexpr auto GetLengths() { return Sequence{}; } __host__ __device__ static constexpr index_t GetElementSize() { return OriginalTensorDesc::GetElementSize(); } template struct lambda_1_GetOriginalMultiIndexFromMultiIndex { const Array& original_multi_id_partial; Array& original_multi_id; __host__ __device__ constexpr lambda_1_GetOriginalMultiIndexFromMultiIndex( const Array& original_multi_id_partial_, Array& original_multi_id_) : original_multi_id_partial(original_multi_id_partial_), original_multi_id(original_multi_id_) { } template __host__ __device__ constexpr void operator()(Number) const { constexpr index_t idim_original = OriginalDimsPartial::Get(Number{}); index_t itmp = original_multi_id_partial[I]; original_multi_id.Set(Number{}, itmp); } }; struct lambda_0_GetOriginalMultiIndexFromMultiIndex { const Array& multi_id; Array& original_multi_id; __host__ __device__ constexpr lambda_0_GetOriginalMultiIndexFromMultiIndex( const Array& multi_id_, Array& original_multi_id_) : multi_id(multi_id_), original_multi_id(original_multi_id_) { } template __host__ __device__ constexpr void operator()(Number) const { constexpr auto original_dims_partial = std::get(Type::mOriginalDimMergeSeqs); // get partial original-multi-id corresponding to this merged dimension const auto original_multi_id_partial = OriginalTensorDesc::Extract(original_dims_partial) .GetMultiIndexFrom1dIndex(multi_id[IDim]); static_for<0, original_dims_partial.GetSize(), 1>{}( lambda_1_GetOriginalMultiIndexFromMultiIndex( original_multi_id_partial, original_multi_id)); } }; // return type is Array<...> __host__ __device__ static constexpr auto GetOriginalMultiIndexFromMultiIndex(Array multi_id) { Array original_multi_id; static_for<0, nDim, 1>{}( lambda_0_GetOriginalMultiIndexFromMultiIndex(multi_id, original_multi_id)); return original_multi_id; } template __host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Sequence) { constexpr auto multi_id = sequence2array(Sequence{}); constexpr auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id); return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id); } __host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Array multi_id) { auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id); return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id); } template __host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Is... is) { return GetOffsetFromMultiIndex(Array{is...}); } __host__ __device__ static constexpr Array GetMultiIndexFrom1dIndex(index_t id) { constexpr auto packed_desc = make_ConstantTensorDescriptor_packed(GetLengths()); return packed_desc.GetMultiIndexFrom1dIndex(id); } }; template __host__ __device__ constexpr auto make_ConstantMergedTensorDescriptor(OriginalTensorDesc, OriginalDimMergeSeqs...) { return ConstantMergedTensorDescriptor{}; } template __host__ __device__ void print_ConstantMergedTensorDescriptor(const char* s, TDesc) { print_ConstantTensorDescriptor(s, TDesc::GetOriginalTensorDescriptor()); } } // namespace ck #endif