#pragma once #include "common.cuh" template struct Constant { const T mValue = N; }; template using Number = Constant; template struct Sequence { static constexpr unsigned nDim = sizeof...(Is); const unsigned mData[nDim] = {Is...}; template __host__ __device__ constexpr unsigned Get(Number) const { return mData[I]; } }; template struct ConstantTensorDescriptor { static constexpr unsigned nDim = Lengths::nDim; using NDimConstant = Number; __host__ __device__ constexpr ConstantTensorDescriptor() { static_assert(Lengths::nDim == Strides::nDim, "nDim not consistent"); } __host__ __device__ constexpr unsigned GetDimension() const { return nDim; } __host__ __device__ constexpr Lengths GetLengths() const { return Lengths{}; } __host__ __device__ constexpr Strides GetStrides() const { return Strides{}; } template __host__ __device__ constexpr unsigned GetLength(Number) const { return Lengths{}.Get(Number{}); } template __host__ __device__ constexpr unsigned GetStride(Number) const { return Strides{}.Get(Number{}); } // this is ugly, only for 4d __host__ __device__ constexpr unsigned GetElementSize() const { static_assert(nDim == 4, "nDim is not 4"); constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3); } // this is ugly, only for 4d __host__ __device__ constexpr unsigned GetElementSpace() const { static_assert(nDim == 4, "nDim is not 4"); constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) + (GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) + 1; } // this is ugly, only for 4d __host__ __device__ unsigned Get1dIndex(unsigned n, unsigned c, unsigned h, unsigned w) const { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; static_assert(nDim == 4, "nDim is not 4"); return n * GetStride(I0) + c * GetStride(I1) + h * GetStride(I2) + w * GetStride(I3); } }; // this is ugly, only for 4d template __host__ __device__ constexpr auto calculate_default_strides(Sequence) { return Sequence{}; } template __host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths) { using Strides = decltype(calculate_default_strides(Lengths{})); return ConstantTensorDescriptor{}; } template __host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Strides) { return ConstantTensorDescriptor{}; } // this is ugly, only for 4d template __host__ __device__ constexpr auto get_output_4d_tensor_descriptor(InDesc, WeiDesc) { constexpr auto in_desc = InDesc{}; constexpr auto wei_desc = WeiDesc{}; constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; static_assert(in_desc.GetDimension() == 4, "input nDim is not 4"); static_assert(wei_desc.GetDimension() == 4, "weight nDim is not 4"); static_assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1), "input & weight dimension not consistent"); constexpr auto N = in_desc.GetLength(I0); constexpr auto HI = in_desc.GetLength(I2); constexpr auto WI = in_desc.GetLength(I3); constexpr auto K = wei_desc.GetLength(I0); constexpr auto S = wei_desc.GetLength(I2); constexpr auto R = wei_desc.GetLength(I3); constexpr auto HO = HI - S + 1; constexpr auto WO = WI - R + 1; return make_ConstantTensorDescriptor(Sequence{}); } // this is ugly, only for 4d template __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s) { constexpr auto desc = TDesc{}; constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; static_assert(desc.GetDimension() == 4, "dim is not 4"); printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}\n", s, desc.GetDimension(), desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3), desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3)); }