Commit cdc17060 authored by luise.chen's avatar luise.chen Committed by flyingdown
Browse files

Luise/gbn optimization (#105)

* GroupBN: Reduced buffering for better hiding calculations in some loops of length OUTER_LOOPS

* GroupBN: Use C_ELEMENTS_PER_CTA=64 for BN and BN_relu kernels for improvement of resnet50

* GroupBN: Use C_ELEMENTS_PER_CTA=64 for BN_add_relu kernels for ~10% E2E improvement of resnet50
parent 06053e19
...@@ -236,18 +236,18 @@ class NhwcBatchNorm { ...@@ -236,18 +236,18 @@ class NhwcBatchNorm {
// Kernel params // Kernel params
static const int USE_ONLINE_APPROACH = 1; static const int USE_ONLINE_APPROACH = 1;
static const int THREADS_PER_CTA = 512; static const int THREADS_PER_CTA = 512;
static const int THREADS_PER_PIXEL = 16; static const int THREADS_PER_PIXEL = 32;
static const int C_ELEMENTS_PER_CTA = 64; static const int C_ELEMENTS_PER_CTA = 128;
static const int ELEMENTS_PER_LDG = C_ELEMENTS_PER_CTA / THREADS_PER_PIXEL; static const int ELEMENTS_PER_LDG = C_ELEMENTS_PER_CTA / THREADS_PER_PIXEL;
static const int MAX_SMEM_WITHOUT_OPT_IN = 48 * 1024; static const int MAX_SMEM_WITHOUT_OPT_IN = 48 * 1024;
typedef uint16_t StorageType; typedef uint16_t StorageType;
//typedef float StorageType; //typedef float StorageType;
// increasing this to 6 causes spills in fwd kernel! // increasing this to 6 causes spills in fwd kernel!
static const int PIXELS_PER_THREAD_IN_REGISTERS_FWD = 5; static const int PIXELS_PER_THREAD_IN_REGISTERS_FWD = 1;
static const int PIXELS_PER_THREAD_IN_REGISTERS_BWD = 3; static const int PIXELS_PER_THREAD_IN_REGISTERS_BWD = 1;
static const int PIXELS_PER_THREAD_IN_SMEM_FWD = 10; static const int PIXELS_PER_THREAD_IN_SMEM_FWD = 0;
static const int PIXELS_PER_THREAD_IN_SMEM_BWD = 5; static const int PIXELS_PER_THREAD_IN_SMEM_BWD = 0;
static const int PIXELS_PER_THREAD_FWD = PIXELS_PER_THREAD_IN_REGISTERS_FWD + \ static const int PIXELS_PER_THREAD_FWD = PIXELS_PER_THREAD_IN_REGISTERS_FWD + \
PIXELS_PER_THREAD_IN_SMEM_FWD; PIXELS_PER_THREAD_IN_SMEM_FWD;
......
...@@ -248,17 +248,17 @@ class NhwcBatchNormAddRelu { ...@@ -248,17 +248,17 @@ class NhwcBatchNormAddRelu {
// Kernel params // Kernel params
static const int USE_ONLINE_APPROACH = 1; static const int USE_ONLINE_APPROACH = 1;
static const int THREADS_PER_CTA = 512; static const int THREADS_PER_CTA = 512;
static const int THREADS_PER_PIXEL = 16; static const int THREADS_PER_PIXEL = 32;
static const int C_ELEMENTS_PER_CTA = 64; static const int C_ELEMENTS_PER_CTA = 128;
static const int ELEMENTS_PER_LDG = C_ELEMENTS_PER_CTA / THREADS_PER_PIXEL; static const int ELEMENTS_PER_LDG = C_ELEMENTS_PER_CTA / THREADS_PER_PIXEL;
static const int MAX_SMEM_WITHOUT_OPT_IN = 48 * 1024; static const int MAX_SMEM_WITHOUT_OPT_IN = 48 * 1024;
typedef uint16_t StorageType; typedef uint16_t StorageType;
// increasing this to 6 causes spills in fwd kernel! // increasing this to 6 causes spills in fwd kernel!
static const int PIXELS_PER_THREAD_IN_REGISTERS_FWD = 5; static const int PIXELS_PER_THREAD_IN_REGISTERS_FWD = 1;
static const int PIXELS_PER_THREAD_IN_REGISTERS_BWD = 3; static const int PIXELS_PER_THREAD_IN_REGISTERS_BWD = 1;
static const int PIXELS_PER_THREAD_IN_SMEM_FWD = 10; static const int PIXELS_PER_THREAD_IN_SMEM_FWD = 0;
static const int PIXELS_PER_THREAD_IN_SMEM_BWD = 5; static const int PIXELS_PER_THREAD_IN_SMEM_BWD = 0;
static const int PIXELS_PER_THREAD_FWD = PIXELS_PER_THREAD_IN_REGISTERS_FWD + \ static const int PIXELS_PER_THREAD_FWD = PIXELS_PER_THREAD_IN_REGISTERS_FWD + \
PIXELS_PER_THREAD_IN_SMEM_FWD; PIXELS_PER_THREAD_IN_SMEM_FWD;
...@@ -559,7 +559,7 @@ const std::vector<size_t> NhwcBatchNormAddRelu::numWorkspaceBytes() const { ...@@ -559,7 +559,7 @@ const std::vector<size_t> NhwcBatchNormAddRelu::numWorkspaceBytes() const {
const size_t num_variance_bytes = num_mean_bytes; const size_t num_variance_bytes = num_mean_bytes;
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
int elems_per_group = ((m_ + 3) & ~3); int elems_per_group = ((m_ + 3) & ~3) * 2;
#else #else
int elems_per_group = ((m_ + 31) & ~31) * 2; int elems_per_group = ((m_ + 31) & ~31) * 2;
#endif #endif
......
...@@ -36,7 +36,7 @@ ...@@ -36,7 +36,7 @@
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
using bitmask_t = uint64_t; using bitmask_t = uint64_t;
#define BITMASK_OFFSET 1 #define BITMASK_OFFSET 2
#define ONE_BITMASK 1UL #define ONE_BITMASK 1UL
#else #else
using bitmask_t = unsigned int; using bitmask_t = unsigned int;
...@@ -745,79 +745,72 @@ DEVICE_FUNCTION void parallel_sums_8x4(float *smem, float (&x)[4], int nhw) { ...@@ -745,79 +745,72 @@ DEVICE_FUNCTION void parallel_sums_8x4(float *smem, float (&x)[4], int nhw) {
template< int THREADS_PER_CTA, int THREADS_PER_PIXEL, int ELEMENTS_PER_LDG > template< int THREADS_PER_CTA, int THREADS_PER_PIXEL, int ELEMENTS_PER_LDG >
DEVICE_FUNCTION void parallel_sums(float *smem, float (&x)[ELEMENTS_PER_LDG], int nhw) { DEVICE_FUNCTION void parallel_sums(float *smem, float (&x)[ELEMENTS_PER_LDG], int nhw) {
// The size of a warp. // The size of a warp.
const int THREADS_PER_WARP = warpSize; #ifdef __HIP_PLATFORM_HCC__
// The number of warps in a CTA. const int THREADS_PER_WARP = 64;
#else
const int THREADS_PER_WARP = 32;
#endif
const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP; const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP;
// The number of pixels computed by a single warp. // The warp decomposition.
const int PIXELS_PER_WARP = THREADS_PER_WARP / THREADS_PER_PIXEL; const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int lane_id = threadIdx.x % THREADS_PER_WARP;
// The position in the warp. // total size of data per sync iter
const int nhw_in_warp = nhw % PIXELS_PER_WARP;
// The C in the warp.
const int c_in_warp = threadIdx.x % THREADS_PER_PIXEL;
// Store the values to shared memory.
write_to_smem(smem, threadIdx.x, x);
// Compute the parallel sums.
for (int offset = PIXELS_PER_WARP/2; offset > 0; offset /= 2) {
// NOP.
syncwarp();
// Read the running sum from the other thread. #ifdef __HIP_PLATFORM_HCC__
float y[ELEMENTS_PER_LDG]; for (int offset = THREADS_PER_PIXEL; offset <= THREADS_PER_WARP >> 1; offset <<= 1) {
if (nhw_in_warp < offset) { for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_PIXEL); x[i] += shfl_sync(x[i], offset + lane_id);
} }
// Compute the updated sum.
add(x, y);
// NOP.
syncwarp();
// Update the sum in SMEM.
if (offset > 1 && nhw_in_warp < offset) {
write_to_smem(smem, threadIdx.x, x);
} }
#else
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += shfl_sync(x[i], THREADS_PER_PIXEL+lane_id);
} }
#endif
// The warps are done. Do the final reduction at the CTA level.
__syncthreads();
// The warp leaders, write to SMEM. // The warp leaders, write to SMEM.
const int idx = (threadIdx.x/THREADS_PER_WARP)*THREADS_PER_PIXEL + c_in_warp; if (lane_id < THREADS_PER_PIXEL) {
if (nhw_in_warp == 0) { write_to_smem(smem, warp_id*THREADS_PER_PIXEL + lane_id, x);
write_to_smem(smem, idx, x);
} }
// The data is in SMEM. Do the final reduction. // The data is in SMEM. Do the final reduction.
__syncthreads(); __syncthreads();
// Read the 1st element to prepare the work. // The 1st warp does all the work.
if (nhw < WARPS_PER_CTA/2) { // We do the final reduction each half-warp sequentially reduces the final values.
if (warp_id == 0) {
read_from_smem(x, smem, threadIdx.x); read_from_smem(x, smem, threadIdx.x);
}
// We have the running mean and running m2. Let's build the mean/var of the CTA. #pragma unroll
for (int offset = WARPS_PER_CTA/2; offset > 0; offset /= 2) { for (int offset = 1;
// NOP. offset < WARPS_PER_CTA/(THREADS_PER_WARP / THREADS_PER_PIXEL); ++offset) {
syncwarp();
// Read the mean and variance from the other pixel.
float y[ELEMENTS_PER_LDG]; float y[ELEMENTS_PER_LDG];
if (nhw < offset) { // Read the mean and variance from the other pixel.
read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_PIXEL); read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_WARP);
}
// Compute the updated sum. // Compute the updated sum.
add(x, y); add(x, y);
}
#ifdef __HIP_PLATFORM_HCC__
for (int offset = THREADS_PER_WARP >> 1; offset >= THREADS_PER_PIXEL; offset >>= 1) {
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += shfl_sync(x[i], offset + lane_id);
}
}
#else
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += shfl_sync(x[i], THREADS_PER_PIXEL+lane_id);
}
#endif
// NOP. // Make sure the data was read from SMEM.
syncwarp(); syncwarp();
// Store the mean/var for the different pixels. // Store the final values.
if (nhw < offset) { if (threadIdx.x < THREADS_PER_PIXEL) {
// probably could do it earlier, before sync
write_to_smem(smem, threadIdx.x, x); write_to_smem(smem, threadIdx.x, x);
} }
} }
...@@ -834,7 +827,7 @@ struct ParallelSums { ...@@ -834,7 +827,7 @@ struct ParallelSums {
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
/*
template<> template<>
struct ParallelSums<16, 4> { struct ParallelSums<16, 4> {
template< int THREADS_PER_CTA > template< int THREADS_PER_CTA >
...@@ -855,6 +848,7 @@ struct ParallelSums<8, 4> { ...@@ -855,6 +848,7 @@ struct ParallelSums<8, 4> {
parallel_sums_8x4<THREADS_PER_CTA>(smem, x, nhw); parallel_sums_8x4<THREADS_PER_CTA>(smem, x, nhw);
} }
}; };
*/
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -1503,7 +1497,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) ...@@ -1503,7 +1497,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
bitmask_t *const gmem_relu_bitmask = params.gmem_relu_bitmask + bitmask_t *const gmem_relu_bitmask = params.gmem_relu_bitmask +
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
((params.nhw + 3) & ~3) * c_blk_index; ((params.nhw + 3) & ~3) * 2 * c_blk_index;
#else #else
((params.nhw + 31) & ~31) * 2 * c_blk_index; ((params.nhw + 31) & ~31) * 2 * c_blk_index;
#endif #endif
...@@ -2661,7 +2655,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) ...@@ -2661,7 +2655,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
const bitmask_t *const gmem_relu_bitmask = params.gmem_relu_bitmask + const bitmask_t *const gmem_relu_bitmask = params.gmem_relu_bitmask +
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
((params.nhw + 3) & ~3) * c_blk_index; ((params.nhw + 3) & ~3) * 2 * c_blk_index;
#else #else
((params.nhw + 31) & ~31) * 2 * c_blk_index; ((params.nhw + 31) & ~31) * 2 * c_blk_index;
#endif #endif
......
...@@ -82,7 +82,7 @@ class bn_addrelu_NHWC_impl(torch.autograd.Function): ...@@ -82,7 +82,7 @@ class bn_addrelu_NHWC_impl(torch.autograd.Function):
nhw = x.shape[0] * x.shape[2] * x.shape[3] nhw = x.shape[0] * x.shape[2] * x.shape[3]
else: else:
nhw = x.shape[0] * x.shape[1] * x.shape[2] nhw = x.shape[0] * x.shape[1] * x.shape[2]
shape = int(((nhw + 3) & ~3) * grid_dim_y) shape = int(((nhw + 3) & ~3) * 2 * grid_dim_y)
bitmask = torch.cuda.LongTensor(shape) bitmask = torch.cuda.LongTensor(shape)
else: else:
bitmask = torch.cuda.IntTensor(((x.numel()+31)//32) * 2 * grid_dim_y) bitmask = torch.cuda.IntTensor(((x.numel()+31)//32) * 2 * grid_dim_y)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment