#pragma once #include "common.hip.hpp" template __host__ __device__ constexpr auto calculate_tensor_strides_default_rank_packed(Lengths) { return reverse_inclusive_scan_sequence(Lengths{}.PopFront(), std::multiplies{}) .PushBack(Number<1>{}); } template __host__ __device__ constexpr auto calculate_tensor_strides_default_rank_aligned(Lengths, Number) { constexpr index_t L_back_align = Align * mod_conv::integer_divide_ceiler{}(Lengths{}.Back(), Align); return calculate_tensor_strides_default_rank_packed( Lengths{}.Modify(Number{}, Number{})); } // MemoryRanks of dimensions is for conversion from offset to multi-index template struct ConstantTensorDescriptor { using Type = ConstantTensorDescriptor; static constexpr index_t nDim = Lengths::GetSize(); __host__ __device__ constexpr ConstantTensorDescriptor() { static_assert(Lengths::GetSize() == Strides::GetSize() && Lengths::GetSize() == MemoryRanks::GetSize(), "nDim not consistent"); #if 0 // require sequence_sort, but it's not implemented yet static_assert(is_same::SortedSeqType, typename arithmetic_sequence_gen<0, nDim, 1>::SeqType>::value, "wrong! invalid MemoryRanks"); #endif } __host__ __device__ static constexpr auto GetOriginalTensorDescriptor() { return Type{}; } template __host__ __device__ static constexpr auto GetContainedOriginalDimensions(Number) { return Sequence{}; } __host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; } __host__ __device__ static constexpr auto GetLengths() { return Lengths{}; } __host__ __device__ static constexpr auto GetStrides() { return Strides{}; } __host__ __device__ static constexpr auto GetMemoryRanks() { return MemoryRanks{}; } template __host__ __device__ static constexpr index_t GetLength(Number) { return Lengths{}.Get(Number{}); } template __host__ __device__ static constexpr index_t GetStride(Number) { return Strides{}.Get(Number{}); } template __host__ __device__ static constexpr index_t GetMemoryRank(Number) { return MemoryRanks{}.Get(Number{}); } __host__ __device__ static constexpr bool AreStridesNonAscending() { bool flag = true; static_for<0, nDim - 1, 1>{}([&](auto IDim) { constexpr auto IDim_p1 = Number{}; flag = flag && (GetLength(IDim) >= GetLength(IDim_p1)); }); return flag; } template __host__ __device__ static constexpr bool ContainMultipleOriginalDimensions(T) { return false; } __host__ __device__ static constexpr index_t GetElementSize() { return accumulate_on_sequence(Lengths{}, std::multiplies{}, Number<1>{}); } // WRONG! ReorderGivenOld2New is broken template > __host__ __device__ static constexpr index_t GetElementSpace(Align align = Align{}) { #if 0 constexpr auto lengths_in_rank = GetLengths().ReorderGivenOld2New(MemoryRank{}); constexpr auto strides_in_rank = GetStrides().ReorderGivenOld2new(MemoryRank{}); constexpr index_t element_space_unaligned = accumulate_on_sequence( (lengths_in_rank - Number<1>{}) * strides_in_rank, std::plus{}, Number<1>{}); #else // WRONG! align shouldbe applied to the last memory rank, not the last tensor dimension constexpr index_t element_space_unaligned = accumulate_on_sequence( (GetLengths() - Number<1>{}) * GetStrides(), std::plus{}, Number<1>{}); #endif return align.Get() * ((element_space_unaligned + align.Get() - 1) / align.Get()); } template __host__ __device__ static index_t GetOffsetFromMultiIndex(Array multi_id) { static_assert(NSize == nDim, "wrong! Dimension not consistent"); index_t offset = 0; static_for<0, nDim, 1>{}([&](auto IDim) { constexpr index_t idim = IDim.Get(); offset += multi_id[idim] * GetStride(IDim); }); return offset; } template __host__ __device__ static index_t GetOffsetFromMultiIndex(Is... is) { return GetOffsetFromMultiIndex(Array{is...}); } template __host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Sequence) { static_assert(sizeof...(Is) == nDim, "wrong! Dimension not consistent"); constexpr auto multi_id = Sequence{}; return accumulate_on_sequence(multi_id * GetStrides(), std::plus{}, Number<0>{}); } #if 0 // ReorderGivenOld2new is broken __host__ __device__ static Array GetMultiIndexFromOffset(index_t offset) { Array ranked_multi_id; constexpr auto ranked_strides = GetStrides().ReorderGivenOld2new(MemoryRanks{}); // check this // calculate index in each of the dimensions in the order of their rank (not dimension) static_for<0, nDim - 1, 1>{}([&](auto IDim) { constexpr index_t idim = IDim.Get(); constexpr index_t stride = ranked_strides.Get(Number{}); ranked_multi_id[idim] = offset / stride; offset -= ranked_multi_id[idim] * stride; }); ranked_multi_id[nDim - 1] = offset / ranked_strides.Get(Number{}); return reorder_array_given_new2old(ranked_multi_id, MemoryRanks{}); // check this } #endif __host__ __device__ static Array GetMultiIndexFrom1dIndex(index_t id) { Array multi_id; constexpr auto dummy_strides = calculate_tensor_strides_default_rank_packed(GetLengths()); // calculate index in each of the dimensions in the order of their dimension (not rank) static_for<0, nDim - 1, 1>{}([&](auto IDim) { constexpr index_t idim = IDim.Get(); constexpr index_t stride = dummy_strides.Get(Number{}); multi_id[idim] = id / stride; id -= multi_id[idim] * stride; }); multi_id[nDim - 1] = id / dummy_strides.Get(Number{}); return multi_id; } __host__ __device__ static auto GetOriginalMultiIndexFromMultiIndex(Array multi_id) { return multi_id; } // This function doesn't do carry check on the highest dimension, for performance reason. // It is the user's responsibility to make sure the result "new_mutli_id" is not out-of-bound // on the highest dimension __host__ __device__ static Array UpdateMultiIndexGivenStepSizeOf1dIndex(Array old_multi_id, index_t step_size_of_1d_index) { auto new_multi_id = old_multi_id + GetMultiIndexFrom1dIndex(step_size_of_1d_index); bool carry = false; // do carry check in reversed order, starting from lowest dimension // don't check the highest dimension static_for<0, nDim - 1, 1>{}([&](auto IDimReverse) { constexpr index_t idim = nDim - 1 - IDimReverse.Get(); constexpr auto IDim = Number{}; if(carry) { ++new_multi_id[idim]; } carry = false; if(new_multi_id[idim] >= GetLength(IDim)) { new_multi_id[idim] -= GetLength(IDim); carry = true; } }); return new_multi_id; } // WRONG! Ranks is broken template __host__ __device__ static constexpr auto Extract(Number... extract_dims) { static_assert(sizeof...(IDims) <= GetNumOfDimension(), "wrong! too many number of dimensions to be extracted"); using extract_lengths = decltype(Lengths{}.Extract(extract_dims...)); using extract_strides = decltype(Strides{}.Extract(extract_dims...)); using extract_ranks = decltype(MemoryRanks{}.Extract(extract_dims...)); #if 0 using new_ranks = typename sequence_sort::Original2SortedType; #else // WRONG! TODO:: implement sequence_sort using new_ranks = typename arithmetic_sequence_gen<0, sizeof...(IDims), 1>::SeqType; #endif return ConstantTensorDescriptor{}; } template __host__ __device__ static constexpr auto Extract(Sequence) { return Extract(Number{}...); } template __host__ __device__ static constexpr auto Inject(ConstantTensorDescriptor) { using leaf_tensor = ConstantTensorDescriptor; // memory rank is broken // TODO: remove memory rank info from tensor descritpor return ConstantTensorDescriptor{}; } template __host__ __device__ static constexpr auto Slice(Number, Number) { using slice_lengths = decltype(Lengths{}.Modify(Number{}, Number{})); return ConstantTensorDescriptor{}; } template struct f_fold_impl { __host__ __device__ constexpr index_t operator()(index_t x) const { return x > Threashold ? x + Delta : x; } }; template __host__ __device__ static constexpr auto Fold(Number, Number...) { constexpr auto fold_intervals = Sequence{}; constexpr index_t fold_intervals_product = accumulate_on_sequence(fold_intervals, std::multiplies{}, Number<1>{}); constexpr auto unfold_length = GetLength(Number{}); constexpr auto unfold_stride = GetStride(Number{}); constexpr auto unfold_rank = GetMemoryRank(Number{}); // length of the dimension to be folded needs to be dividable by fold_interval_product, // otherwise, folding is invalid static_assert(unfold_length % fold_intervals_product == 0, "wrong! length on the dimension to be folded cannot be evenly divided!"); // folded lengths constexpr auto fold_lengths = Sequence{}.Append(fold_intervals); // folded strides constexpr auto fold_strides = Number{} * reverse_inclusive_scan_sequence(fold_intervals.PushBack(Number<1>{}), std::multiplies{}); // folded_ranks constexpr auto fold_ranks = typename arithmetic_sequence_gen::SeqType{}; // increase the ranks that are larger than unfold_rank constexpr auto tmp_ranks = transform_sequences( f_fold_impl{}, GetMemoryRanks()); // left and right constexpr auto left = typename arithmetic_sequence_gen<0, IDim, 1>::SeqType{}; constexpr auto right = typename arithmetic_sequence_gen::SeqType{}; constexpr auto new_lengths = GetLengths().Extract(left).Append(fold_lengths).Append(GetLengths().Extract(right)); constexpr auto new_strides = GetStrides().Extract(left).Append(fold_strides).Append(GetStrides().Extract(right)); constexpr auto new_ranks = tmp_ranks.Extract(left).Append(fold_ranks).Append(tmp_ranks.Extract(right)); static_assert(new_ranks.GetSize() == new_lengths.GetSize(), "wrong!"); static_assert(fold_ranks.GetSize() == fold_lengths.GetSize(), "wrong!"); return ConstantTensorDescriptor{}; } template struct f_unfold_impl { __host__ __device__ constexpr index_t operator()(index_t x) const { return x > Threashold ? x - Delta : x; } }; template __host__ __device__ static constexpr auto Unfold(Number, Number) { static_assert(FirstUnfoldDim >= 0 && LastUnfoldDim < nDim && FirstUnfoldDim <= LastUnfoldDim, "wrong! should have FirstUnfoldDim <= LastUnfoldDim!"); #if 0 // cannot compile: compiler complain about constexpr // dimensions to be unfold need to be in descending order (w.r.t. strides), and need to be // packed in memory, otherwise, unfolding is invalid static_for{}([&](auto IDim_) { constexpr auto IDim = decltype(IDim_){}; constexpr auto IDim_p1 = IDim + Number<1>{}; // check stride static_assert( GetStride(IDim) >= GetStride(IDim_p1), "wrong! dimensions to be unfolded need to be in descending order w.r.t strides"); // check if packed static_assert(GetStride(IDim_p1) * GetLength(IDim_p1) == GetStride(IDim), "wrong! dimensions to be unfolded need to be packed"); // check ranks static_assert(GetMemoryRank(IDim_p1) == GetMemoryRank(IDim) + 1, "wrong! ranks of dimensions to be unfolded need to be in increasing and " "continuous ranks"); }); #endif // left and right constexpr auto left = typename arithmetic_sequence_gen<0, FirstUnfoldDim, 1>::SeqType{}; constexpr auto middle = typename arithmetic_sequence_gen::SeqType{}; constexpr auto right = typename arithmetic_sequence_gen::SeqType{}; // unfolded length, stride and rank constexpr index_t unfold_length = accumulate_on_sequence( GetLengths().Extract(middle), std::multiplies{}, Number<1>{}); constexpr index_t unfold_stride = GetStride(Number{}); constexpr index_t unfold_rank = GetMemoryRank(Number{}); // decrease the ranks that are larger than the rank of LastUnfoldDim constexpr auto tmp_ranks = transform_sequences(f_unfold_impl{}), LastUnfoldDim - FirstUnfoldDim + 1>{}, GetMemoryRanks()); // new lengths, strides and ranks constexpr auto new_lengths = GetLengths() .Extract(left) .PushBack(Number{}) .Append(GetLengths().Extract(right)); constexpr auto new_strides = GetStrides() .Extract(left) .PushBack(Number{}) .Append(GetStrides().Extract(right)); constexpr auto new_ranks = tmp_ranks.Extract(left) .PushBack(Number{}) .Append(tmp_ranks.Extract(right)); return ConstantTensorDescriptor{}; } template __host__ __device__ static constexpr auto ReorderGivenNew2Old(MapNew2Old) { return ConstantTensorDescriptor{}; } #if 0 // require sequence_sort, which is not implemented yet template __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New) { return ConstantTensorDescriptor{}; } #endif }; template __host__ __device__ constexpr auto make_ConstantTensorDescriptor_default_rank_packed(Lengths) { using Strides = decltype(calculate_tensor_strides_default_rank_packed(Lengths{})); using MemoryRanks = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::SeqType; return ConstantTensorDescriptor{}; } template __host__ __device__ constexpr auto make_ConstantTensorDescriptor_default_rank(Lengths, Strides) { using MemoryRanks = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::SeqType; return ConstantTensorDescriptor{}; } template __host__ __device__ constexpr auto make_ConstantTensorDescriptor_default_rank_aligned(Lengths, Number) { using Strides = decltype(calculate_tensor_strides_default_rank_aligned(Lengths{}, Number{})); using MemoryRanks = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::SeqType; return ConstantTensorDescriptor{}; } template __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s) { constexpr index_t ndim = TDesc::GetNumOfDimension(); static_assert(ndim >= 2 && ndim <= 10, "wrong!"); static_if{}([&](auto fwd) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto desc = fwd(TDesc{}); printf("%s dim %u, lengths {%u %u}, strides {%u %u}, ranks {%u %u}\n", s, desc.GetNumOfDimension(), desc.GetLength(I0), desc.GetLength(I1), desc.GetStride(I0), desc.GetStride(I1), desc.GetMemoryRank(I0), desc.GetMemoryRank(I1)); }); static_if{}([&](auto fwd) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto desc = fwd(TDesc{}); printf("%s dim %u, lengths {%u %u %u}, strides {%u %u %u}, ranks {%u %u %u}\n", s, desc.GetNumOfDimension(), desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetMemoryRank(I0), desc.GetMemoryRank(I1), desc.GetMemoryRank(I2)); }); static_if{}([&](auto fwd) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; constexpr auto desc = fwd(TDesc{}); printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}, ranks {%u %u %u %u}\n", s, desc.GetNumOfDimension(), desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3), desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3), desc.GetMemoryRank(I0), desc.GetMemoryRank(I1), desc.GetMemoryRank(I2), desc.GetMemoryRank(I3)); }); static_if{}([&](auto fwd) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; constexpr auto I4 = Number<4>{}; constexpr auto desc = fwd(TDesc{}); printf("%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}, ranks {%u %u %u %u " "%u}\n", s, desc.GetNumOfDimension(), desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3), desc.GetLength(I4), desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3), desc.GetStride(I4), desc.GetMemoryRank(I0), desc.GetMemoryRank(I1), desc.GetMemoryRank(I2), desc.GetMemoryRank(I3), desc.GetMemoryRank(I4)); }); static_if{}([&](auto fwd) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; constexpr auto I4 = Number<4>{}; constexpr auto I5 = Number<5>{}; constexpr auto desc = fwd(TDesc{}); printf("%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}, ranks {%u %u " "%u %u %u %u}\n", s, desc.GetNumOfDimension(), desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3), desc.GetLength(I4), desc.GetLength(I5), desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3), desc.GetStride(I4), desc.GetStride(I5), desc.GetMemoryRank(I0), desc.GetMemoryRank(I1), desc.GetMemoryRank(I2), desc.GetMemoryRank(I3), desc.GetMemoryRank(I4), desc.GetMemoryRank(I5)); }); static_if{}([&](auto fwd) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; constexpr auto I4 = Number<4>{}; constexpr auto I5 = Number<5>{}; constexpr auto I6 = Number<6>{}; constexpr auto desc = fwd(TDesc{}); printf("%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}, ranks " "{%u %u %u %u %u %u %u}\n", s, desc.GetNumOfDimension(), desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3), desc.GetLength(I4), desc.GetLength(I5), desc.GetLength(I6), desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3), desc.GetStride(I4), desc.GetStride(I5), desc.GetStride(I6), desc.GetMemoryRank(I0), desc.GetMemoryRank(I1), desc.GetMemoryRank(I2), desc.GetMemoryRank(I3), desc.GetMemoryRank(I4), desc.GetMemoryRank(I5), desc.GetMemoryRank(I6)); }); static_if{}([&](auto fwd) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; constexpr auto I4 = Number<4>{}; constexpr auto I5 = Number<5>{}; constexpr auto I6 = Number<6>{}; constexpr auto I7 = Number<7>{}; constexpr auto desc = fwd(TDesc{}); printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}, " "ranks {%u %u %u %u %u %u %u %u}\n", s, desc.GetNumOfDimension(), desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3), desc.GetLength(I4), desc.GetLength(I5), desc.GetLength(I6), desc.GetLength(I7), desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3), desc.GetStride(I4), desc.GetStride(I5), desc.GetStride(I6), desc.GetStride(I7), desc.GetMemoryRank(I0), desc.GetMemoryRank(I1), desc.GetMemoryRank(I2), desc.GetMemoryRank(I3), desc.GetMemoryRank(I4), desc.GetMemoryRank(I5), desc.GetMemoryRank(I6), desc.GetMemoryRank(I7)); }); static_if{}([&](auto fwd) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; constexpr auto I4 = Number<4>{}; constexpr auto I5 = Number<5>{}; constexpr auto I6 = Number<6>{}; constexpr auto I7 = Number<7>{}; constexpr auto I8 = Number<8>{}; constexpr auto desc = fwd(TDesc{}); printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u " "%u}, ranks {%u %u %u %u %u %u %u %u %u}\n", s, desc.GetNumOfDimension(), desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3), desc.GetLength(I4), desc.GetLength(I5), desc.GetLength(I6), desc.GetLength(I7), desc.GetLength(I8), desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3), desc.GetStride(I4), desc.GetStride(I5), desc.GetStride(I6), desc.GetStride(I7), desc.GetStride(I8), desc.GetMemoryRank(I0), desc.GetMemoryRank(I1), desc.GetMemoryRank(I2), desc.GetMemoryRank(I3), desc.GetMemoryRank(I4), desc.GetMemoryRank(I5), desc.GetMemoryRank(I6), desc.GetMemoryRank(I7), desc.GetMemoryRank(I8)); }); static_if{}([&](auto fwd) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; constexpr auto I4 = Number<4>{}; constexpr auto I5 = Number<5>{}; constexpr auto I6 = Number<6>{}; constexpr auto I7 = Number<7>{}; constexpr auto I8 = Number<8>{}; constexpr auto I9 = Number<9>{}; constexpr auto desc = fwd(TDesc{}); printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u " "%u %u %u}, ranks {%u %u %u %u %u %u %u %u %u %u}\n", s, desc.GetNumOfDimension(), desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3), desc.GetLength(I4), desc.GetLength(I5), desc.GetLength(I6), desc.GetLength(I7), desc.GetLength(I8), desc.GetLength(I9), desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3), desc.GetStride(I4), desc.GetStride(I5), desc.GetStride(I6), desc.GetStride(I7), desc.GetStride(I8), desc.GetStride(I9), desc.GetMemoryRank(I0), desc.GetMemoryRank(I1), desc.GetMemoryRank(I2), desc.GetMemoryRank(I3), desc.GetMemoryRank(I4), desc.GetMemoryRank(I5), desc.GetMemoryRank(I6), desc.GetMemoryRank(I7), desc.GetMemoryRank(I8), desc.GetMemoryRank(I9)); }); }