#pragma once #include "common.hip.hpp" // this is ugly, only for 2d template __host__ __device__ constexpr auto calculate_default_strides(Sequence) { return Sequence{}; } // this is ugly, only for 4d template __host__ __device__ constexpr auto calculate_default_strides(Sequence) { return Sequence{}; } // this is ugly, only for 6d template __host__ __device__ constexpr auto calculate_default_strides(Sequence) { return Sequence{}; } // this is ugly, only for 8d template __host__ __device__ constexpr auto calculate_default_strides(Sequence) { return Sequence{}; } // this is ugly, only for 2d template __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence, Number) { constexpr unsigned L1_align = Align * ((L1 + Align - 1) / Align); return Sequence{}; } // this is ugly, only for 4d template __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence, Number) { constexpr unsigned L3_align = Align * ((L3 + Align - 1) / Align); return Sequence{}; } template struct ConstantTensorDescriptor { static constexpr unsigned nDim = Lengths::nDim; using NDimConstant = Number; __host__ __device__ constexpr ConstantTensorDescriptor() { static_assert(Lengths::nDim == Strides::nDim, "nDim not consistent"); } __host__ __device__ constexpr unsigned GetDimension() const { return nDim; } __host__ __device__ constexpr Lengths GetLengths() const { return Lengths{}; } __host__ __device__ constexpr Strides GetStrides() const { return Strides{}; } template __host__ __device__ constexpr unsigned GetLength(Number) const { return Lengths{}.Get(Number{}); } template __host__ __device__ constexpr unsigned GetStride(Number) const { return Strides{}.Get(Number{}); } __host__ __device__ constexpr unsigned GetElementSize() const { static_assert(nDim >= 2 && nDim <= 8, "nDim"); if(nDim == 2) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; return GetLength(I0) * GetLength(I1); } else if(nDim == 3) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; return GetLength(I0) * GetLength(I1) * GetLength(I2); } else if(nDim == 4) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3); } else if(nDim == 5) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; constexpr auto I4 = Number<4>{}; return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3) * GetLength(I4); } else if(nDim == 6) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; constexpr auto I4 = Number<4>{}; constexpr auto I5 = Number<5>{}; return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3) * GetLength(I4) * GetLength(I5); } else if(nDim == 7) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; constexpr auto I4 = Number<4>{}; constexpr auto I5 = Number<5>{}; constexpr auto I6 = Number<6>{}; return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3) * GetLength(I4) * GetLength(I5) * GetLength(I6); } else if(nDim == 8) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; constexpr auto I4 = Number<4>{}; constexpr auto I5 = Number<5>{}; constexpr auto I6 = Number<6>{}; constexpr auto I7 = Number<7>{}; return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3) * GetLength(I4) * GetLength(I5) * GetLength(I6) * GetLength(I7); } else { assert(false); } } template > __host__ __device__ constexpr unsigned GetElementSpace(Align align = Align{}) const { static_assert(nDim >= 2 && nDim <= 8, "nDim"); constexpr unsigned align_size = align.Get(); if(nDim == 2) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) + align_size; } else if(nDim == 3) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) + (GetLength(I2) - 1) * GetStride(I2) + align_size; } else if(nDim == 4) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) + (GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) + align_size; } else if(nDim == 5) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; constexpr auto I4 = Number<4>{}; return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) + (GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) + (GetLength(I4) - 1) * GetStride(I4) + align_size; } else if(nDim == 6) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; constexpr auto I4 = Number<4>{}; constexpr auto I5 = Number<5>{}; return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) + (GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) + (GetLength(I4) - 1) * GetStride(I4) + (GetLength(I5) - 1) * GetStride(I5) + align_size; } else if(nDim == 7) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; constexpr auto I4 = Number<4>{}; constexpr auto I5 = Number<5>{}; constexpr auto I6 = Number<6>{}; return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) + (GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) + (GetLength(I4) - 1) * GetStride(I4) + (GetLength(I5) - 1) * GetStride(I5) + (GetLength(I6) - 1) * GetStride(I6) + align_size; } else if(nDim == 8) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; constexpr auto I4 = Number<4>{}; constexpr auto I5 = Number<5>{}; constexpr auto I6 = Number<6>{}; constexpr auto I7 = Number<7>{}; return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) + (GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) + (GetLength(I4) - 1) * GetStride(I4) + (GetLength(I5) - 1) * GetStride(I5) + (GetLength(I6) - 1) * GetStride(I6) + (GetLength(I7) - 1) * GetStride(I7) + align_size; } } // this is ugly, only for 2d __host__ __device__ unsigned Get1dIndex(unsigned i0, unsigned i1) const { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; static_assert(nDim == 2, "nDim is not 2"); return i0 * GetStride(I0) + i1 * GetStride(I1); } // this is ugly, only for 3d __host__ __device__ unsigned Get1dIndex(unsigned i0, unsigned i1, unsigned i2) const { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; static_assert(nDim == 3, "nDim is not 3"); return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2); } // this is ugly, only for 4d __host__ __device__ unsigned Get1dIndex(unsigned i0, unsigned i1, unsigned i2, unsigned i3) const { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; static_assert(nDim == 4, "nDim is not 4"); return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3); } // this is ugly, only for 5d __host__ __device__ unsigned Get1dIndex(unsigned i0, unsigned i1, unsigned i2, unsigned i3, unsigned i4) const { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; constexpr auto I4 = Number<4>{}; static_assert(nDim == 5, "nDim is not 5"); return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3) + i4 * GetStride(I4); } // this is ugly, only for 6d __host__ __device__ unsigned Get1dIndex(unsigned i0, unsigned i1, unsigned i2, unsigned i3, unsigned i4, unsigned i5) const { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; constexpr auto I4 = Number<4>{}; constexpr auto I5 = Number<5>{}; static_assert(nDim == 6, "nDim is not 6"); return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3) + i4 * GetStride(I4) + i5 * GetStride(I5); } // this is ugly, only for 7d __host__ __device__ unsigned Get1dIndex(unsigned i0, unsigned i1, unsigned i2, unsigned i3, unsigned i4, unsigned i5, unsigned i6) const { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; constexpr auto I4 = Number<4>{}; constexpr auto I5 = Number<5>{}; constexpr auto I6 = Number<6>{}; static_assert(nDim == 7, "nDim is not 7"); return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3) + i4 * GetStride(I4) + i5 * GetStride(I5) + i6 * GetStride(I6); } // this is ugly, only for 8d __host__ __device__ unsigned Get1dIndex(unsigned i0, unsigned i1, unsigned i2, unsigned i3, unsigned i4, unsigned i5, unsigned i6, unsigned i7) const { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; constexpr auto I4 = Number<4>{}; constexpr auto I5 = Number<5>{}; constexpr auto I6 = Number<6>{}; constexpr auto I7 = Number<7>{}; static_assert(nDim == 8, "nDim is not 8"); return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3) + i4 * GetStride(I4) + i5 * GetStride(I5) + i6 * GetStride(I6) + i7 * GetStride(I7); } __host__ __device__ constexpr auto Condense() const { constexpr auto default_strides = calculate_default_strides(Lengths{}); return ConstantTensorDescriptor{}; } }; template __host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths) { using Strides = decltype(calculate_default_strides(Lengths{})); return ConstantTensorDescriptor{}; } template __host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Strides) { return ConstantTensorDescriptor{}; } template __host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths, Number) { using Strides = decltype(calculate_default_strides_aligned(Lengths{}, Number{})); return ConstantTensorDescriptor{}; } template __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s) { constexpr auto desc = TDesc{}; constexpr unsigned ndim = desc.GetDimension(); static_assert(ndim >= 2 && ndim <= 8, "wrong!"); if(ndim == 2) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; printf("%s dim %u, lengths {%u %u}, strides {%u %u}\n", s, desc.GetDimension(), desc.GetLength(I0), desc.GetLength(I1), desc.GetStride(I0), desc.GetStride(I1)); } else if(ndim == 4) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}\n", s, desc.GetDimension(), desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3), desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3)); } else if(ndim == 5) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; constexpr auto I4 = Number<4>{}; printf("%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}\n", s, desc.GetDimension(), desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3), desc.GetLength(I4), desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3), desc.GetStride(I4)); } else if(ndim == 6) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; constexpr auto I4 = Number<4>{}; constexpr auto I5 = Number<5>{}; printf("%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}\n", s, desc.GetDimension(), desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3), desc.GetLength(I4), desc.GetLength(I5), desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3), desc.GetStride(I4), desc.GetStride(I5)); } else if(ndim == 7) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; constexpr auto I4 = Number<4>{}; constexpr auto I5 = Number<5>{}; constexpr auto I6 = Number<6>{}; printf("%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}\n", s, desc.GetDimension(), desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3), desc.GetLength(I4), desc.GetLength(I5), desc.GetLength(I6), desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3), desc.GetStride(I4), desc.GetStride(I5), desc.GetStride(I6)); } else if(ndim == 8) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; constexpr auto I4 = Number<4>{}; constexpr auto I5 = Number<5>{}; constexpr auto I6 = Number<6>{}; constexpr auto I7 = Number<7>{}; printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}\n", s, desc.GetDimension(), desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3), desc.GetLength(I4), desc.GetLength(I5), desc.GetLength(I6), desc.GetLength(I7), desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3), desc.GetStride(I4), desc.GetStride(I5), desc.GetStride(I6), desc.GetStride(I7)); } }