Unverified Commit 3f49dbf0 authored by Jeff Daily's avatar Jeff Daily Committed by GitHub
Browse files

fix bugs in syncbn (#46)

- incorrect use of __shfl_down
- fix warp size assumptions
- update unit tests to exit on failure
parent c1e88fae
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
#include "compat.h" #include "compat.h"
#if defined __HIP_PLATFORM_HCC__ #if defined __HIP_PLATFORM_HCC__
#define SHFL_DOWN __shfl_down #define SHFL_DOWN(mask,val,i) __shfl_down(val, i)
#else #else
#define SHFL_DOWN __shfl_down_sync #define SHFL_DOWN __shfl_down_sync
#endif #endif
...@@ -44,8 +44,11 @@ __host__ __forceinline__ int h_last_pow2(unsigned int n) { ...@@ -44,8 +44,11 @@ __host__ __forceinline__ int h_last_pow2(unsigned int n) {
return n - (n >> 1); return n - (n >> 1);
} }
#ifdef __HIP_PLATFORM_HCC__
#define WARP_SIZE 64
#else
#define WARP_SIZE 32 #define WARP_SIZE 32
#endif
template<typename T> template<typename T>
__device__ __forceinline__ T warp_reduce_sum(T val) __device__ __forceinline__ T warp_reduce_sum(T val)
...@@ -61,25 +64,27 @@ __device__ __forceinline__ T reduce_block(T *x, T val) ...@@ -61,25 +64,27 @@ __device__ __forceinline__ T reduce_block(T *x, T val)
{ {
int tid = threadIdx.y*blockDim.x + threadIdx.x; int tid = threadIdx.y*blockDim.x + threadIdx.x;
int blockSize = blockDim.x * blockDim.y; int blockSize = blockDim.x * blockDim.y;
int lane = tid % WARP_SIZE;
int wid = tid / WARP_SIZE;
if (blockSize > 32) { if (blockSize > WARP_SIZE) {
val = warp_reduce_sum(val); val = warp_reduce_sum(val);
if (tid % WARP_SIZE == 0) if (lane == 0)
x[tid/WARP_SIZE] = val; x[wid] = val;
__syncthreads(); __syncthreads();
val = (tid < blockSize / WARP_SIZE? x[tid%WARP_SIZE] : T(0)); val = (tid < blockSize / WARP_SIZE? x[lane] : T(0));
} }
if(tid/WARP_SIZE==0) val = warp_reduce_sum(val); if(wid==0) val = warp_reduce_sum(val);
return val; return val;
} }
#define ELEMENTS_PER_ITER 4 // enables concurrency within each thread to hide latency #define ELEMENTS_PER_ITER 4 // enables concurrency within each thread to hide latency
#define ELEMENTS_PER_THREAD 16 #define ELEMENTS_PER_THREAD 16
#define OPTIMAL_TILE_W 32 #define OPTIMAL_TILE_W WARP_SIZE
#define MAX_H_BLOCK 128 #define MAX_H_BLOCK 128
#define MAX_BLOCK_SIZE 512 #define MAX_BLOCK_SIZE 512
...@@ -137,11 +142,7 @@ __device__ __forceinline__ void warp_reduce_mean_m2n(T &mean, T &m2n, int &num) ...@@ -137,11 +142,7 @@ __device__ __forceinline__ void warp_reduce_mean_m2n(T &mean, T &m2n, int &num)
auto num_new = SHFL_DOWN(0xffffffff, num, i); auto num_new = SHFL_DOWN(0xffffffff, num, i);
auto mean_new = SHFL_DOWN(0xffffffff, mean, i); auto mean_new = SHFL_DOWN(0xffffffff, mean, i);
auto m2n_new = SHFL_DOWN(0xffffffff, m2n, i); auto m2n_new = SHFL_DOWN(0xffffffff, m2n, i);
#if defined __HIP_PLATFORM_HCC__
welford_merge_element<T, int>(num, mean, m2n, num_new, mean_new, m2n_new);
#else
welford_merge_element(num, mean, m2n, num_new, mean_new, m2n_new); welford_merge_element(num, mean, m2n, num_new, mean_new, m2n_new);
#endif
} }
} }
...@@ -158,7 +159,7 @@ __device__ void welford_reduce_mean_m2n( ...@@ -158,7 +159,7 @@ __device__ void welford_reduce_mean_m2n(
int lane = thread_id % WARP_SIZE; int lane = thread_id % WARP_SIZE;
int wid = thread_id / WARP_SIZE; int wid = thread_id / WARP_SIZE;
if (block_size > 32) { if (block_size > WARP_SIZE) {
warp_reduce_mean_m2n(mean, m2n, num); warp_reduce_mean_m2n(mean, m2n, num);
if (lane == 0) { if (lane == 0) {
x[wid*2] = mean; x[wid*2] = mean;
...@@ -265,6 +266,9 @@ __device__ __forceinline__ void merge_block_vertical(T& sum_dy, ...@@ -265,6 +266,9 @@ __device__ __forceinline__ void merge_block_vertical(T& sum_dy,
// welford kernel calculating mean/biased_variance/unbiased_variance // welford kernel calculating mean/biased_variance/unbiased_variance
template <typename scalar_t, typename accscalar_t, typename outscalar_t> template <typename scalar_t, typename accscalar_t, typename outscalar_t>
#ifdef __HIP_PLATFORM_HCC__
__launch_bounds__(MAX_BLOCK_SIZE)
#endif
__global__ void welford_kernel( __global__ void welford_kernel(
const scalar_t* __restrict__ input, const scalar_t* __restrict__ input,
outscalar_t* __restrict__ out_mean, outscalar_t* __restrict__ out_mean,
...@@ -291,8 +295,8 @@ __global__ void welford_kernel( ...@@ -291,8 +295,8 @@ __global__ void welford_kernel(
} }
} }
static __shared__ int s_mem[160]; static __shared__ int s_mem[WARP_SIZE];
accscalar_t* s_mem_ac = (accscalar_t*) &s_mem[32]; static __shared__ accscalar_t s_mem_ac[WARP_SIZE*2];
welford_reduce_mean_m2n<accscalar_t>(s_mem_ac, s_mem, x_mean, m_2_n, count, block_size, thread_id); welford_reduce_mean_m2n<accscalar_t>(s_mem_ac, s_mem, x_mean, m_2_n, count, block_size, thread_id);
...@@ -304,6 +308,9 @@ __global__ void welford_kernel( ...@@ -304,6 +308,9 @@ __global__ void welford_kernel(
// elementwise BN kernel // elementwise BN kernel
template <typename scalar_t, typename accscalar_t, typename layerscalar_t> template <typename scalar_t, typename accscalar_t, typename layerscalar_t>
#ifdef __HIP_PLATFORM_HCC__
__launch_bounds__(MAX_BLOCK_SIZE)
#endif
__global__ void batchnorm_forward_kernel( __global__ void batchnorm_forward_kernel(
const scalar_t* __restrict__ input, const scalar_t* __restrict__ input,
const accscalar_t* __restrict__ mean, const accscalar_t* __restrict__ mean,
...@@ -331,6 +338,9 @@ __global__ void batchnorm_forward_kernel( ...@@ -331,6 +338,9 @@ __global__ void batchnorm_forward_kernel(
// Breaking the grad_input to two step to support sync BN, which requires all // Breaking the grad_input to two step to support sync BN, which requires all
// reduce of the intermediate results across processes. // reduce of the intermediate results across processes.
template <typename scalar_t, typename accscalar_t, typename layerscalar_t> template <typename scalar_t, typename accscalar_t, typename layerscalar_t>
#ifdef __HIP_PLATFORM_HCC__
__launch_bounds__(MAX_BLOCK_SIZE)
#endif
__global__ void reduce_bn_kernel( __global__ void reduce_bn_kernel(
const scalar_t* __restrict__ input, const scalar_t* __restrict__ input,
const scalar_t* __restrict__ grad_output, const scalar_t* __restrict__ grad_output,
...@@ -343,7 +353,7 @@ __global__ void reduce_bn_kernel( ...@@ -343,7 +353,7 @@ __global__ void reduce_bn_kernel(
const int bs, const int bs,
const int fs, const int fs,
const int ss) { const int ss) {
static __shared__ int s_mem[64]; static __shared__ int s_mem[WARP_SIZE];
//int total_item_num = bs * ss; //int total_item_num = bs * ss;
int thread_id = threadIdx.y*blockDim.x + threadIdx.x; int thread_id = threadIdx.y*blockDim.x + threadIdx.x;
...@@ -395,6 +405,9 @@ __global__ void reduce_bn_kernel( ...@@ -395,6 +405,9 @@ __global__ void reduce_bn_kernel(
// elementwise backward BN kernel // elementwise backward BN kernel
template <typename scalar_t, typename accscalar_t, typename layerscalar_t> template <typename scalar_t, typename accscalar_t, typename layerscalar_t>
#ifdef __HIP_PLATFORM_HCC__
__launch_bounds__(MAX_BLOCK_SIZE)
#endif
__global__ void batchnorm_backward_kernel( __global__ void batchnorm_backward_kernel(
const scalar_t* __restrict__ grad_output, const scalar_t* __restrict__ grad_output,
const scalar_t* __restrict__ input, const scalar_t* __restrict__ input,
...@@ -434,6 +447,9 @@ template ...@@ -434,6 +447,9 @@ template
typename accscalar_t, typename accscalar_t,
typename outscalar_t, typename outscalar_t,
int PARALLEL_LOADS> int PARALLEL_LOADS>
#ifdef __HIP_PLATFORM_HCC__
__launch_bounds__(MAX_BLOCK_SIZE)
#endif
__global__ void __global__ void
welford_kernel_c_last( welford_kernel_c_last(
const scalar_t* __restrict__ input, const scalar_t* __restrict__ input,
...@@ -575,6 +591,9 @@ welford_kernel_c_last( ...@@ -575,6 +591,9 @@ welford_kernel_c_last(
// parallel welford kernel to further reduce mean / biased_var // parallel welford kernel to further reduce mean / biased_var
// into mean / unbiased_var / inv_std across multiple processes. // into mean / unbiased_var / inv_std across multiple processes.
template <typename scalar_t> template <typename scalar_t>
#ifdef __HIP_PLATFORM_HCC__
__launch_bounds__(MAX_BLOCK_SIZE)
#endif
__global__ void welford_kernel_parallel( __global__ void welford_kernel_parallel(
const scalar_t* __restrict__ mean, const scalar_t* __restrict__ mean,
const scalar_t* __restrict__ var_biased, const scalar_t* __restrict__ var_biased,
...@@ -608,6 +627,9 @@ template < ...@@ -608,6 +627,9 @@ template <
typename accscalar_t, typename accscalar_t,
typename layerscalar_t, typename layerscalar_t,
int PARALLEL_LOADS> int PARALLEL_LOADS>
#ifdef __HIP_PLATFORM_HCC__
__launch_bounds__(MAX_BLOCK_SIZE)
#endif
__global__ void batchnorm_forward_c_last_kernel( __global__ void batchnorm_forward_c_last_kernel(
const scalar_t* __restrict__ input, const scalar_t* __restrict__ input,
const scalar_t* __restrict__ z, const scalar_t* __restrict__ z,
...@@ -658,6 +680,9 @@ template < ...@@ -658,6 +680,9 @@ template <
typename accscalar_t, typename accscalar_t,
typename layerscalar_t, typename layerscalar_t,
int PARALLEL_LOADS> int PARALLEL_LOADS>
#ifdef __HIP_PLATFORM_HCC__
__launch_bounds__(MAX_BLOCK_SIZE)
#endif
__global__ void relu_backward_c_last_kernel( __global__ void relu_backward_c_last_kernel(
const scalar_t* __restrict__ grad_output, const scalar_t* __restrict__ grad_output,
const scalar_t* __restrict__ input, const scalar_t* __restrict__ input,
...@@ -708,6 +733,9 @@ template ...@@ -708,6 +733,9 @@ template
typename accscalar_t, typename accscalar_t,
typename layerscalar_t, typename layerscalar_t,
int PARALLEL_LOADS> int PARALLEL_LOADS>
#ifdef __HIP_PLATFORM_HCC__
__launch_bounds__(MAX_BLOCK_SIZE)
#endif
__global__ void reduce_bn_c_last_kernel( __global__ void reduce_bn_c_last_kernel(
const scalar_t* __restrict__ input, const scalar_t* __restrict__ input,
const scalar_t* __restrict__ grad_output, const scalar_t* __restrict__ grad_output,
...@@ -861,6 +889,9 @@ template < ...@@ -861,6 +889,9 @@ template <
typename accscalar_t, typename accscalar_t,
typename layerscalar_t, typename layerscalar_t,
int PARALLEL_LOADS> int PARALLEL_LOADS>
#ifdef __HIP_PLATFORM_HCC__
__launch_bounds__(MAX_BLOCK_SIZE)
#endif
__global__ void batchnorm_backward_c_last_kernel( __global__ void batchnorm_backward_c_last_kernel(
const scalar_t* __restrict__ grad_output, const scalar_t* __restrict__ grad_output,
const scalar_t* __restrict__ input, const scalar_t* __restrict__ input,
...@@ -921,7 +952,7 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) { ...@@ -921,7 +952,7 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {
at::Tensor out_var_biased = at::empty({feature_size}, input.options().dtype(scalar_type)); at::Tensor out_var_biased = at::empty({feature_size}, input.options().dtype(scalar_type));
at::Tensor out_mean = at::empty({feature_size}, input.options().dtype(scalar_type)); at::Tensor out_mean = at::empty({feature_size}, input.options().dtype(scalar_type));
int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE / 32)); int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE / WARP_SIZE));
int block_x = max(1, min(MAX_BLOCK_SIZE / block_y, h_last_pow2(space_size))); int block_x = max(1, min(MAX_BLOCK_SIZE / block_y, h_last_pow2(space_size)));
const dim3 block(block_x, block_y); const dim3 block(block_x, block_y);
const dim3 grid(feature_size); const dim3 grid(feature_size);
...@@ -957,7 +988,7 @@ at::Tensor batchnorm_forward_CUDA( ...@@ -957,7 +988,7 @@ at::Tensor batchnorm_forward_CUDA(
auto space_size = get_tensor_spatial_size(input); auto space_size = get_tensor_spatial_size(input);
int block_x = max(32, min(MAX_BLOCK_SIZE, h_last_pow2(space_size)/4)); int block_x = max(WARP_SIZE, min(MAX_BLOCK_SIZE, h_last_pow2(space_size)/4));
int block_y = max(1, min(MAX_BLOCK_SIZE/block_x, h_last_pow2(batch_size)/4)); int block_y = max(1, min(MAX_BLOCK_SIZE/block_x, h_last_pow2(batch_size)/4));
const dim3 block(block_x, block_y); const dim3 block(block_x, block_y);
int grid_z = max(1, min(65535, h_last_pow2(space_size)/4/block_x)); int grid_z = max(1, min(65535, h_last_pow2(space_size)/4/block_x));
...@@ -1030,7 +1061,7 @@ std::vector<at::Tensor> reduce_bn_CUDA( ...@@ -1030,7 +1061,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(
auto space_size = get_tensor_spatial_size(input); auto space_size = get_tensor_spatial_size(input);
int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE/ 32)); int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE/ WARP_SIZE));
int block_x = max(1, min(MAX_BLOCK_SIZE/ block_y, h_last_pow2(space_size))); int block_x = max(1, min(MAX_BLOCK_SIZE/ block_y, h_last_pow2(space_size)));
const dim3 block(block_x, block_y); const dim3 block(block_x, block_y);
const dim3 grid(feature_size); const dim3 grid(feature_size);
...@@ -1097,7 +1128,7 @@ at::Tensor batchnorm_backward_CUDA( ...@@ -1097,7 +1128,7 @@ at::Tensor batchnorm_backward_CUDA(
auto space_size = get_tensor_spatial_size(input); auto space_size = get_tensor_spatial_size(input);
int block_x = max(32, min(MAX_BLOCK_SIZE, h_last_pow2(space_size)/4)); int block_x = max(WARP_SIZE, min(MAX_BLOCK_SIZE, h_last_pow2(space_size)/4));
int block_y = max(1, min(MAX_BLOCK_SIZE/block_x, h_last_pow2(batch_size)/4)); int block_y = max(1, min(MAX_BLOCK_SIZE/block_x, h_last_pow2(batch_size)/4));
const dim3 block(block_x, block_y); const dim3 block(block_x, block_y);
int grid_z = max(1, min(65535, h_last_pow2(space_size)/4/block_x)); int grid_z = max(1, min(65535, h_last_pow2(space_size)/4/block_x));
......
...@@ -109,3 +109,4 @@ if sbn_result: ...@@ -109,3 +109,4 @@ if sbn_result:
else: else:
print("*SBN single gpu failed*") print("*SBN single gpu failed*")
assert sbn_result
...@@ -157,3 +157,6 @@ if sbn_result_c_last: ...@@ -157,3 +157,6 @@ if sbn_result_c_last:
print("====SBN channel last single gpu passed tests") print("====SBN channel last single gpu passed tests")
else: else:
print("*SBN channel last single gpu failed*") print("*SBN channel last single gpu failed*")
assert sbn_result
assert sbn_result_c_last
...@@ -60,7 +60,11 @@ inp = np.random.randn(batch_size, feature_size, space_size, space_size).astype(d ...@@ -60,7 +60,11 @@ inp = np.random.randn(batch_size, feature_size, space_size, space_size).astype(d
grad = np.random.randn(batch_size, feature_size, space_size, space_size).astype(dtype) grad = np.random.randn(batch_size, feature_size, space_size, space_size).astype(dtype)
weight = np.random.randn(feature_size).astype(dtype) weight = np.random.randn(feature_size).astype(dtype)
bias = np.random.randn(feature_size).astype(dtype) bias = np.random.randn(feature_size).astype(dtype)
#count = torch.cuda.IntTensor([batch_size*space_size**2])
count = [ space_size**2 * ( (i+1) * batch_size // args.world_size - i * batch_size // args.world_size ) for i in range(0, args.world_size)]
count = torch.cuda.IntTensor(count)
print("--- count : " , count)
type_tensor = torch.cuda.FloatTensor type_tensor = torch.cuda.FloatTensor
if args.fp16: if args.fp16:
...@@ -153,7 +157,7 @@ mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).co ...@@ -153,7 +157,7 @@ mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).co
grad_input_r = (grad_output2_r - mean_dy_r.view(-1, 1, 1) - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1,1,1) + eps) * mean_dy_xmu_r.view(-1, 1, 1) ) * torch.rsqrt(b_v.view(-1,1,1) + eps) * weight_r.view(-1,1,1) grad_input_r = (grad_output2_r - mean_dy_r.view(-1, 1, 1) - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1,1,1) + eps) * mean_dy_xmu_r.view(-1, 1, 1) ) * torch.rsqrt(b_v.view(-1,1,1) + eps) * weight_r.view(-1,1,1)
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, inv_std, weight_t) mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, inv_std, weight_t)
grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, inv_std, weight_t, mean_dy, mean_dy_xmu) grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, inv_std, weight_t, mean_dy, mean_dy_xmu, count)
if args.local_rank == 0: if args.local_rank == 0:
sbn_result = compare("comparing bias grad: ", grad_bias, grad_bias_r, error) and sbn_result sbn_result = compare("comparing bias grad: ", grad_bias, grad_bias_r, error) and sbn_result
......
...@@ -178,3 +178,5 @@ if sbn_result: ...@@ -178,3 +178,5 @@ if sbn_result:
print("====SBN two gpu passed tests") print("====SBN two gpu passed tests")
else: else:
print("*SBN two gpu failed*") print("*SBN two gpu failed*")
assert sbn_result
python python_single_gpu_unit_test.py python python_single_gpu_unit_test.py || exit 1
python single_gpu_unit_test.py python single_gpu_unit_test.py || exit 1
python test_batchnorm1d.py python test_batchnorm1d.py || exit 1
python -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py python -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py || exit 1
python -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py --fp16 python -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py --fp16 || exit 1
python -m torch.distributed.launch --nproc_per_node=2 two_gpu_test_different_batch_size.py --apex python -m torch.distributed.launch --nproc_per_node=2 two_gpu_test_different_batch_size.py --apex || exit 1
#beware, you need a system with at least 4 gpus to test group_size<world_size #beware, you need a system with at least 4 gpus to test group_size<world_size
#python -m torch.distributed.launch --nproc_per_node=4 test_groups.py --group_size=2 #python -m torch.distributed.launch --nproc_per_node=4 test_groups.py --group_size=2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment