Unverified Commit 0273d7ad authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

[syncBN] (#77)

adjusted kernel config for better perf.
removed divergence in welford warp reduction.
parents 5dad4c21 ee67e56a
...@@ -14,18 +14,19 @@ ...@@ -14,18 +14,19 @@
__device__ __forceinline__ int lastpow2(int n) __device__ __forceinline__ int lastpow2(int n)
{ {
int out = 1 << (31 - __clz(n)); int out = 1 << (31 - __clz(n));
if(n == out) if(n == out)
out >>= 1; out >>= 1;
return out; return out;
} }
__host__ __forceinline__ int h_next_pow2(unsigned int n) { __host__ __forceinline__ int h_next_pow2(unsigned int n) {
unsigned int old = n;
n |= (n >> 1); n |= (n >> 1);
n |= (n >> 2); n |= (n >> 2);
n |= (n >> 4); n |= (n >> 4);
n |= (n >> 8); n |= (n >> 8);
n |= (n >> 16); n |= (n >> 16);
return n + 1; return n == old? n : n + 1;
} }
__host__ __forceinline__ int h_last_pow2(unsigned int n) { __host__ __forceinline__ int h_last_pow2(unsigned int n) {
...@@ -71,7 +72,7 @@ __device__ __forceinline__ T reduce_block(T *x, T val) ...@@ -71,7 +72,7 @@ __device__ __forceinline__ T reduce_block(T *x, T val)
} }
#define TILE_W 32 #define TILE_W 32
#define MAX_BLOCK_SIZE 256 #define MAX_BLOCK_SIZE 1024
template<typename T> template<typename T>
__device__ __forceinline__ void warp_reduce_mean_m2n(T &mean, T &m2n, int &num) __device__ __forceinline__ void warp_reduce_mean_m2n(T &mean, T &m2n, int &num)
...@@ -81,12 +82,11 @@ __device__ __forceinline__ void warp_reduce_mean_m2n(T &mean, T &m2n, int &num) ...@@ -81,12 +82,11 @@ __device__ __forceinline__ void warp_reduce_mean_m2n(T &mean, T &m2n, int &num)
auto num_new = __shfl_down_sync(0xffffffff, num, i); auto num_new = __shfl_down_sync(0xffffffff, num, i);
auto mean_new = __shfl_down_sync(0xffffffff, mean, i); auto mean_new = __shfl_down_sync(0xffffffff, mean, i);
auto m2n_new = __shfl_down_sync(0xffffffff, m2n, i); auto m2n_new = __shfl_down_sync(0xffffffff, m2n, i);
if (num_new != 0) { T factor = 1.0 / max(1, (num+num_new));
auto dif_mean = mean - mean_new; auto dif_mean = mean - mean_new;
mean = (mean_new * num_new + mean * num) / (num + num_new); mean = (mean_new * num_new + mean * num)*factor;
m2n += m2n_new + dif_mean*dif_mean*num*num_new/(num_new+num); m2n += m2n_new + dif_mean*dif_mean*num*num_new*factor;
num += num_new; num += num_new;
}
} }
} }
...@@ -159,11 +159,7 @@ __global__ void welford_kernel( ...@@ -159,11 +159,7 @@ __global__ void welford_kernel(
const int bs, const int bs,
const int fs, const int fs,
const int ss) { const int ss) {
static __shared__ int s_mem[160];
int block_size = blockDim.x * blockDim.y; int block_size = blockDim.x * blockDim.y;
accscalar_t* s_mem_ac = (accscalar_t*) &s_mem[32];
int count = 0; int count = 0;
accscalar_t x_mean = accscalar_t(0); accscalar_t x_mean = accscalar_t(0);
accscalar_t m_2_n = accscalar_t(0); accscalar_t m_2_n = accscalar_t(0);
...@@ -176,12 +172,15 @@ __global__ void welford_kernel( ...@@ -176,12 +172,15 @@ __global__ void welford_kernel(
for (int offset = threadIdx.x; offset < ss ; offset += blockDim.x) { for (int offset = threadIdx.x; offset < ss ; offset += blockDim.x) {
count++; count++;
auto x_n = static_cast<accscalar_t>(input[offset+input_base]); auto x_n = static_cast<accscalar_t>(input[offset+input_base]);
auto x_mean_new = x_mean + (x_n - x_mean) / count; auto d = x_n - x_mean;
m_2_n = m_2_n + (x_n - x_mean_new) * (x_n - x_mean); x_mean += d / count;
x_mean = x_mean_new; m_2_n += d * (x_n - x_mean);
} }
} }
static __shared__ int s_mem[160];
accscalar_t* s_mem_ac = (accscalar_t*) &s_mem[32];
welford_reduce_mean_m2n<accscalar_t>(s_mem_ac, s_mem, x_mean, m_2_n, count, block_size, thread_id); welford_reduce_mean_m2n<accscalar_t>(s_mem_ac, s_mem, x_mean, m_2_n, count, block_size, thread_id);
if (thread_id == 0) { if (thread_id == 0) {
...@@ -201,16 +200,18 @@ __global__ void batchnorm_forward_kernel( ...@@ -201,16 +200,18 @@ __global__ void batchnorm_forward_kernel(
const layerscalar_t* __restrict__ shift, const layerscalar_t* __restrict__ shift,
scalar_t* __restrict__ out, scalar_t* __restrict__ out,
const int ss, const int ss,
const int bs,
const float eps) { const float eps) {
int address_base = blockIdx.x*ss + blockIdx.y*gridDim.x*ss;
auto m_c = mean[blockIdx.x]; auto m_c = mean[blockIdx.x];
auto inv_std_c = static_cast<accscalar_t>(rsqrt(var[blockIdx.x] + eps)); auto inv_std_c = static_cast<accscalar_t>(rsqrt(var[blockIdx.x] + eps));
auto w_c = static_cast<accscalar_t>(weight[blockIdx.x]); auto w_c = static_cast<accscalar_t>(weight[blockIdx.x]);
auto s_c = static_cast<accscalar_t>(shift[blockIdx.x]); auto s_c = static_cast<accscalar_t>(shift[blockIdx.x]);
for (int offset = threadIdx.x; offset < ss ; offset+= blockDim.x) { for (int batch_offset = blockIdx.y*blockDim.y + threadIdx.y; batch_offset < bs; batch_offset += gridDim.y*blockDim.y) {
out[address_base+offset] = static_cast<scalar_t>(w_c * (static_cast<accscalar_t>(input[address_base+offset]) - m_c ) * inv_std_c + s_c); int address_base = blockIdx.x*ss + batch_offset*gridDim.x*ss;
for (int offset = threadIdx.x + blockIdx.z*blockDim.x; offset < ss ; offset+= gridDim.z*blockDim.x) {
out[address_base+offset] = static_cast<scalar_t>(w_c * (static_cast<accscalar_t>(input[address_base+offset]) - m_c ) * inv_std_c + s_c);
}
} }
} }
...@@ -267,7 +268,7 @@ __global__ void reduce_bn_kernel( ...@@ -267,7 +268,7 @@ __global__ void reduce_bn_kernel(
sum_dy = reduce_block((accscalar_t*)s_mem, sum_dy); sum_dy = reduce_block((accscalar_t*)s_mem, sum_dy);
__syncthreads(); __syncthreads();
sum_dy_xmu = reduce_block((accscalar_t*)s_mem, sum_dy_xmu); sum_dy_xmu = reduce_block((accscalar_t*)s_mem, sum_dy_xmu);
if (thread_id == 0) { if (thread_id == 0) {
grad_bias[blockIdx.x] = static_cast<layerscalar_t>(sum_dy); grad_bias[blockIdx.x] = static_cast<layerscalar_t>(sum_dy);
grad_weight[blockIdx.x] = static_cast<layerscalar_t>(sum_dy_xmu * factor); grad_weight[blockIdx.x] = static_cast<layerscalar_t>(sum_dy_xmu * factor);
...@@ -288,17 +289,19 @@ __global__ void batchnorm_backward_kernel( ...@@ -288,17 +289,19 @@ __global__ void batchnorm_backward_kernel(
const accscalar_t* __restrict__ mean_dy_xmu, const accscalar_t* __restrict__ mean_dy_xmu,
scalar_t* __restrict__ grad_input, scalar_t* __restrict__ grad_input,
const int ss, const int ss,
const int bs,
const float eps) { const float eps) {
int address_base = blockIdx.x*ss + blockIdx.y*gridDim.x*ss;
auto m_c = static_cast<accscalar_t>(mean[blockIdx.x]); auto m_c = static_cast<accscalar_t>(mean[blockIdx.x]);
auto m_dy_c = static_cast<accscalar_t>(mean_dy[blockIdx.x]); auto m_dy_c = static_cast<accscalar_t>(mean_dy[blockIdx.x]);
auto factor_1_c = static_cast<accscalar_t>(var[blockIdx.x]) + eps; auto factor_1_c = static_cast<accscalar_t>(var[blockIdx.x]) + eps;
auto factor_2_c = static_cast<accscalar_t>(weight[blockIdx.x]) / sqrt(factor_1_c); auto factor_2_c = static_cast<accscalar_t>(weight[blockIdx.x]) / sqrt(factor_1_c);
factor_1_c /= static_cast<accscalar_t>(mean_dy_xmu[blockIdx.x]); factor_1_c /= static_cast<accscalar_t>(mean_dy_xmu[blockIdx.x]);
for (int offset = threadIdx.x; offset < ss ; offset+= blockDim.x) { for (int batch_offset = blockIdx.y*blockDim.y+threadIdx.y; batch_offset < bs; batch_offset += gridDim.y*blockDim.y) {
grad_input[address_base+offset] = (static_cast<accscalar_t>(grad_output[address_base+offset]) - m_dy_c - (static_cast<accscalar_t>(input[address_base+offset]) - m_c) / factor_1_c) * factor_2_c; int address_base = blockIdx.x*ss + batch_offset*gridDim.x*ss;
for (int offset = threadIdx.x + blockIdx.z*blockDim.x; offset < ss ; offset+= gridDim.z*blockDim.x) {
grad_input[address_base+offset] = (static_cast<accscalar_t>(grad_output[address_base+offset]) - m_dy_c - (static_cast<accscalar_t>(input[address_base+offset]) - m_c) / factor_1_c) * factor_2_c;
}
} }
} }
...@@ -322,7 +325,7 @@ __global__ void welford_kernel_parallel( ...@@ -322,7 +325,7 @@ __global__ void welford_kernel_parallel(
int input_base = blockIdx.x*ns + threadIdx.x; int input_base = blockIdx.x*ns + threadIdx.x;
int thread_id = threadIdx.x; int thread_id = threadIdx.x;
// load data; // load data;
auto x_mean = static_cast<accscalar_t>(mean[input_base]); auto x_mean = static_cast<accscalar_t>(mean[input_base]);
auto m_2_n = static_cast<accscalar_t>(var_biased[input_base]) * numel; auto m_2_n = static_cast<accscalar_t>(var_biased[input_base]) * numel;
auto count = numel; auto count = numel;
...@@ -337,7 +340,7 @@ __global__ void welford_kernel_parallel( ...@@ -337,7 +340,7 @@ __global__ void welford_kernel_parallel(
out_var_biased[blockIdx.x] = static_cast<scalar_t>(m_2_n/count); out_var_biased[blockIdx.x] = static_cast<scalar_t>(m_2_n/count);
} }
} }
std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) { std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {
const auto batch_size = input.size(0); const auto batch_size = input.size(0);
...@@ -350,8 +353,8 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) { ...@@ -350,8 +353,8 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {
at::Tensor out_var_biased = at::empty({feature_size}, input.options().dtype(scalar_type)); at::Tensor out_var_biased = at::empty({feature_size}, input.options().dtype(scalar_type));
at::Tensor out_mean = at::empty({feature_size}, input.options().dtype(scalar_type)); at::Tensor out_mean = at::empty({feature_size}, input.options().dtype(scalar_type));
int block_x = TILE_W; int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE / 32));
int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE / block_x)); int block_x = max(1, min(MAX_BLOCK_SIZE / block_y, h_last_pow2(space_size)));
const dim3 block(block_x, block_y); const dim3 block(block_x, block_y);
const dim3 grid(feature_size); const dim3 grid(feature_size);
...@@ -386,9 +389,12 @@ at::Tensor batchnorm_forward_CUDA( ...@@ -386,9 +389,12 @@ at::Tensor batchnorm_forward_CUDA(
auto space_size = get_tensor_spatial_size(input); auto space_size = get_tensor_spatial_size(input);
int block = min(MAX_BLOCK_SIZE, h_next_pow2(space_size)/4); int block_x = max(32, min(MAX_BLOCK_SIZE, h_last_pow2(space_size)/4));
// TODO(jie): should I do 1 block per feature? int block_y = max(1, min(MAX_BLOCK_SIZE/block_x, h_last_pow2(batch_size)/4));
const dim3 grid(feature_size, batch_size); const dim3 block(block_x, block_y);
int grid_z = max(1, min(65535, h_last_pow2(space_size)/4/block_x));
int batch_group_size = max(1, min(65535, h_last_pow2(batch_size)/block_y));
const dim3 grid(feature_size, batch_group_size, grid_z);
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half && weight.type().scalarType() == at::ScalarType::Float) { if (input.type().scalarType() == at::ScalarType::Half && weight.type().scalarType() == at::ScalarType::Float) {
...@@ -402,10 +408,11 @@ at::Tensor batchnorm_forward_CUDA( ...@@ -402,10 +408,11 @@ at::Tensor batchnorm_forward_CUDA(
shift.data<accscalar_t>(), shift.data<accscalar_t>(),
out.data<scalar_t>(), out.data<scalar_t>(),
space_size, space_size,
batch_size,
eps); eps);
})); }));
} else { } else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()"); AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()");
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>( batchnorm_forward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
...@@ -416,6 +423,7 @@ at::Tensor batchnorm_forward_CUDA( ...@@ -416,6 +423,7 @@ at::Tensor batchnorm_forward_CUDA(
shift.data<scalar_t>(), shift.data<scalar_t>(),
out.data<scalar_t>(), out.data<scalar_t>(),
space_size, space_size,
batch_size,
eps); eps);
})); }));
} }
...@@ -442,11 +450,10 @@ std::vector<at::Tensor> reduce_bn_CUDA( ...@@ -442,11 +450,10 @@ std::vector<at::Tensor> reduce_bn_CUDA(
auto space_size = get_tensor_spatial_size(input); auto space_size = get_tensor_spatial_size(input);
int block_x = TILE_W; int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE/ 32));
int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE / block_x)); int block_x = max(1, min(MAX_BLOCK_SIZE/ block_y, h_last_pow2(space_size)));
const dim3 block(block_x, block_y); const dim3 block(block_x, block_y);
const dim3 grid(feature_size); const dim3 grid(feature_size);
// shared memory used for reduce on sum_dy, sum_dy_xmu;
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half && weight.type().scalarType() == at::ScalarType::Float) { if (input.type().scalarType() == at::ScalarType::Half && weight.type().scalarType() == at::ScalarType::Float) {
...@@ -467,7 +474,7 @@ std::vector<at::Tensor> reduce_bn_CUDA( ...@@ -467,7 +474,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(
eps); eps);
})); }));
} else { } else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()"); AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()");
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
reduce_bn_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>( reduce_bn_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
...@@ -485,7 +492,7 @@ std::vector<at::Tensor> reduce_bn_CUDA( ...@@ -485,7 +492,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(
eps); eps);
})); }));
} }
return {mean_dy, mean_dy_xmu, grad_weight, grad_bias}; return {mean_dy, mean_dy_xmu, grad_weight, grad_bias};
} }
...@@ -505,9 +512,13 @@ at::Tensor batchnorm_backward_CUDA( ...@@ -505,9 +512,13 @@ at::Tensor batchnorm_backward_CUDA(
auto space_size = get_tensor_spatial_size(input); auto space_size = get_tensor_spatial_size(input);
int block = min(MAX_BLOCK_SIZE, h_next_pow2(space_size)/4); int block_x = max(32, min(MAX_BLOCK_SIZE, h_last_pow2(space_size)/4));
// TODO(jie): should I do 1 block per feature? int block_y = max(1, min(MAX_BLOCK_SIZE/block_x, h_last_pow2(batch_size)/4));
const dim3 grid(feature_size, batch_size); const dim3 block(block_x, block_y);
int grid_z = max(1, min(65535, h_last_pow2(space_size)/4/block_x));
int batch_group_size = max(1, min(65535, h_last_pow2(batch_size)/block_y));
const dim3 grid(feature_size, batch_group_size, grid_z);
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half && weight.type().scalarType() == at::ScalarType::Float) { if (input.type().scalarType() == at::ScalarType::Half && weight.type().scalarType() == at::ScalarType::Float) {
...@@ -523,10 +534,11 @@ at::Tensor batchnorm_backward_CUDA( ...@@ -523,10 +534,11 @@ at::Tensor batchnorm_backward_CUDA(
mean_dy_xmu.data<accscalar_t>(), mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(), grad_input.data<scalar_t>(),
space_size, space_size,
batch_size,
eps); eps);
})); }));
} else { } else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()"); AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()");
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>; using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>( batchnorm_backward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
...@@ -539,10 +551,11 @@ at::Tensor batchnorm_backward_CUDA( ...@@ -539,10 +551,11 @@ at::Tensor batchnorm_backward_CUDA(
mean_dy_xmu.data<accscalar_t>(), mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(), grad_input.data<scalar_t>(),
space_size, space_size,
batch_size,
eps); eps);
})); }));
} }
return grad_input; return grad_input;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment