Commit cdc17060 authored by luise.chen's avatar luise.chen Committed by flyingdown
Browse files

Luise/gbn optimization (#105)

* GroupBN: Reduced buffering for better hiding calculations in some loops of length OUTER_LOOPS

* GroupBN: Use C_ELEMENTS_PER_CTA=64 for BN and BN_relu kernels for improvement of resnet50

* GroupBN: Use C_ELEMENTS_PER_CTA=64 for BN_add_relu kernels for ~10% E2E improvement of resnet50
parent 06053e19
......@@ -236,18 +236,18 @@ class NhwcBatchNorm {
// Kernel params
static const int USE_ONLINE_APPROACH = 1;
static const int THREADS_PER_CTA = 512;
static const int THREADS_PER_PIXEL = 16;
static const int C_ELEMENTS_PER_CTA = 64;
static const int THREADS_PER_PIXEL = 32;
static const int C_ELEMENTS_PER_CTA = 128;
static const int ELEMENTS_PER_LDG = C_ELEMENTS_PER_CTA / THREADS_PER_PIXEL;
static const int MAX_SMEM_WITHOUT_OPT_IN = 48 * 1024;
typedef uint16_t StorageType;
//typedef float StorageType;
// increasing this to 6 causes spills in fwd kernel!
static const int PIXELS_PER_THREAD_IN_REGISTERS_FWD = 5;
static const int PIXELS_PER_THREAD_IN_REGISTERS_BWD = 3;
static const int PIXELS_PER_THREAD_IN_SMEM_FWD = 10;
static const int PIXELS_PER_THREAD_IN_SMEM_BWD = 5;
static const int PIXELS_PER_THREAD_IN_REGISTERS_FWD = 1;
static const int PIXELS_PER_THREAD_IN_REGISTERS_BWD = 1;
static const int PIXELS_PER_THREAD_IN_SMEM_FWD = 0;
static const int PIXELS_PER_THREAD_IN_SMEM_BWD = 0;
static const int PIXELS_PER_THREAD_FWD = PIXELS_PER_THREAD_IN_REGISTERS_FWD + \
PIXELS_PER_THREAD_IN_SMEM_FWD;
......
......@@ -248,17 +248,17 @@ class NhwcBatchNormAddRelu {
// Kernel params
static const int USE_ONLINE_APPROACH = 1;
static const int THREADS_PER_CTA = 512;
static const int THREADS_PER_PIXEL = 16;
static const int C_ELEMENTS_PER_CTA = 64;
static const int THREADS_PER_PIXEL = 32;
static const int C_ELEMENTS_PER_CTA = 128;
static const int ELEMENTS_PER_LDG = C_ELEMENTS_PER_CTA / THREADS_PER_PIXEL;
static const int MAX_SMEM_WITHOUT_OPT_IN = 48 * 1024;
typedef uint16_t StorageType;
// increasing this to 6 causes spills in fwd kernel!
static const int PIXELS_PER_THREAD_IN_REGISTERS_FWD = 5;
static const int PIXELS_PER_THREAD_IN_REGISTERS_BWD = 3;
static const int PIXELS_PER_THREAD_IN_SMEM_FWD = 10;
static const int PIXELS_PER_THREAD_IN_SMEM_BWD = 5;
static const int PIXELS_PER_THREAD_IN_REGISTERS_FWD = 1;
static const int PIXELS_PER_THREAD_IN_REGISTERS_BWD = 1;
static const int PIXELS_PER_THREAD_IN_SMEM_FWD = 0;
static const int PIXELS_PER_THREAD_IN_SMEM_BWD = 0;
static const int PIXELS_PER_THREAD_FWD = PIXELS_PER_THREAD_IN_REGISTERS_FWD + \
PIXELS_PER_THREAD_IN_SMEM_FWD;
......@@ -559,7 +559,7 @@ const std::vector<size_t> NhwcBatchNormAddRelu::numWorkspaceBytes() const {
const size_t num_variance_bytes = num_mean_bytes;
#ifdef __HIP_PLATFORM_HCC__
int elems_per_group = ((m_ + 3) & ~3);
int elems_per_group = ((m_ + 3) & ~3) * 2;
#else
int elems_per_group = ((m_ + 31) & ~31) * 2;
#endif
......
......@@ -36,7 +36,7 @@
#ifdef __HIP_PLATFORM_HCC__
using bitmask_t = uint64_t;
#define BITMASK_OFFSET 1
#define BITMASK_OFFSET 2
#define ONE_BITMASK 1UL
#else
using bitmask_t = unsigned int;
......@@ -745,79 +745,72 @@ DEVICE_FUNCTION void parallel_sums_8x4(float *smem, float (&x)[4], int nhw) {
template< int THREADS_PER_CTA, int THREADS_PER_PIXEL, int ELEMENTS_PER_LDG >
DEVICE_FUNCTION void parallel_sums(float *smem, float (&x)[ELEMENTS_PER_LDG], int nhw) {
// The size of a warp.
const int THREADS_PER_WARP = warpSize;
// The number of warps in a CTA.
#ifdef __HIP_PLATFORM_HCC__
const int THREADS_PER_WARP = 64;
#else
const int THREADS_PER_WARP = 32;
#endif
const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP;
// The number of pixels computed by a single warp.
const int PIXELS_PER_WARP = THREADS_PER_WARP / THREADS_PER_PIXEL;
// The position in the warp.
const int nhw_in_warp = nhw % PIXELS_PER_WARP;
// The C in the warp.
const int c_in_warp = threadIdx.x % THREADS_PER_PIXEL;
// Store the values to shared memory.
write_to_smem(smem, threadIdx.x, x);
// Compute the parallel sums.
for (int offset = PIXELS_PER_WARP/2; offset > 0; offset /= 2) {
// NOP.
syncwarp();
// Read the running sum from the other thread.
float y[ELEMENTS_PER_LDG];
if (nhw_in_warp < offset) {
read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_PIXEL);
}
// Compute the updated sum.
add(x, y);
// NOP.
syncwarp();
// The warp decomposition.
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int lane_id = threadIdx.x % THREADS_PER_WARP;
// total size of data per sync iter
// Update the sum in SMEM.
if (offset > 1 && nhw_in_warp < offset) {
write_to_smem(smem, threadIdx.x, x);
#ifdef __HIP_PLATFORM_HCC__
for (int offset = THREADS_PER_PIXEL; offset <= THREADS_PER_WARP >> 1; offset <<= 1) {
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += shfl_sync(x[i], offset + lane_id);
}
}
#else
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += shfl_sync(x[i], THREADS_PER_PIXEL+lane_id);
}
#endif
// The warps are done. Do the final reduction at the CTA level.
__syncthreads();
// The warp leaders, write to SMEM.
const int idx = (threadIdx.x/THREADS_PER_WARP)*THREADS_PER_PIXEL + c_in_warp;
if (nhw_in_warp == 0) {
write_to_smem(smem, idx, x);
if (lane_id < THREADS_PER_PIXEL) {
write_to_smem(smem, warp_id*THREADS_PER_PIXEL + lane_id, x);
}
// The data is in SMEM. Do the final reduction.
__syncthreads();
// Read the 1st element to prepare the work.
if (nhw < WARPS_PER_CTA/2) {
// The 1st warp does all the work.
// We do the final reduction each half-warp sequentially reduces the final values.
if (warp_id == 0) {
read_from_smem(x, smem, threadIdx.x);
}
// We have the running mean and running m2. Let's build the mean/var of the CTA.
for (int offset = WARPS_PER_CTA/2; offset > 0; offset /= 2) {
// NOP.
syncwarp();
// Read the mean and variance from the other pixel.
float y[ELEMENTS_PER_LDG];
if (nhw < offset) {
read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_PIXEL);
#pragma unroll
for (int offset = 1;
offset < WARPS_PER_CTA/(THREADS_PER_WARP / THREADS_PER_PIXEL); ++offset) {
float y[ELEMENTS_PER_LDG];
// Read the mean and variance from the other pixel.
read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_WARP);
// Compute the updated sum.
add(x, y);
}
// Compute the updated sum.
add(x, y);
#ifdef __HIP_PLATFORM_HCC__
for (int offset = THREADS_PER_WARP >> 1; offset >= THREADS_PER_PIXEL; offset >>= 1) {
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += shfl_sync(x[i], offset + lane_id);
}
}
#else
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += shfl_sync(x[i], THREADS_PER_PIXEL+lane_id);
}
#endif
// NOP.
// Make sure the data was read from SMEM.
syncwarp();
// Store the mean/var for the different pixels.
if (nhw < offset) {
// Store the final values.
if (threadIdx.x < THREADS_PER_PIXEL) {
// probably could do it earlier, before sync
write_to_smem(smem, threadIdx.x, x);
}
}
......@@ -834,7 +827,7 @@ struct ParallelSums {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/*
template<>
struct ParallelSums<16, 4> {
template< int THREADS_PER_CTA >
......@@ -855,6 +848,7 @@ struct ParallelSums<8, 4> {
parallel_sums_8x4<THREADS_PER_CTA>(smem, x, nhw);
}
};
*/
////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -1503,7 +1497,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
bitmask_t *const gmem_relu_bitmask = params.gmem_relu_bitmask +
#ifdef __HIP_PLATFORM_HCC__
((params.nhw + 3) & ~3) * c_blk_index;
((params.nhw + 3) & ~3) * 2 * c_blk_index;
#else
((params.nhw + 31) & ~31) * 2 * c_blk_index;
#endif
......@@ -2661,7 +2655,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
const bitmask_t *const gmem_relu_bitmask = params.gmem_relu_bitmask +
#ifdef __HIP_PLATFORM_HCC__
((params.nhw + 3) & ~3) * c_blk_index;
((params.nhw + 3) & ~3) * 2 * c_blk_index;
#else
((params.nhw + 31) & ~31) * 2 * c_blk_index;
#endif
......
......@@ -82,7 +82,7 @@ class bn_addrelu_NHWC_impl(torch.autograd.Function):
nhw = x.shape[0] * x.shape[2] * x.shape[3]
else:
nhw = x.shape[0] * x.shape[1] * x.shape[2]
shape = int(((nhw + 3) & ~3) * grid_dim_y)
shape = int(((nhw + 3) & ~3) * 2 * grid_dim_y)
bitmask = torch.cuda.LongTensor(shape)
else:
bitmask = torch.cuda.IntTensor(((x.numel()+31)//32) * 2 * grid_dim_y)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment