#include #include static const unsigned WARP_SIZE = 32; // The maximum number of threads in a block static const unsigned MAX_BLOCK_SIZE = 512U; template struct ScalarConvert { static __host__ __device__ __forceinline__ Out to(const In v) { return (Out) v; } }; // Number of threads in a block given an input size up to MAX_BLOCK_SIZE static int getNumThreads(int nElem) { int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE }; for (int i = 0; i != 5; ++i) { if (nElem <= threadSizes[i]) { return threadSizes[i]; } } return MAX_BLOCK_SIZE; } // Returns the index of the most significant 1 bit in `val`. __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); } template __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) { #if CUDA_VERSION >= 9000 return __shfl_xor_sync(mask, value, laneMask, width); #else return __shfl_xor(value, laneMask, width); #endif } // Sum across all threads within a warp template static __device__ __forceinline__ T warpSum(T val) { #if __CUDA_ARCH__ >= 300 for (int i = 0; i < getMSB(WARP_SIZE); ++i) { val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE); } #else __shared__ T values[MAX_BLOCK_SIZE]; values[threadIdx.x] = val; __threadfence_block(); const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE; for (int i = 1; i < WARP_SIZE; i++) { val += values[base + ((i + threadIdx.x) % WARP_SIZE)]; } #endif return val; } template struct Float2 { Acctype v1, v2; __device__ Float2() {} __device__ Float2(DType v1, DType v2) : v1(ScalarConvert::to(v1)), v2(ScalarConvert::to(v2)) {} __device__ Float2(DType v) : v1(ScalarConvert::to(v)), v2(ScalarConvert::to(v)) {} __device__ Float2(int v) : v1(ScalarConvert::to(v)), v2(ScalarConvert::to(v)) {} __device__ Float2& operator+=(const Float2& a) { v1 += a.v1; v2 += a.v2; return *this; } }; template static __device__ __forceinline__ Float2 warpSum(Float2 value) { value.v1 = warpSum(value.v1); value.v2 = warpSum(value.v2); return value; } template __device__ T reduceD( Op op, int b, int i, int k, int D) { T sum = 0; for (int x = threadIdx.x; x < D; x += blockDim.x) { sum += op(b,i,k,x); } // sum over NumThreads within a warp sum = warpSum(sum); // 'transpose', and reduce within warp again __shared__ T shared[32]; __syncthreads(); if (threadIdx.x % WARP_SIZE == 0) { if (threadIdx.x / WARP_SIZE < 32) { shared[threadIdx.x / WARP_SIZE] = sum; } } if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { // zero out the other entries in shared shared[threadIdx.x] = (T) 0; } __syncthreads(); if (threadIdx.x / WARP_SIZE == 0) { sum = warpSum(shared[threadIdx.x]); if (threadIdx.x == 0) { shared[0] = sum; } } __syncthreads(); // Everyone picks it up, should be broadcast into the whole gradInput return shared[0]; } template __device__ T reduceN( Op op, int b, int k, int d, int N) { T sum = 0; for (int x = threadIdx.x; x < N; x += blockDim.x) { sum += op(b,x,k,d); } // sum over NumThreads within a warp sum = warpSum(sum); // 'transpose', and reduce within warp again __shared__ T shared[32]; __syncthreads(); if (threadIdx.x % WARP_SIZE == 0) { if (threadIdx.x / WARP_SIZE < 32) { shared[threadIdx.x / WARP_SIZE] = sum; } } if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { // zero out the other entries in shared shared[threadIdx.x] = (T) 0; } __syncthreads(); if (threadIdx.x / WARP_SIZE == 0) { sum = warpSum(shared[threadIdx.x]); if (threadIdx.x == 0) { shared[0] = sum; } } __syncthreads(); // Everyone picks it up, should be broadcast into the whole gradInput return shared[0]; } template __device__ T reduceK( Op op, int b, int i, int d, int K) { T sum = 0; for (int x = threadIdx.x; x < K; x += blockDim.x) { sum += op(b,i,x,d); } // sum over NumThreads within a warp sum = warpSum(sum); // 'transpose', and reduce within warp again __shared__ T shared[32]; __syncthreads(); if (threadIdx.x % WARP_SIZE == 0) { if (threadIdx.x / WARP_SIZE < 32) { shared[threadIdx.x / WARP_SIZE] = sum; } } if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { // zero out the other entries in shared shared[threadIdx.x] = (T) 0; } __syncthreads(); if (threadIdx.x / WARP_SIZE == 0) { sum = warpSum(shared[threadIdx.x]); if (threadIdx.x == 0) { shared[0] = sum; } } __syncthreads(); // Everyone picks it up, should be broadcast into the whole gradInput return shared[0]; } template __device__ T reduceBN( Op op, int k, int d, int B, int N) { T sum = 0; for (int batch = 0; batch < B; ++batch) { for (int x = threadIdx.x; x < N; x += blockDim.x) { sum += op(batch,x,k,d); } } // sum over NumThreads within a warp sum = warpSum(sum); // 'transpose', and reduce within warp again __shared__ T shared[32]; __syncthreads(); if (threadIdx.x % WARP_SIZE == 0) { if (threadIdx.x / WARP_SIZE < 32) { shared[threadIdx.x / WARP_SIZE] = sum; } } if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { // zero out the other entries in shared shared[threadIdx.x] = (T) 0; } __syncthreads(); if (threadIdx.x / WARP_SIZE == 0) { sum = warpSum(shared[threadIdx.x]); if (threadIdx.x == 0) { shared[0] = sum; } } __syncthreads(); // Everyone picks it up, should be broadcast into the whole gradInput return shared[0]; }