#pragma once #include "common.hip.hpp" // this is ugly, only for 2d template __host__ __device__ constexpr auto calculate_default_strides(Sequence) { return Sequence{}; } // this is ugly, only for 3d template __host__ __device__ constexpr auto calculate_default_strides(Sequence) { return Sequence{}; } // this is ugly, only for 4d template __host__ __device__ constexpr auto calculate_default_strides(Sequence) { return Sequence{}; } // this is ugly, only for 6d template __host__ __device__ constexpr auto calculate_default_strides(Sequence) { return Sequence{}; } // this is ugly, only for 8d template __host__ __device__ constexpr auto calculate_default_strides(Sequence) { return Sequence{}; } // this is ugly, only for 8d template __host__ __device__ constexpr auto calculate_default_strides(Sequence) { return Sequence{}; } // this is ugly, only for 2d template __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence, Number) { constexpr index_t L1_align = Align * ((L1 + Align - 1) / Align); return Sequence{}; } // this is ugly, only for 3d template __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence, Number) { constexpr index_t L2_align = Align * ((L2 + Align - 1) / Align); return Sequence{}; } // this is ugly, only for 4d template __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence, Number) { constexpr index_t L3_align = Align * ((L3 + Align - 1) / Align); return Sequence{}; } template struct ConstantTensorDescriptor { using Type = ConstantTensorDescriptor; static constexpr index_t nDim = Lengths::nDim; __host__ __device__ constexpr ConstantTensorDescriptor() { static_assert(Lengths::nDim == Strides::nDim, "nDim not consistent"); } __host__ __device__ static constexpr index_t GetDimension() { return nDim; } __host__ __device__ static constexpr Lengths GetLengths() { return Lengths{}; } __host__ __device__ static constexpr Strides GetStrides() { return Strides{}; } template __host__ __device__ static constexpr index_t GetLength(Number) { return Lengths{}.Get(Number{}); } template __host__ __device__ static constexpr index_t GetStride(Number) { return Strides{}.Get(Number{}); } __host__ __device__ static constexpr index_t GetElementSize() { return accumulate_on_sequence(Lengths{}, mod_conv::multiplies{}, Number<1>{}); } // c++14 doesn't support constexpr lambdas, has to use this trick instead struct GetElementSpace_f { template __host__ __device__ constexpr index_t operator()(IDim idim) const { return (Type{}.GetLength(idim) - 1) * Type{}.GetStride(idim); } }; template > __host__ __device__ static constexpr index_t GetElementSpace(Align align = Align{}) { index_t element_space_unaligned = static_const_reduce_n{}(GetElementSpace_f{}, mod_conv::plus{}) + 1; return align.Get() * ((element_space_unaligned + align.Get() - 1) / align.Get()); } template __host__ __device__ static index_t Get1dIndex(Is... is) { static_assert(sizeof...(Is) == nDim, "number of multi-index is wrong"); const auto multi_id = Array(is...); index_t id = 0; static_for<0, nDim, 1>{}([&](auto IDim) { constexpr index_t idim = IDim.Get(); #if DEVICE_BACKEND_HIP id += __mul24(multi_id[idim], GetStride(IDim)); #else id += multi_id[idim] * GetStride(IDim); #endif }); return id; } __host__ __device__ static Array GetMultiIndex(index_t id) { Array multi_id; static_for<0, nDim - 1, 1>{}([&](auto IDim) { constexpr index_t idim = IDim.Get(); multi_id[idim] = id / GetStride(IDim); id -= multi_id[idim] * GetStride(IDim); }); multi_id[nDim - 1] = id / GetStride(Number{}); return multi_id; } __host__ __device__ static constexpr auto Condense() { constexpr auto default_strides = calculate_default_strides(Lengths{}); return ConstantTensorDescriptor{}; } }; template __host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths) { using Strides = decltype(calculate_default_strides(Lengths{})); return ConstantTensorDescriptor{}; } template __host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Strides) { return ConstantTensorDescriptor{}; } template __host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths, Number) { using Strides = decltype(calculate_default_strides_aligned(Lengths{}, Number{})); return ConstantTensorDescriptor{}; } template __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s) { constexpr auto desc = TDesc{}; constexpr index_t ndim = desc.GetDimension(); static_assert(ndim >= 2 && ndim <= 10, "wrong!"); if(ndim == 2) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; printf("%s dim %u, lengths {%u %u}, strides {%u %u}\n", s, desc.GetDimension(), desc.GetLength(I0), desc.GetLength(I1), desc.GetStride(I0), desc.GetStride(I1)); } else if(ndim == 3) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; printf("%s dim %u, lengths {%u %u %u}, strides {%u %u %u}\n", s, desc.GetDimension(), desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2)); } else if(ndim == 4) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}\n", s, desc.GetDimension(), desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3), desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3)); } else if(ndim == 5) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; constexpr auto I4 = Number<4>{}; printf("%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}\n", s, desc.GetDimension(), desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3), desc.GetLength(I4), desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3), desc.GetStride(I4)); } else if(ndim == 6) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; constexpr auto I4 = Number<4>{}; constexpr auto I5 = Number<5>{}; printf("%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}\n", s, desc.GetDimension(), desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3), desc.GetLength(I4), desc.GetLength(I5), desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3), desc.GetStride(I4), desc.GetStride(I5)); } else if(ndim == 7) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; constexpr auto I4 = Number<4>{}; constexpr auto I5 = Number<5>{}; constexpr auto I6 = Number<6>{}; printf("%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}\n", s, desc.GetDimension(), desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3), desc.GetLength(I4), desc.GetLength(I5), desc.GetLength(I6), desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3), desc.GetStride(I4), desc.GetStride(I5), desc.GetStride(I6)); } else if(ndim == 8) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; constexpr auto I4 = Number<4>{}; constexpr auto I5 = Number<5>{}; constexpr auto I6 = Number<6>{}; constexpr auto I7 = Number<7>{}; printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}\n", s, desc.GetDimension(), desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3), desc.GetLength(I4), desc.GetLength(I5), desc.GetLength(I6), desc.GetLength(I7), desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3), desc.GetStride(I4), desc.GetStride(I5), desc.GetStride(I6), desc.GetStride(I7)); } else if(ndim == 9) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; constexpr auto I4 = Number<4>{}; constexpr auto I5 = Number<5>{}; constexpr auto I6 = Number<6>{}; constexpr auto I7 = Number<7>{}; constexpr auto I8 = Number<8>{}; printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u " "%u}\n", s, desc.GetDimension(), desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3), desc.GetLength(I4), desc.GetLength(I5), desc.GetLength(I6), desc.GetLength(I7), desc.GetLength(I8), desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3), desc.GetStride(I4), desc.GetStride(I5), desc.GetStride(I6), desc.GetStride(I7), desc.GetStride(I8)); } else if(ndim == 10) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; constexpr auto I4 = Number<4>{}; constexpr auto I5 = Number<5>{}; constexpr auto I6 = Number<6>{}; constexpr auto I7 = Number<7>{}; constexpr auto I8 = Number<8>{}; constexpr auto I9 = Number<9>{}; printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u " "%u %u %u}\n", s, desc.GetDimension(), desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3), desc.GetLength(I4), desc.GetLength(I5), desc.GetLength(I6), desc.GetLength(I7), desc.GetLength(I8), desc.GetLength(I9), desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3), desc.GetStride(I4), desc.GetStride(I5), desc.GetStride(I6), desc.GetStride(I7), desc.GetStride(I8), desc.GetStride(I9)); } }