#include #include #include #include "common.h" #include "device_tensor.h" namespace { template struct GradOp { __device__ GradOp(Acctype m, const DeviceTensor3 i, const DeviceTensor3 g) : mean(m), input(i), gradOutput(g) {} __device__ __forceinline__ Float2 operator()(int batch, int plane, int n) { DType g = gradOutput[batch][plane][n]; DType c = ScalarConvert::to(input[batch][plane][n] - mean); return Float2(g, g * c); } const Acctype mean; const DeviceTensor3 input; const DeviceTensor3 gradOutput; }; template struct SumOp { __device__ SumOp(DeviceTensor i) : input(i){} __device__ __forceinline__ Float2 operator()(int batch, int plane, int n) { DType g = input[batch][plane][n]; return Float2(g, g * g); } DType mean; DeviceTensor input; }; // Sum across (batch, x/y/z) applying Op() pointwise template __device__ T reduce(Op op, DeviceTensor3 tensor, int plane) { T sum = (T)0; for (int batch = 0; batch < tensor.getSize(0); ++batch) { for (int x = threadIdx.x; x < tensor.getSize(2); x += blockDim.x) { sum += op(batch, plane, x); } } // sum over NumThreads within a warp sum = warpSum(sum); // 'transpose', and reduce within warp again __shared__ T shared[32]; __syncthreads(); if (threadIdx.x % WARP_SIZE == 0) { shared[threadIdx.x / WARP_SIZE] = sum; } if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { // zero out the other entries in shared shared[threadIdx.x] = (T)0; } __syncthreads(); if (threadIdx.x / WARP_SIZE == 0) { sum = warpSum(shared[threadIdx.x]); if (threadIdx.x == 0) { shared[0] = sum; } } __syncthreads(); // Everyone picks it up, should be broadcast into the whole gradInput return shared[0]; } template __global__ void BatchNorm_Forward_kernel ( DeviceTensor output, DeviceTensor input, DeviceTensor mean, DeviceTensor std, DeviceTensor gamma, DeviceTensor beta) { int c = blockIdx.x; /* main operation */ for (int b = 0; b < input.getSize(0); ++b) { for (int x = threadIdx.x; x < input.getSize(2); x += blockDim.x) { DType inp = input[b][c][x]; output[b][c][x] = gamma[c] * (inp - mean[c]) / std[c] + beta[c]; } } } template __global__ void BatchNorm_Backward_kernel ( DeviceTensor gradoutput, DeviceTensor input, DeviceTensor gradinput, DeviceTensor gradgamma, DeviceTensor gradbeta, DeviceTensor mean, DeviceTensor std, DeviceTensor gamma, DeviceTensor beta, DeviceTensor gradMean, DeviceTensor gradStd, bool train) { /* declarations of the variables */ /* Get the index and channels */ int c = blockIdx.x; /* main operation */ GradOp> g(mean[c], input, gradoutput); Float2 res = reduce, GradOp>, DeviceTensor>(g, gradoutput, c); DType gradOutputSum = res.v1; DType dotP = res.v2; DType invstd = DType(1.0) / std[c]; DType gradScale = invstd * gamma[c]; if (train && threadIdx.x == 0) { gradMean[c] = - gradOutputSum * gamma[c] * invstd; gradStd[c] = - dotP * gamma[c] * invstd * invstd; } if (gradinput.numElements() > 0) { for (int batch = 0; batch < gradoutput.getSize(0); ++batch) { for (int x = threadIdx.x; x < gradoutput.getSize(2); x += blockDim.x) { gradinput[batch][c][x] = gradoutput[batch][c][x] * gradScale; } } } if (gradgamma.numElements() > 0) { if (threadIdx.x == 0) { gradgamma[c] += dotP * invstd; } } if (gradbeta.numElements() > 0) { if (threadIdx.x == 0) { gradbeta[c] += gradOutputSum; } } } template __global__ void Sum_Square_Forward_kernel ( DeviceTensor input, DeviceTensor sum, DeviceTensor square) { int c = blockIdx.x; /* main operation */ SumOp g(input); Float2 res = reduce, SumOp, DeviceTensor>(g, input, c); DType xsum = res.v1; DType xsquare = res.v2; if (threadIdx.x == 0) { sum[c] = xsum; square[c] = xsquare; } } template __global__ void Sum_Square_Backward_kernel ( DeviceTensor gradInput, DeviceTensor input, DeviceTensor gradSum, DeviceTensor gradSquare) { int c = blockIdx.x; /* main operation */ for (int batch = 0; batch < gradInput.getSize(0); ++batch) { for (int x = threadIdx.x; x < gradInput.getSize(2); x += blockDim.x) { gradInput[batch][c][x] = gradSum[c] + 2 * gradSquare[c] * input[batch][c][x]; } } } } // namespcae at::Tensor BatchNorm_Forward_CUDA( const at::Tensor input_, const at::Tensor mean_, const at::Tensor std_, const at::Tensor gamma_, const at::Tensor beta_) { auto output_ = at::zeros_like(input_); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); dim3 blocks(input_.size(1)); dim3 threads(getNumThreads(input_.size(2))); AT_DISPATCH_FLOATING_TYPES(input_.type(), "BatchNorm_Forward_CUDA", ([&] { /* Device tensors */ DeviceTensor output = devicetensor(output_); DeviceTensor input = devicetensor(input_); DeviceTensor mean = devicetensor(mean_); DeviceTensor std = devicetensor(std_); DeviceTensor gamma = devicetensor(gamma_); DeviceTensor beta = devicetensor(beta_); /* kernel function */ BatchNorm_Forward_kernel<<>>( output, input, mean, std, gamma, beta); })); AT_ASSERT(cudaGetLastError() == cudaSuccess); return output_; } std::vector BatchNorm_Backward_CUDA( const at::Tensor gradoutput_, const at::Tensor input_, const at::Tensor mean_, const at::Tensor std_, const at::Tensor gamma_, const at::Tensor beta_, bool train) { /* outputs*/ at::Tensor gradinput_ = at::zeros_like(input_); at::Tensor gradgamma_ = at::zeros_like(gamma_); at::Tensor gradbeta_ = at::zeros_like(beta_); at::Tensor gradMean_ = at::zeros_like(mean_); at::Tensor gradStd_ = at::zeros_like(std_); /* cuda utils*/ cudaStream_t stream = at::cuda::getCurrentCUDAStream(); dim3 blocks(input_.size(1)); dim3 threads(getNumThreads(input_.size(2))); AT_DISPATCH_FLOATING_TYPES(input_.type(), "BatchNorm_Backward_CUDA", ([&] { /* Device tensors */ DeviceTensor gradoutput = devicetensor(gradoutput_); DeviceTensor input = devicetensor(input_); DeviceTensor gradinput = devicetensor(gradinput_); DeviceTensor gradgamma = devicetensor(gradgamma_); DeviceTensor gradbeta = devicetensor(gradbeta_); DeviceTensor mean = devicetensor(mean_); DeviceTensor std = devicetensor(std_); DeviceTensor gamma = devicetensor(gamma_); DeviceTensor beta = devicetensor(beta_); DeviceTensor gradMean = devicetensor(gradMean_); DeviceTensor gradStd = devicetensor(gradStd_); /* kernel function */ BatchNorm_Backward_kernel <<>>( gradoutput, input, gradinput, gradgamma, gradbeta, mean, std, gamma, beta, gradMean, gradStd, train); })); AT_ASSERT(cudaGetLastError() == cudaSuccess); return {gradinput_, gradMean_, gradStd_, gradgamma_, gradbeta_}; } std::vector Sum_Square_Forward_CUDA( const at::Tensor input_) { /* outputs */ at::Tensor sum_ = input_.type().tensor({input_.size(1)}).zero_(); at::Tensor square_ = input_.type().tensor({input_.size(1)}).zero_(); /* cuda utils*/ cudaStream_t stream = at::cuda::getCurrentCUDAStream(); dim3 blocks(input_.size(1)); dim3 threads(getNumThreads(input_.size(2))); AT_DISPATCH_FLOATING_TYPES(input_.type(), "SumSquare_forward_CUDA", ([&] { /* Device tensors */ DeviceTensor input = devicetensor(input_); DeviceTensor sum = devicetensor(sum_); DeviceTensor square = devicetensor(square_); /* kernel function */ Sum_Square_Forward_kernel <<>>(input, sum, square); })); AT_ASSERT(cudaGetLastError() == cudaSuccess); return {sum_, square_}; } at::Tensor Sum_Square_Backward_CUDA( const at::Tensor input_, const at::Tensor gradSum_, const at::Tensor gradSquare_) { /* outputs */ at::Tensor gradInput_ = at::zeros_like(input_); /* cuda utils*/ cudaStream_t stream = at::cuda::getCurrentCUDAStream(); dim3 blocks(input_.size(1)); dim3 threads(getNumThreads(input_.size(2))); AT_DISPATCH_FLOATING_TYPES(input_.type(), "SumSquare_Backward_CUDA", ([&] { /* Device tensors */ DeviceTensor gradInput = devicetensor(gradInput_); DeviceTensor input = devicetensor(input_); DeviceTensor gradSum = devicetensor(gradSum_); DeviceTensor gradSquare =devicetensor(gradSquare_); /* kernel function */ Sum_Square_Backward_kernel <<>>(gradInput, input, gradSum, gradSquare); })); AT_ASSERT(cudaGetLastError() == cudaSuccess); return gradInput_; }