#include #include #include "common.h" #include "device_tensor.h" namespace { template struct AggOp { __device__ AggOp(DeviceTensor a, DeviceTensor x, DeviceTensor c) : A(a), X(x), C(c) {} __device__ __forceinline__ Acctype operator()(int b, int i, int k, int d) { return ScalarConvert::to(A[b][i][k] * (X[b][i][d] - C[k][d])); } DeviceTensor A; DeviceTensor X; DeviceTensor C; }; template struct AggBackOp { __device__ AggBackOp(DeviceTensor g, DeviceTensor x, DeviceTensor c) : G(g), X(x), C(c) {} __device__ __forceinline__ Acctype operator()(int b, int i, int k, int d) { return ScalarConvert::to(G[b][k][d] * (X[b][i][d] - C[k][d])); } DeviceTensor G; DeviceTensor X; DeviceTensor C; }; template struct SL2Op { __device__ SL2Op(DeviceTensor x, DeviceTensor c) : X(x), C(c) {} __device__ __forceinline__ Acctype operator()(int b, int i, int k, int d) { DType r = X[b][i][d] - C[k][d]; return ScalarConvert::to(r * r); } DeviceTensor X; DeviceTensor C; }; template struct SL2GradXOp { __device__ SL2GradXOp( DeviceTensor gsl, DeviceTensor x, DeviceTensor c, DeviceTensor s ) : GSL(gsl), X(x), C(c), S(s) {} __device__ __forceinline__ Acctype operator()(int b, int i, int k, int d) { return ScalarConvert::to( 2 * S[k] * GSL[b][i][k] * (X[b][i][d]-C[k][d])); } DeviceTensor GSL; DeviceTensor X; DeviceTensor C; DeviceTensor S; }; template __device__ T reduceN( Op op, int b, int k, int d, int N) { T sum = 0; for (int x = threadIdx.x; x < N; x += blockDim.x) { sum += op(b,x,k,d); } // sum over NumThreads within a warp sum = warpSum(sum); // 'transpose', and reduce within warp again __shared__ T shared[32]; __syncthreads(); if (threadIdx.x % WARP_SIZE == 0) { if (threadIdx.x / WARP_SIZE < 32) { shared[threadIdx.x / WARP_SIZE] = sum; } } if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { // zero out the other entries in shared shared[threadIdx.x] = (T) 0; } __syncthreads(); if (threadIdx.x / WARP_SIZE == 0) { sum = warpSum(shared[threadIdx.x]); if (threadIdx.x == 0) { shared[0] = sum; } } __syncthreads(); // Everyone picks it up, should be broadcast into the whole gradInput return shared[0]; } template __device__ T reduceD( Op op, int b, int i, int k, int D) { T sum = 0; for (int x = threadIdx.x; x < D; x += blockDim.x) { sum += op(b,i,k,x); } // sum over NumThreads within a warp sum = warpSum(sum); // 'transpose', and reduce within warp again __shared__ T shared[32]; __syncthreads(); if (threadIdx.x % WARP_SIZE == 0) { if (threadIdx.x / WARP_SIZE < 32) { shared[threadIdx.x / WARP_SIZE] = sum; } } if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { // zero out the other entries in shared shared[threadIdx.x] = (T) 0; } __syncthreads(); if (threadIdx.x / WARP_SIZE == 0) { sum = warpSum(shared[threadIdx.x]); if (threadIdx.x == 0) { shared[0] = sum; } } __syncthreads(); // Everyone picks it up, should be broadcast into the whole gradInput return shared[0]; } template __device__ T reduceK( Op op, int b, int i, int d, int K) { T sum = 0; for (int x = threadIdx.x; x < K; x += blockDim.x) { sum += op(b,i,x,d); } // sum over NumThreads within a warp sum = warpSum(sum); // 'transpose', and reduce within warp again __shared__ T shared[32]; __syncthreads(); if (threadIdx.x % WARP_SIZE == 0) { if (threadIdx.x / WARP_SIZE < 32) { shared[threadIdx.x / WARP_SIZE] = sum; } } if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { // zero out the other entries in shared shared[threadIdx.x] = (T) 0; } __syncthreads(); if (threadIdx.x / WARP_SIZE == 0) { sum = warpSum(shared[threadIdx.x]); if (threadIdx.x == 0) { shared[0] = sum; } } __syncthreads(); // Everyone picks it up, should be broadcast into the whole gradInput return shared[0]; } template __device__ T reduceBN( Op op, int k, int d, int B, int N) { T sum = 0; for (int batch = 0; batch < B; ++batch) { for (int x = threadIdx.x; x < N; x += blockDim.x) { sum += op(batch,x,k,d); } } // sum over NumThreads within a warp sum = warpSum(sum); // 'transpose', and reduce within warp again __shared__ T shared[32]; __syncthreads(); if (threadIdx.x % WARP_SIZE == 0) { if (threadIdx.x / WARP_SIZE < 32) { shared[threadIdx.x / WARP_SIZE] = sum; } } if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { // zero out the other entries in shared shared[threadIdx.x] = (T) 0; } __syncthreads(); if (threadIdx.x / WARP_SIZE == 0) { sum = warpSum(shared[threadIdx.x]); if (threadIdx.x == 0) { shared[0] = sum; } } __syncthreads(); // Everyone picks it up, should be broadcast into the whole gradInput return shared[0]; } template __global__ void Aggregate_Forward_kernel ( DeviceTensor E, DeviceTensor A, DeviceTensor X, DeviceTensor C) { /* declarations of the variables */ int b, k, d, N; /* Get the index and channels */ b = blockIdx.z; d = blockIdx.x; k = blockIdx.y; N = X.getSize(1); /* main operation */ AggOp g(A,X,C); E[b][k][d] = reduceN(g, b, k, d, N); } template __global__ void Aggregate_Backward_kernel ( DeviceTensor GA, DeviceTensor GE, DeviceTensor A, DeviceTensor X, DeviceTensor C) { /* declarations of the variables */ int b, k, i, D; /* Get the index and channels */ b = blockIdx.z; i = blockIdx.y; k = blockIdx.x; D = GE.getSize(2); /* main operation */ AggBackOp g(GE,X,C); GA[b][i][k] = reduceD(g, b, i, k, D); } template __global__ void ScaledL2_Forward_kernel ( DeviceTensor SL, DeviceTensor X, DeviceTensor C, DeviceTensor S) { /* declarations of the variables */ int b, k, i, D; /* Get the index and channels */ b = blockIdx.z; k = blockIdx.x; i = blockIdx.y; D = X.getSize(2); /* main operation */ SL2Op g(X,C); SL[b][i][k] = S[k] * reduceD(g,b,i,k,D);; } template __global__ void ScaledL2_GradX_kernel ( DeviceTensor GSL, DeviceTensor GX, DeviceTensor X, DeviceTensor C, DeviceTensor S) { /* declarations of the variables */ int b, d, i, K; /* Get the index and channels */ b = blockIdx.z; d = blockIdx.x; i = blockIdx.y; K = C.getSize(0); /* main operation */ SL2GradXOp g(GSL,X,C,S); GX[b][i][d] = reduceK(g,b,i,d,K); } template __global__ void ScaledL2_GradC_kernel ( DeviceTensor GSL, DeviceTensor GC, DeviceTensor X, DeviceTensor C, DeviceTensor S) { /* declarations of the variables */ int k, d, B, N; /* Get the index and channels */ d = blockIdx.x; k = blockIdx.y; B = X.getSize(0); N = X.getSize(1); /* main operation */ SL2GradXOp g(GSL,X,C,S); GC[k][d] = - reduceBN(g, k, d, B, N); } }// namespace at::Tensor Aggregate_Forward_CUDA( const at::Tensor A_, const at::Tensor X_, const at::Tensor C_) { /* Device tensors */ auto E_ = A_.type().tensor({A_.size(0), C_.size(0), C_.size(1)}).zero_(); cudaStream_t stream = at::globalContext().getCurrentCUDAStream(); // B, K, D dim3 blocks(C_.size(1), C_.size(0), X_.size(0)); dim3 threads(getNumThreads(X_.size(1))); AT_DISPATCH_FLOATING_TYPES(A_.type(), "Aggregate_Forward_CUDA", ([&] { DeviceTensor E = devicetensor(E_); DeviceTensor A = devicetensor(A_); DeviceTensor X = devicetensor(X_); DeviceTensor C = devicetensor(C_); /* kernel function */ Aggregate_Forward_kernel <<>>(E, A, X, C); })); AT_ASSERT(cudaGetLastError() == cudaSuccess); return E_; } std::vector Aggregate_Backward_CUDA( const at::Tensor GE_, const at::Tensor A_, const at::Tensor X_, const at::Tensor C_) { auto gradA_ = at::zeros_like(A_); auto gradX_ = at::bmm(A_, GE_); auto gradC_ = (-GE_ * A_.sum(1).unsqueeze(2)).sum(0); cudaStream_t stream = at::globalContext().getCurrentCUDAStream(); // B, K, D dim3 blocks(C_.size(0), X_.size(1), X_.size(0)); dim3 threads(getNumThreads(C_.size(1))); AT_DISPATCH_FLOATING_TYPES(A_.type(), "Aggregate_Backward_CUDA", ([&] { /* Device tensors */ DeviceTensor GA = devicetensor(gradA_); DeviceTensor GE = devicetensor(GE_); DeviceTensor A = devicetensor(A_); DeviceTensor X = devicetensor(X_); DeviceTensor C = devicetensor(C_); Aggregate_Backward_kernel <<>> (GA, GE, A, X, C); })); AT_ASSERT(cudaGetLastError() == cudaSuccess); return {gradA_, gradX_, gradC_}; } at::Tensor ScaledL2_Forward_CUDA( const at::Tensor X_, const at::Tensor C_, const at::Tensor S_) { auto SL_ = X_.type().tensor({X_.size(0), X_.size(1), C_.size(0)}).zero_(); cudaStream_t stream = at::globalContext().getCurrentCUDAStream(); dim3 blocks(C_.size(0), X_.size(1), X_.size(0)); dim3 threads(getNumThreads(C_.size(1))); AT_DISPATCH_FLOATING_TYPES(X_.type(), "ScaledL2_Forward_CUDA", ([&] { /* Device tensors */ DeviceTensor SL = devicetensor(SL_); DeviceTensor X = devicetensor(X_); DeviceTensor C = devicetensor(C_); DeviceTensor S = devicetensor(S_); /* kernel function */ ScaledL2_Forward_kernel <<>> (SL, X, C, S); })); AT_ASSERT(cudaGetLastError() == cudaSuccess); return SL_; } std::vector ScaledL2_Backward_CUDA( const at::Tensor GSL_, const at::Tensor X_, const at::Tensor C_, const at::Tensor S_, const at::Tensor SL_) { auto GX_ = at::zeros_like(X_); auto GC_ = at::zeros_like(C_); /* kernel function */ cudaStream_t stream = at::globalContext().getCurrentCUDAStream(); dim3 blocks1(X_.size(2), X_.size(1), X_.size(0)); dim3 threads1(getNumThreads(C_.size(0))); dim3 blocks2(C_.size(1), C_.size(0)); dim3 threads2(getNumThreads(X_.size(1))); //std::vector size{ 1, 1, K}; //auto GS_ = GSL_ * (SL_ / at::_unsafe_view(S_, size)) auto GS_ = (GSL_ * (SL_ / S_.view({1, 1, C_.size(0)}))).sum(0).sum(0); AT_DISPATCH_FLOATING_TYPES(X_.type(), "ScaledL2_Backward_CUDA", ([&] { /* Device tensors */ DeviceTensor GSL = devicetensor(GSL_); DeviceTensor GX = devicetensor(GX_); DeviceTensor GC = devicetensor(GC_); DeviceTensor X = devicetensor(X_); DeviceTensor C = devicetensor(C_); DeviceTensor S = devicetensor(S_); ScaledL2_GradX_kernel <<>> (GSL, GX, X, C, S); AT_ASSERT(cudaGetLastError() == cudaSuccess); ScaledL2_GradC_kernel <<>> (GSL, GC, X, C, S); AT_ASSERT(cudaGetLastError() == cudaSuccess); })); return {GX_, GC_, GS_}; }