#pragma once #include "common.hpp" template __host__ __device__ constexpr auto calculate_tensor_strides_packed(Lengths) { return reverse_inclusive_scan_sequence( Lengths{}.PopFront(), mod_conv::multiplies{}, Number<1>{}) .PushBack(Number<1>{}); } template __host__ __device__ constexpr auto calculate_tensor_strides_aligned(Lengths, Number) { constexpr index_t L_back_align = Align * mod_conv::integer_divide_ceiler{}(Lengths{}.Back(), Align); return calculate_tensor_strides_packed( Lengths{}.Modify(Number{}, Number{})); } template struct ConstantTensorDescriptor { using Type = ConstantTensorDescriptor; static constexpr index_t nDim = Lengths::GetSize(); __host__ __device__ constexpr ConstantTensorDescriptor() { static_assert(Lengths::GetSize() == Strides::GetSize(), "nDim not consistent"); } __host__ __device__ static constexpr auto GetOriginalTensorDescriptor() { return Type{}; } template __host__ __device__ static constexpr auto GetContainedOriginalDimensions(Number) { return Sequence{}; } __host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; } __host__ __device__ static constexpr auto GetLengths() { return Lengths{}; } __host__ __device__ static constexpr auto GetStrides() { return Strides{}; } template __host__ __device__ static constexpr index_t GetLength(Number) { return Lengths::Get(Number{}); } template __host__ __device__ static constexpr index_t GetStride(Number) { return Strides::Get(Number{}); } struct lambda_AreDimensionsContinuous { bool& is_continuous; __host__ __device__ constexpr lambda_AreDimensionsContinuous(bool& is_continuous_) : is_continuous(is_continuous_) { } template __host__ __device__ constexpr void operator()(Number) const { constexpr auto IDim = Number{}; constexpr auto IDim_p1 = Number{}; is_continuous = is_continuous && (GetStride(IDim) >= GetStride(IDim_p1) && GetStride(IDim) == GetStride(IDim_p1) * GetLength(IDim_p1)); } }; __host__ __device__ static constexpr bool AreDimensionsContinuous() { bool is_continuous = true; static_for<0, nDim - 1, 1>{}(lambda_AreDimensionsContinuous(is_continuous)); return is_continuous; } __host__ __device__ static constexpr bool IsPackedTensor() { return AreDimensionsContinuous() && GetStride(Number{}) == 1; } template __host__ __device__ static constexpr bool ContainMultipleOriginalDimensions(T) { return false; } __host__ __device__ static constexpr index_t GetElementSize() { return accumulate_on_sequence(Lengths{}, mod_conv::multiplies{}, Number<1>{}); } template > __host__ __device__ static constexpr index_t GetElementSpace(Align align = Align{}) { // This is WRONG! align shouldbe applied to the last memory rank, not the last tensor // dimension constexpr index_t element_space_unaligned = accumulate_on_sequence( (GetLengths() - Number<1>{}) * GetStrides(), mod_conv::plus{}, Number<1>{}); return align.Get() * ((element_space_unaligned + align.Get() - 1) / align.Get()); } // emulate constexpr lambda template struct lambda_GetOffsetFromMultiIndex { Array& multi_id; index_t& offset; __host__ __device__ constexpr lambda_GetOffsetFromMultiIndex(Array& multi_id_, index_t& offset_) : multi_id(multi_id_), offset(offset_) { } template __host__ __device__ constexpr void operator()(X IDim) const { offset += multi_id[IDim] * Type::GetStride(IDim); } }; template __host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Array multi_id) { static_assert(NSize == nDim, "wrong! Dimension not consistent"); index_t offset = 0; static_for<0, nDim, 1>{}(lambda_GetOffsetFromMultiIndex(multi_id, offset)); return offset; } template __host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Is... is) { return GetOffsetFromMultiIndex(Array{is...}); } template __host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Sequence) { static_assert(sizeof...(Is) == nDim, "wrong! Dimension not consistent"); constexpr auto multi_id = Sequence{}; return accumulate_on_sequence( multi_id * GetStrides(), mod_conv::plus{}, Number<0>{}); } // emulate constexpr lambda template struct lambda_GetMultiIndexFrom1dIndex { index_t& id; Array& multi_id; __host__ __device__ constexpr lambda_GetMultiIndexFrom1dIndex(index_t& id_, Array& multi_id_) : id(id_), multi_id(multi_id_) { } template __host__ __device__ constexpr void operator()(IDim_) const { constexpr auto IDim = IDim_{}; constexpr index_t stride = PackedStrides::Get(IDim); multi_id.Set(IDim, id / stride); id -= multi_id[IDim] * stride; } }; __host__ __device__ static constexpr Array GetMultiIndexFrom1dIndex(index_t id) { Array multi_id; using PackedStrides = decltype(calculate_tensor_strides_packed(GetLengths())); // calculate index in each of the dimensions in the order of their dimension static_for<0, nDim - 1, 1>{}(lambda_GetMultiIndexFrom1dIndex(id, multi_id)); multi_id.Set(Number{}, id / PackedStrides::Get(Number{})); return multi_id; } __host__ __device__ static constexpr auto GetOriginalMultiIndexFromMultiIndex(Array multi_id) { return multi_id; } // This function doesn't do carry check on the highest dimension for positive stepping (or // borrow check on the lowest dimension for negative stepping) , for performance reason. It is // the user's responsibility to make sure the result "new_mutli_id" is not out-of-bound on the // highest dimension for positive stepping (or on the lowest dimension for negative stepping) template __host__ __device__ static Array UpdateMultiIndexGivenStepSizeOf1dIndex(Array old_multi_id, index_t step_size_of_1d_index, integral_constant) { Array new_multi_id; const auto step_sizes = GetMultiIndexFrom1dIndex(step_size_of_1d_index); static_if{}([&](auto) { new_multi_id = old_multi_id + step_sizes; bool carry = false; // do carry check in reversed order, starting from lowest dimension // don't check the highest dimension static_for<0, nDim, 1>{}([&](auto IDimReverse) { constexpr index_t idim = nDim - 1 - IDimReverse.Get(); constexpr auto IDim = Number{}; if(carry) { ++new_multi_id(idim); } carry = false; if(new_multi_id[idim] >= GetLength(IDim)) { new_multi_id(idim) -= GetLength(IDim); carry = true; } }); }).Else([&](auto) { // shift up multi-id to avoid unsigned integer underflow during intermediate // calculations. After the shift, should have new_multi_id[...] >= 1 new_multi_id = old_multi_id + (GetLengths() - step_sizes); bool borrow = false; // do borrow check in reversed order, starting from lowest dimension // don't check the highest dimension static_for<0, nDim, 1>{}([&](auto IDimReverse) { constexpr index_t idim = nDim - 1 - IDimReverse.Get(); constexpr auto IDim = Number{}; if(borrow) { --new_multi_id(idim); } borrow = false; if(new_multi_id[idim] < GetLength(IDim)) { new_multi_id(idim) += GetLength(IDim); borrow = true; } }); // shift back down multi-id // here, should have new_multi_id[...] >= GetLengths() new_multi_id = new_multi_id - GetLengths(); }); return new_multi_id; } template __host__ __device__ static constexpr auto Extract(Number... extract_dims) { static_assert(sizeof...(IDims) <= GetNumOfDimension(), "wrong! too many number of dimensions to be extracted"); using extract_lengths = decltype(Lengths::Extract(extract_dims...)); using extract_strides = decltype(Strides::Extract(extract_dims...)); return ConstantTensorDescriptor{}; } template __host__ __device__ static constexpr auto Extract(Sequence) { return Extract(Number{}...); } template __host__ __device__ static constexpr auto Embed(ConstantTensorDescriptor) { using leaf_tensor = ConstantTensorDescriptor; return ConstantTensorDescriptor{}; } template __host__ __device__ static constexpr auto Slice(Number, Number) { using slice_lengths = decltype(Lengths{}.Modify(Number{}, Number{})); return ConstantTensorDescriptor{}; } template __host__ __device__ static constexpr auto Fold(Number, Number...) { constexpr auto fold_intervals = Sequence{}; constexpr index_t fold_intervals_product = accumulate_on_sequence(fold_intervals, mod_conv::multiplies{}, Number<1>{}); constexpr auto unfold_length = GetLength(Number{}); constexpr auto unfold_stride = GetStride(Number{}); // length of the dimension to be folded needs to be dividable by fold_interval_product, // otherwise, folding is invalid static_assert(unfold_length % fold_intervals_product == 0, "wrong! length on the dimension to be folded cannot be evenly divided!"); // folded lengths constexpr auto fold_lengths = Sequence{}.Append(fold_intervals); // folded strides constexpr auto fold_strides = Number{} * reverse_inclusive_scan_sequence( fold_intervals.PushBack(Number<1>{}), mod_conv::multiplies{}, Number<1>{}); // left and right constexpr auto left = typename arithmetic_sequence_gen<0, IDim, 1>::SeqType{}; constexpr auto right = typename arithmetic_sequence_gen::SeqType{}; constexpr auto new_lengths = GetLengths().Extract(left).Append(fold_lengths).Append(GetLengths().Extract(right)); constexpr auto new_strides = GetStrides().Extract(left).Append(fold_strides).Append(GetStrides().Extract(right)); return ConstantTensorDescriptor{}; } // this function unfold dimension [FirstUnfoldDim, ..., LastUnfoldDim] into 1 dimension template __host__ __device__ static constexpr auto Unfold(Number, Number) { static_assert(FirstUnfoldDim >= 0 && LastUnfoldDim < nDim && FirstUnfoldDim <= LastUnfoldDim, "wrong! should have FirstUnfoldDim <= LastUnfoldDim!"); // left and right constexpr auto left = typename arithmetic_sequence_gen<0, FirstUnfoldDim, 1>::SeqType{}; constexpr auto middle = typename arithmetic_sequence_gen::SeqType{}; constexpr auto right = typename arithmetic_sequence_gen::SeqType{}; // dimensions to be unfolded need to be continuous static_assert(Type::Extract(middle).AreDimensionsContinuous(), "wrong! not unfoldable"); // unfolded length, stride constexpr index_t unfold_length = accumulate_on_sequence( GetLengths().Extract(middle), mod_conv::multiplies{}, Number<1>{}); constexpr index_t unfold_stride = GetStride(Number{}); // new lengths, strides constexpr auto new_lengths = GetLengths() .Extract(left) .PushBack(Number{}) .Append(GetLengths().Extract(right)); constexpr auto new_strides = GetStrides() .Extract(left) .PushBack(Number{}) .Append(GetStrides().Extract(right)); return ConstantTensorDescriptor{}; } template __host__ __device__ static constexpr auto ReorderGivenNew2Old(MapNew2Old) { return ConstantTensorDescriptor{}; } #if 0 // require sequence_sort, which is not implemented yet template __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New) { return ConstantTensorDescriptor{} } #endif }; template __host__ __device__ constexpr auto make_ConstantTensorDescriptor_packed(Lengths) { using Strides = decltype(calculate_tensor_strides_packed(Lengths{})); return ConstantTensorDescriptor{}; } template __host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Strides) { return ConstantTensorDescriptor{}; } template __host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths, Number) { using Strides = decltype(calculate_tensor_strides_aligned(Lengths{}, Number{})); return ConstantTensorDescriptor{}; } template __host__ __device__ void print_ConstantTensorDescriptor(const char* s, ConstantTensorDescriptor, Sequence>) { constexpr index_t ndim = sizeof...(Lengths); static_assert(ndim > 0 && ndim <= 10, "wrong!"); static_if{}([&](auto) { printf("%s dim %u, lengths {%u}, strides {%u}\n", s, ndim, Lengths..., Strides...); }); static_if{}([&](auto) { printf("%s dim %u, lengths {%u %u}, strides {%u %u}\n", s, ndim, Lengths..., Strides...); }); static_if{}([&](auto) { printf( "%s dim %u, lengths {%u %u %u}, strides {%u %u %u}\n", s, ndim, Lengths..., Strides...); }); static_if{}([&](auto) { printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}\n", s, ndim, Lengths..., Strides...); }); static_if{}([&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}\n", s, ndim, Lengths..., Strides...); }); static_if{}([&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}\n", s, ndim, Lengths..., Strides...); }); static_if{}([&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}\n", s, ndim, Lengths..., Strides...); }); static_if{}([&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}\n", s, ndim, Lengths..., Strides...); }); static_if{}([&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u " "%u}\n", s, ndim, Lengths..., Strides...); }); static_if{}([&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u " "%u %u %u}\n", s, ndim, Lengths..., Strides...); }); }