Unverified Commit 0273d7ad authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

[syncBN] (#77)

adjusted kernel config for better perf.
removed divergence in welford warp reduction.
parents 5dad4c21 ee67e56a
......@@ -14,18 +14,19 @@
__device__ __forceinline__ int lastpow2(int n)
{
int out = 1 << (31 - __clz(n));
if(n == out)
if(n == out)
out >>= 1;
return out;
}
__host__ __forceinline__ int h_next_pow2(unsigned int n) {
unsigned int old = n;
n |= (n >> 1);
n |= (n >> 2);
n |= (n >> 4);
n |= (n >> 8);
n |= (n >> 16);
return n + 1;
return n == old? n : n + 1;
}
__host__ __forceinline__ int h_last_pow2(unsigned int n) {
......@@ -71,7 +72,7 @@ __device__ __forceinline__ T reduce_block(T *x, T val)
}
#define TILE_W 32
#define MAX_BLOCK_SIZE 256
#define MAX_BLOCK_SIZE 1024
template<typename T>
__device__ __forceinline__ void warp_reduce_mean_m2n(T &mean, T &m2n, int &num)
......@@ -81,12 +82,11 @@ __device__ __forceinline__ void warp_reduce_mean_m2n(T &mean, T &m2n, int &num)
auto num_new = __shfl_down_sync(0xffffffff, num, i);
auto mean_new = __shfl_down_sync(0xffffffff, mean, i);
auto m2n_new = __shfl_down_sync(0xffffffff, m2n, i);
if (num_new != 0) {
auto dif_mean = mean - mean_new;
mean = (mean_new * num_new + mean * num) / (num + num_new);
m2n += m2n_new + dif_mean*dif_mean*num*num_new/(num_new+num);
num += num_new;
}
T factor = 1.0 / max(1, (num+num_new));
auto dif_mean = mean - mean_new;
mean = (mean_new * num_new + mean * num)*factor;
m2n += m2n_new + dif_mean*dif_mean*num*num_new*factor;
num += num_new;
}
}
......@@ -159,11 +159,7 @@ __global__ void welford_kernel(
const int bs,
const int fs,
const int ss) {
static __shared__ int s_mem[160];
int block_size = blockDim.x * blockDim.y;
accscalar_t* s_mem_ac = (accscalar_t*) &s_mem[32];
int count = 0;
accscalar_t x_mean = accscalar_t(0);
accscalar_t m_2_n = accscalar_t(0);
......@@ -176,12 +172,15 @@ __global__ void welford_kernel(
for (int offset = threadIdx.x; offset < ss ; offset += blockDim.x) {
count++;
auto x_n = static_cast<accscalar_t>(input[offset+input_base]);
auto x_mean_new = x_mean + (x_n - x_mean) / count;
m_2_n = m_2_n + (x_n - x_mean_new) * (x_n - x_mean);
x_mean = x_mean_new;
auto d = x_n - x_mean;
x_mean += d / count;
m_2_n += d * (x_n - x_mean);
}
}
static __shared__ int s_mem[160];
accscalar_t* s_mem_ac = (accscalar_t*) &s_mem[32];
welford_reduce_mean_m2n<accscalar_t>(s_mem_ac, s_mem, x_mean, m_2_n, count, block_size, thread_id);
if (thread_id == 0) {
......@@ -201,16 +200,18 @@ __global__ void batchnorm_forward_kernel(
const layerscalar_t* __restrict__ shift,
scalar_t* __restrict__ out,
const int ss,
const int bs,
const float eps) {
int address_base = blockIdx.x*ss + blockIdx.y*gridDim.x*ss;
auto m_c = mean[blockIdx.x];
auto inv_std_c = static_cast<accscalar_t>(rsqrt(var[blockIdx.x] + eps));
auto w_c = static_cast<accscalar_t>(weight[blockIdx.x]);
auto s_c = static_cast<accscalar_t>(shift[blockIdx.x]);
for (int offset = threadIdx.x; offset < ss ; offset+= blockDim.x) {
out[address_base+offset] = static_cast<scalar_t>(w_c * (static_cast<accscalar_t>(input[address_base+offset]) - m_c ) * inv_std_c + s_c);
for (int batch_offset = blockIdx.y*blockDim.y + threadIdx.y; batch_offset < bs; batch_offset += gridDim.y*blockDim.y) {
int address_base = blockIdx.x*ss + batch_offset*gridDim.x*ss;
for (int offset = threadIdx.x + blockIdx.z*blockDim.x; offset < ss ; offset+= gridDim.z*blockDim.x) {
out[address_base+offset] = static_cast<scalar_t>(w_c * (static_cast<accscalar_t>(input[address_base+offset]) - m_c ) * inv_std_c + s_c);
}
}
}
......@@ -267,7 +268,7 @@ __global__ void reduce_bn_kernel(
sum_dy = reduce_block((accscalar_t*)s_mem, sum_dy);
__syncthreads();
sum_dy_xmu = reduce_block((accscalar_t*)s_mem, sum_dy_xmu);
if (thread_id == 0) {
grad_bias[blockIdx.x] = static_cast<layerscalar_t>(sum_dy);
grad_weight[blockIdx.x] = static_cast<layerscalar_t>(sum_dy_xmu * factor);
......@@ -288,17 +289,19 @@ __global__ void batchnorm_backward_kernel(
const accscalar_t* __restrict__ mean_dy_xmu,
scalar_t* __restrict__ grad_input,
const int ss,
const int bs,
const float eps) {
int address_base = blockIdx.x*ss + blockIdx.y*gridDim.x*ss;
auto m_c = static_cast<accscalar_t>(mean[blockIdx.x]);
auto m_dy_c = static_cast<accscalar_t>(mean_dy[blockIdx.x]);
auto factor_1_c = static_cast<accscalar_t>(var[blockIdx.x]) + eps;
auto factor_2_c = static_cast<accscalar_t>(weight[blockIdx.x]) / sqrt(factor_1_c);
factor_1_c /= static_cast<accscalar_t>(mean_dy_xmu[blockIdx.x]);
for (int offset = threadIdx.x; offset < ss ; offset+= blockDim.x) {
grad_input[address_base+offset] = (static_cast<accscalar_t>(grad_output[address_base+offset]) - m_dy_c - (static_cast<accscalar_t>(input[address_base+offset]) - m_c) / factor_1_c) * factor_2_c;
for (int batch_offset = blockIdx.y*blockDim.y+threadIdx.y; batch_offset < bs; batch_offset += gridDim.y*blockDim.y) {
int address_base = blockIdx.x*ss + batch_offset*gridDim.x*ss;
for (int offset = threadIdx.x + blockIdx.z*blockDim.x; offset < ss ; offset+= gridDim.z*blockDim.x) {
grad_input[address_base+offset] = (static_cast<accscalar_t>(grad_output[address_base+offset]) - m_dy_c - (static_cast<accscalar_t>(input[address_base+offset]) - m_c) / factor_1_c) * factor_2_c;
}
}
}
......@@ -322,7 +325,7 @@ __global__ void welford_kernel_parallel(
int input_base = blockIdx.x*ns + threadIdx.x;
int thread_id = threadIdx.x;
// load data;
// load data;
auto x_mean = static_cast<accscalar_t>(mean[input_base]);
auto m_2_n = static_cast<accscalar_t>(var_biased[input_base]) * numel;
auto count = numel;
......@@ -337,7 +340,7 @@ __global__ void welford_kernel_parallel(
out_var_biased[blockIdx.x] = static_cast<scalar_t>(m_2_n/count);
}
}
std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {
const auto batch_size = input.size(0);
......@@ -350,8 +353,8 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {
at::Tensor out_var_biased = at::empty({feature_size}, input.options().dtype(scalar_type));
at::Tensor out_mean = at::empty({feature_size}, input.options().dtype(scalar_type));
int block_x = TILE_W;
int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE / block_x));
int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE / 32));
int block_x = max(1, min(MAX_BLOCK_SIZE / block_y, h_last_pow2(space_size)));
const dim3 block(block_x, block_y);
const dim3 grid(feature_size);
......@@ -386,9 +389,12 @@ at::Tensor batchnorm_forward_CUDA(
auto space_size = get_tensor_spatial_size(input);
int block = min(MAX_BLOCK_SIZE, h_next_pow2(space_size)/4);
// TODO(jie): should I do 1 block per feature?
const dim3 grid(feature_size, batch_size);
int block_x = max(32, min(MAX_BLOCK_SIZE, h_last_pow2(space_size)/4));
int block_y = max(1, min(MAX_BLOCK_SIZE/block_x, h_last_pow2(batch_size)/4));
const dim3 block(block_x, block_y);
int grid_z = max(1, min(65535, h_last_pow2(space_size)/4/block_x));
int batch_group_size = max(1, min(65535, h_last_pow2(batch_size)/block_y));
const dim3 grid(feature_size, batch_group_size, grid_z);
auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half && weight.type().scalarType() == at::ScalarType::Float) {
......@@ -402,10 +408,11 @@ at::Tensor batchnorm_forward_CUDA(
shift.data<accscalar_t>(),
out.data<scalar_t>(),
space_size,
batch_size,
eps);
}));
} else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()");
AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()");
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_forward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_forward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
......@@ -416,6 +423,7 @@ at::Tensor batchnorm_forward_CUDA(
shift.data<scalar_t>(),
out.data<scalar_t>(),
space_size,
batch_size,
eps);
}));
}
......@@ -442,11 +450,10 @@ std::vector<at::Tensor> reduce_bn_CUDA(
auto space_size = get_tensor_spatial_size(input);
int block_x = TILE_W;
int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE / block_x));
int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE/ 32));
int block_x = max(1, min(MAX_BLOCK_SIZE/ block_y, h_last_pow2(space_size)));
const dim3 block(block_x, block_y);
const dim3 grid(feature_size);
// shared memory used for reduce on sum_dy, sum_dy_xmu;
auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half && weight.type().scalarType() == at::ScalarType::Float) {
......@@ -467,7 +474,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(
eps);
}));
} else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()");
AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()");
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward_reduce", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
reduce_bn_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
......@@ -485,7 +492,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(
eps);
}));
}
return {mean_dy, mean_dy_xmu, grad_weight, grad_bias};
}
......@@ -505,9 +512,13 @@ at::Tensor batchnorm_backward_CUDA(
auto space_size = get_tensor_spatial_size(input);
int block = min(MAX_BLOCK_SIZE, h_next_pow2(space_size)/4);
// TODO(jie): should I do 1 block per feature?
const dim3 grid(feature_size, batch_size);
int block_x = max(32, min(MAX_BLOCK_SIZE, h_last_pow2(space_size)/4));
int block_y = max(1, min(MAX_BLOCK_SIZE/block_x, h_last_pow2(batch_size)/4));
const dim3 block(block_x, block_y);
int grid_z = max(1, min(65535, h_last_pow2(space_size)/4/block_x));
int batch_group_size = max(1, min(65535, h_last_pow2(batch_size)/block_y));
const dim3 grid(feature_size, batch_group_size, grid_z);
auto stream = at::cuda::getCurrentCUDAStream();
if (input.type().scalarType() == at::ScalarType::Half && weight.type().scalarType() == at::ScalarType::Float) {
......@@ -523,10 +534,11 @@ at::Tensor batchnorm_backward_CUDA(
mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(),
space_size,
batch_size,
eps);
}));
} else {
AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()");
AT_CHECK(input.type().scalarType() == weight.type().scalarType(), "input.type().scalarType() is not supported with weight.type().scalarType()");
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "batchnorm_backward", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
batchnorm_backward_kernel<scalar_t, accscalar_t, scalar_t><<<grid, block, 0, stream>>>(
......@@ -539,10 +551,11 @@ at::Tensor batchnorm_backward_CUDA(
mean_dy_xmu.data<accscalar_t>(),
grad_input.data<scalar_t>(),
space_size,
batch_size,
eps);
}));
}
return grad_input;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment