Unverified Commit d9c887c2 authored by jjsjann123's avatar jjsjann123 Committed by GitHub
Browse files

Merge pull request #360 from NVIDIA/gbn_update

update gbn
parents d68ec712 333e53f7
......@@ -52,16 +52,19 @@ at::Tensor nhwc_bn_fwd_train(
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const at::Tensor& ret_cta,
const float momentum,
const float epsilon,
const bool fuse_relu,
void * my_data,
void * pair_data,
void * pair_data2,
void * pair_data3,
const int bn_group,
const at::Tensor& magic_tensor,
const int max_cta_per_sm,
const int cta_launch_margin) {
const int occupancy,
const int grid_dim_x,
const bool coop) {
const int N = x.size(0);
const int H = x.size(1);
......@@ -116,8 +119,8 @@ at::Tensor nhwc_bn_fwd_train(
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[2];
void* retired_ctas = THCudaMalloc(at::globalContext().lazyInitCUDA(), retired_cta_bytes);
cudaMemsetAsync(retired_ctas, 0, retired_cta_bytes, stream); //FIXME: is this legit?
void* retired_ctas = ret_cta.data<uint8_t>();
assert(ret_cta.size(0)>=retired_cta_bytes);
workspace.push_back(retired_ctas);
for (auto index = 3; index < workspace_bytes.size(); ++index) {
......@@ -127,12 +130,9 @@ at::Tensor nhwc_bn_fwd_train(
bn->setWorkspacePointers(workspace, workspace_bytes);
int device_id;
cudaGetDevice(&device_id);
// Don't fuse in ReLU for now at least
bn->fwd(stream, fuse_relu, device_id, my_data, pair_data, pair_data2, bn_group, *magic, max_cta_per_sm, cta_launch_margin);
bn->fwd(stream, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop);
THCudaFree(at::globalContext().lazyInitCUDA(), retired_ctas);
return y;
}
......@@ -142,6 +142,7 @@ at::Tensor nhwc_bn_fwd_eval(
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& ret_cta,
const int bn_group,
const float momentum,
const float epsilon,
......@@ -196,8 +197,8 @@ at::Tensor nhwc_bn_fwd_eval(
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[2];
void* retired_ctas = THCudaMalloc(at::globalContext().lazyInitCUDA(), retired_cta_bytes);
cudaMemsetAsync(retired_ctas, 0, retired_cta_bytes, stream); //FIXME: is this legit?
void* retired_ctas = ret_cta.data<uint8_t>();
assert(ret_cta.size(0)>=retired_cta_bytes);
workspace.push_back(retired_ctas);
for (auto index = 3; index < workspace_bytes.size(); ++index) {
......@@ -210,7 +211,6 @@ at::Tensor nhwc_bn_fwd_eval(
// Don't fuse in ReLU for now at least
bn->fwdInference(stream, fuse_relu);
THCudaFree(at::globalContext().lazyInitCUDA(), retired_ctas);
return y;
}
......@@ -224,16 +224,19 @@ std::vector<at::Tensor> nhwc_bn_bwd(
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const at::Tensor& ret_cta,
const float momentum,
const float epsilon,
const bool fuse_relu,
void * my_data,
void * pair_data,
void * pair_data2,
void * pair_data3,
const int bn_group,
const at::Tensor& magic_tensor,
const int max_cta_per_sm,
const int cta_launch_margin) {
const int occupancy,
const int grid_dim_x,
const bool coop) {
// shape
const int N = x.size(0);
const int H = x.size(1);
......@@ -293,8 +296,8 @@ std::vector<at::Tensor> nhwc_bn_bwd(
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[2];
void* retired_ctas = THCudaMalloc(at::globalContext().lazyInitCUDA(), retired_cta_bytes);
cudaMemsetAsync(retired_ctas, 0, retired_cta_bytes, stream); //FIXME: is this legit?
void* retired_ctas = ret_cta.data<uint8_t>();
assert(ret_cta.size(0)>=retired_cta_bytes);
workspace.push_back(retired_ctas);
for (auto index = 3; index < workspace_bytes.size(); ++index) {
......@@ -304,10 +307,25 @@ std::vector<at::Tensor> nhwc_bn_bwd(
bn->setWorkspacePointers(workspace, workspace_bytes);
int device_id;
cudaGetDevice(&device_id);
bn->dgrad(stream, fuse_relu, device_id, my_data, pair_data, pair_data2, bn_group, *magic, max_cta_per_sm, cta_launch_margin);
bn->dgrad(stream, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop);
THCudaFree(at::globalContext().lazyInitCUDA(), retired_ctas);
return std::vector<at::Tensor>{x_grad, scale_grad, bias_grad};
}
int nhwc_bn_fwd_occupancy() {
int device_id=-1;
cudaGetDevice(&device_id);
//max occupancy supported by the code is 2
return NhwcBatchNorm::smem_driven_fwd_occupancy(device_id, 2);
}
int nhwc_bn_bwd_occupancy() {
int device_id=-1;
cudaGetDevice(&device_id);
//max occupancy supported by the code is 2
return NhwcBatchNorm::smem_driven_bwd_occupancy(device_id, 2);
}
......@@ -56,11 +56,11 @@ class NhwcBatchNorm {
exit(-1);
}
void fwd(cudaStream_t stream, bool use_relu, int device_id, void* my_data, void* pair_data, void* pair_data2, const int bn_group, const int magic, const int max_cta_per_sm, const int cta_launch_margin);
void dgrad(cudaStream_t stream, bool use_relu, int device_id, void* my_data, void* pair_data, void* pair_data2, const int bn_group, const int magic, const int max_cta_per_sm, const int cta_launch_margin);
void fwd(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop);
void dgrad(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop);
void fwdInference(cudaStream_t stream, bool use_relu);
dim3 calc_fwd_grid(int device_id, int *loop, const int max_cta_per_sm, const int cta_launch_margin);
dim3 calc_bwd_grid(int device_id, int *loop, const int max_cta_per_sm, const int cta_launch_margin);
dim3 calc_fwd_grid(int *loop, const int grid_dim_x);
dim3 calc_bwd_grid(int *loop, const int grid_dim_x);
void setInputDescriptor(const cudnnTensorFormat_t format,
const cudnnDataType_t data_type,
......@@ -256,8 +256,9 @@ class NhwcBatchNorm {
// version that was compiled with that occupancy in its launch bounds. This way, we avoid
// needless register spills.
void _fwdKernelLauncher(cudaStream_t stream, NhwcBatchNormFwdParams params,
dim3 grid_dim, int outer_loops, bool use_relu, int device_id, const int max_cta_per_sm) {
#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY) \
dim3 grid_dim, int outer_loops, bool use_relu, const int occupancy, const bool coop) {
#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto fwd_func = nhwc_batch_norm_fwd< \
......@@ -289,37 +290,45 @@ class NhwcBatchNorm {
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>); \
cudaLaunchCooperativeKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_FWD, \
stream); \
if (COOP) { \
cudaLaunchCooperativeKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_FWD, \
stream); \
} else { \
cudaLaunchKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_FWD, \
stream); \
} \
checkCudaStatus(name_ + " fwd ser coop kernel"); \
} while (0)
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
int occupancy = smem_driven_fwd_occupancy(device_id, max_cta_per_sm);
if (outer_loops == 1 && use_relu) {
if (occupancy >= 2)
LAUNCH_FWD_KERNEL(1, true, false, 2);
LAUNCH_FWD_KERNEL(1, true, false, 2, coop);
else
LAUNCH_FWD_KERNEL(1, true, false, 1);
LAUNCH_FWD_KERNEL(1, true, false, 1, coop);
} else if (outer_loops == 1 && !use_relu) {
if (occupancy >= 2)
LAUNCH_FWD_KERNEL(1, false, false, 2);
LAUNCH_FWD_KERNEL(1, false, false, 2, coop);
else
LAUNCH_FWD_KERNEL(1, false, false, 1);
LAUNCH_FWD_KERNEL(1, false, false, 1, coop);
} else if (use_relu) {
if (occupancy >= 2)
LAUNCH_FWD_KERNEL(0, true, false, 2);
LAUNCH_FWD_KERNEL(0, true, false, 2, coop);
else
LAUNCH_FWD_KERNEL(0, true, false, 1);
LAUNCH_FWD_KERNEL(0, true, false, 1, coop);
} else {
if (occupancy >= 2)
LAUNCH_FWD_KERNEL(0, false, false, 2);
LAUNCH_FWD_KERNEL(0, false, false, 2, coop);
else
LAUNCH_FWD_KERNEL(0, false, false, 1);
LAUNCH_FWD_KERNEL(0, false, false, 1, coop);
}
#undef LAUNCH_FWD_KERNEL
}
......@@ -327,8 +336,8 @@ class NhwcBatchNorm {
// Helper function to launch the backward kernel.
void _bwdKernelLauncher(cudaStream_t stream, NhwcBatchNormBwdParams params,
dim3 grid_dim, int outer_loops, bool use_relu, int device_id, const int max_cta_per_sm) {
#define LAUNCH_BWD_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY) \
dim3 grid_dim, int outer_loops, bool use_relu, const int occupancy, const bool coop) {
#define LAUNCH_BWD_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto bwd_func = nhwc_batch_norm_bwd< \
......@@ -356,16 +365,25 @@ class NhwcBatchNorm {
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
cudaLaunchCooperativeKernel<BWD_FUNC>(bwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
if (COOP) { \
cudaLaunchCooperativeKernel<BWD_FUNC>(bwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} else { \
cudaLaunchKernel<BWD_FUNC>(bwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} \
checkCudaStatus(name_ + " bwd coop serial kernel"); \
} while (0)
#define LAUNCH_BWD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY) \
#define LAUNCH_BWD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \
auto bwd_relu_func = nhwc_batch_norm_bwd_relu< \
......@@ -393,62 +411,50 @@ class NhwcBatchNorm {
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
cudaLaunchCooperativeKernel<BWD_RELU_FUNC>(bwd_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
if (COOP) { \
cudaLaunchCooperativeKernel<BWD_RELU_FUNC>(bwd_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} else { \
cudaLaunchKernel<BWD_RELU_FUNC>(bwd_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} \
checkCudaStatus(name_ + " bwd-relu coop serial kernel"); \
} while (0)
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
int occupancy = smem_driven_bwd_occupancy(device_id, max_cta_per_sm);
if (outer_loops == 1 && use_relu) {
if (occupancy >= 2)
LAUNCH_BWD_RELU_KERNEL(1, 2);
LAUNCH_BWD_RELU_KERNEL(1, 2, coop);
else
LAUNCH_BWD_RELU_KERNEL(1, 1);
LAUNCH_BWD_RELU_KERNEL(1, 1, coop);
} else if (outer_loops == 1 && !use_relu) {
if (occupancy >= 2)
LAUNCH_BWD_KERNEL(1, 2);
LAUNCH_BWD_KERNEL(1, 2, coop);
else
LAUNCH_BWD_KERNEL(1, 1);
LAUNCH_BWD_KERNEL(1, 1, coop);
} else if (use_relu) {
if (occupancy >= 2)
LAUNCH_BWD_RELU_KERNEL(0, 2);
LAUNCH_BWD_RELU_KERNEL(0, 2, coop);
else
LAUNCH_BWD_RELU_KERNEL(0, 1);
LAUNCH_BWD_RELU_KERNEL(0, 1, coop);
} else {
if (occupancy >= 2)
LAUNCH_BWD_KERNEL(0, 2);
LAUNCH_BWD_KERNEL(0, 2, coop);
else
LAUNCH_BWD_KERNEL(0, 1);
LAUNCH_BWD_KERNEL(0, 1, coop);
}
#undef LAUNCH_BWD_KERNEL
}
private:
// Calculate the max number of CTAs allowed in the grid for the fwd kernel.
static size_t max_fwd_grid_x(int device_id, const int max_cta_per_sm, const int cta_launch_margin) {
using namespace at::cuda::utils;
int answer = MultiprocessorCount(device_id) * smem_driven_fwd_occupancy(device_id, max_cta_per_sm);
if (SMArch(device_id) >= 70)
answer -= cta_launch_margin;
answer = std::max(1, answer); // we need at least one CTA to operate
return static_cast<size_t>(answer);
}
// Calculate the max number of CTAs allowed in the grid for the bwd kernel.
static size_t max_bwd_grid_x(int device_id, const int max_cta_per_sm, const int cta_launch_margin) {
using namespace at::cuda::utils;
int answer = MultiprocessorCount(device_id) * smem_driven_bwd_occupancy(device_id, max_cta_per_sm);
if (SMArch(device_id) >= 70)
answer -= cta_launch_margin;
answer = std::max(1, answer); // we need at least one CTA to operate
return static_cast<size_t>(answer);
}
public:
// Calculate the expected fwd kernel occupancy, as dictated by shared memory usage.
static int smem_driven_fwd_occupancy(int device_id, const int max_cta_per_sm) {
......@@ -603,11 +609,11 @@ void NhwcBatchNorm::fwdInference(cudaStream_t stream, bool use_relu) {
}
}
dim3 NhwcBatchNorm::calc_fwd_grid(int device_id, int *loop, const int max_cta_per_sm, const int cta_launch_margin) {
dim3 NhwcBatchNorm::calc_fwd_grid(int *loop, const int grid_dim_x) {
dim3 grid_dim;
grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD);
int c_blks = div_up(c_, C_ELEMENTS_PER_CTA);
unsigned int max_grid_x = max_fwd_grid_x(device_id, max_cta_per_sm, cta_launch_margin);
unsigned int max_grid_x = grid_dim_x;
if (grid_dim.x <= max_grid_x) {
*loop = 1;
if (max_grid_x / grid_dim.x > 1) {
......@@ -626,11 +632,11 @@ dim3 NhwcBatchNorm::calc_fwd_grid(int device_id, int *loop, const int max_cta_pe
return grid_dim;
}
dim3 NhwcBatchNorm::calc_bwd_grid(int device_id, int *loop, const int max_cta_per_sm, const int cta_launch_margin) {
dim3 NhwcBatchNorm::calc_bwd_grid(int *loop, const int grid_dim_x) {
dim3 grid_dim;
grid_dim.x = div_up(m_, PIXELS_PER_CTA_BWD);
int c_blks = div_up(c_, C_ELEMENTS_PER_CTA);
unsigned int max_grid_x = max_bwd_grid_x(device_id, max_cta_per_sm, cta_launch_margin);
unsigned int max_grid_x = grid_dim_x;
if (grid_dim.x <= max_grid_x) {
*loop = 1;
if (max_grid_x / grid_dim.x > 1) {
......@@ -649,7 +655,8 @@ dim3 NhwcBatchNorm::calc_bwd_grid(int device_id, int *loop, const int max_cta_pe
return grid_dim;
}
void NhwcBatchNorm::fwd(cudaStream_t stream, bool use_relu, int device_id, void* my_data, void* pair_data, void* pair_data2, const int bn_group, const int magic, const int max_cta_per_sm, const int cta_launch_margin) {
void NhwcBatchNorm::fwd(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3,
const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop) {
bool ptrs_are_set =
X_tensor_desc_ != nullptr
&& Y_tensor_desc_ != nullptr
......@@ -677,16 +684,18 @@ void NhwcBatchNorm::fwd(cudaStream_t stream, bool use_relu, int device_id, void*
NhwcBatchNormFwdParams params;
_setFwdParams(&params);
params.my_data = my_data;
params.pair_data = pair_data;
params.pair_data2 = pair_data2;
params.pair_datas[0] = pair_data;
params.pair_datas[1] = pair_data2;
params.pair_datas[2] = pair_data3;
params.magic = magic;
params.sync_iters = bn_group >> 1;
params.sync_iters = (bn_group==8)?3:(bn_group >> 1);
dim3 grid_dim = calc_fwd_grid(device_id, &params.outer_loops, max_cta_per_sm, cta_launch_margin);
_fwdKernelLauncher(stream, params, grid_dim, params.outer_loops, use_relu, device_id, max_cta_per_sm);
dim3 grid_dim = calc_fwd_grid(&params.outer_loops, grid_dim_x);
_fwdKernelLauncher(stream, params, grid_dim, params.outer_loops, use_relu, occupancy, coop);
}
void NhwcBatchNorm::dgrad(cudaStream_t stream, bool use_relu, int device_id, void* my_data, void* pair_data, void* pair_data2, const int bn_group, const int magic, const int max_cta_per_sm, const int cta_launch_margin) {
void NhwcBatchNorm::dgrad(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3,
const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop) {
bool ptrs_are_set =
X_tensor_desc_ != nullptr
&& Y_tensor_desc_ != nullptr
......@@ -711,14 +720,15 @@ void NhwcBatchNorm::dgrad(cudaStream_t stream, bool use_relu, int device_id, voi
NhwcBatchNormBwdParams params;
_setBwdParams(&params);
params.my_data = my_data;
params.pair_data = pair_data;
params.pair_data2 = pair_data2;
params.pair_datas[0] = pair_data;
params.pair_datas[1] = pair_data2;
params.pair_datas[2] = pair_data3;
params.magic = magic;
params.sync_iters = bn_group >> 1;
params.sync_iters = (bn_group==8)?3:(bn_group >> 1);
params.wgrad_coeff = 1.0 / bn_group;
dim3 grid_dim = calc_bwd_grid(device_id, &params.outer_loops, max_cta_per_sm, cta_launch_margin);
_bwdKernelLauncher(stream, params, grid_dim, params.outer_loops, use_relu, device_id, max_cta_per_sm);
dim3 grid_dim = calc_bwd_grid(&params.outer_loops, grid_dim_x);
_bwdKernelLauncher(stream, params, grid_dim, params.outer_loops, use_relu, occupancy, coop);
}
#endif // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_
......@@ -55,15 +55,18 @@ at::Tensor nhwc_bn_addrelu_fwd_train(
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const at::Tensor& bitmask,
const at::Tensor& ret_cta,
const float momentum,
const float epsilon,
void * my_data,
void * pair_data,
void * pair_data2,
void * pair_data3,
const int bn_group,
const at::Tensor& magic_tensor,
const int max_cta_per_sm,
const int cta_launch_margin) {
const int occupancy,
const int grid_dim_x,
const bool coop) {
const int N = x.size(0);
const int H = x.size(1);
......@@ -121,8 +124,9 @@ at::Tensor nhwc_bn_addrelu_fwd_train(
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[3];
void* retired_ctas = THCudaMalloc(at::globalContext().lazyInitCUDA(), retired_cta_bytes);
cudaMemsetAsync(retired_ctas, 0, retired_cta_bytes, stream); //FIXME: is this legit?
void* retired_ctas = ret_cta.data<uint8_t>();
assert(ret_cta.size(0)>=retired_cta_bytes);
workspace.push_back(retired_ctas);
for (auto index = 4; index < workspace_bytes.size(); ++index) {
......@@ -132,12 +136,9 @@ at::Tensor nhwc_bn_addrelu_fwd_train(
bn->setWorkspacePointers(workspace, workspace_bytes);
int device_id;
cudaGetDevice(&device_id);
// Don't fuse in ReLU for now at least
bn->fwd(stream, device_id, my_data, pair_data, pair_data2, bn_group, *magic, max_cta_per_sm, cta_launch_margin);
bn->fwd(stream, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop);
THCudaFree(at::globalContext().lazyInitCUDA(), retired_ctas);
return y;
}
......@@ -148,6 +149,7 @@ at::Tensor nhwc_bn_addrelu_fwd_eval(
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& ret_cta,
const int bn_group,
const float momentum,
const float epsilon) {
......@@ -204,8 +206,8 @@ at::Tensor nhwc_bn_addrelu_fwd_eval(
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[3];
void* retired_ctas = THCudaMalloc(at::globalContext().lazyInitCUDA(), retired_cta_bytes);
cudaMemsetAsync(retired_ctas, 0, retired_cta_bytes, stream); //FIXME: is this legit?
void* retired_ctas = ret_cta.data<uint8_t>();
assert(ret_cta.size(0)>=retired_cta_bytes);
workspace.push_back(retired_ctas);
for (auto index = 4; index < workspace_bytes.size(); ++index) {
......@@ -218,7 +220,6 @@ at::Tensor nhwc_bn_addrelu_fwd_eval(
// Don't fuse in ReLU for now at least
bn->fwdInference(stream);
THCudaFree(at::globalContext().lazyInitCUDA(), retired_ctas);
return y;
}
......@@ -233,15 +234,18 @@ std::vector<at::Tensor> nhwc_bn_addrelu_bwd(
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const at::Tensor& bitmask,
const at::Tensor& ret_cta,
const float momentum,
const float epsilon,
void * my_data,
void * pair_data,
void * pair_data2,
void * pair_data3,
const int bn_group,
const at::Tensor& magic_tensor,
const int max_cta_per_sm,
const int cta_launch_margin) {
const int occupancy,
const int grid_dim_x,
const bool coop) {
// shape
const int N = x.size(0);
const int H = x.size(1);
......@@ -305,8 +309,8 @@ std::vector<at::Tensor> nhwc_bn_addrelu_bwd(
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int retired_cta_bytes = workspace_bytes[3];
void* retired_ctas = THCudaMalloc(at::globalContext().lazyInitCUDA(), retired_cta_bytes);
cudaMemsetAsync(retired_ctas, 0, retired_cta_bytes, stream); //FIXME: is this legit?
void* retired_ctas = ret_cta.data<uint8_t>();
assert(ret_cta.size(0)>=retired_cta_bytes);
workspace.push_back(retired_ctas);
for (auto index = 4; index < workspace_bytes.size(); ++index) {
......@@ -316,10 +320,24 @@ std::vector<at::Tensor> nhwc_bn_addrelu_bwd(
bn->setWorkspacePointers(workspace, workspace_bytes);
int device_id;
cudaGetDevice(&device_id);
bn->dgrad(stream, device_id, my_data, pair_data, pair_data2, bn_group, *magic, max_cta_per_sm, cta_launch_margin);
bn->dgrad(stream, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop);
THCudaFree(at::globalContext().lazyInitCUDA(), retired_ctas);
return std::vector<at::Tensor>{x_grad, z_grad, scale_grad, bias_grad};
}
int nhwc_bn_addrelu_fwd_occupancy() {
int device_id=-1;
cudaGetDevice(&device_id);
//max occupancy supported by the code is 2
return NhwcBatchNormAddRelu::smem_driven_fwd_occupancy(device_id, 2);
}
int nhwc_bn_addrelu_bwd_occupancy() {
int device_id=-1;
cudaGetDevice(&device_id);
//max occupancy supported by the code is 2
return NhwcBatchNormAddRelu::smem_driven_bwd_occupancy(device_id, 2);
}
......@@ -56,11 +56,11 @@ class NhwcBatchNormAddRelu {
exit(-1);
}
void fwd(cudaStream_t stream, int device_id, void* my_data, void* pair_data, void* pair_data2, const int bn_group, const int magic, const int max_cta_per_sm, const int cta_launch_margin);
void dgrad(cudaStream_t stream, int device_id, void* my_data, void* pair_data, void* pair_data2, const int bn_group, const int magic, const int max_cta_per_sm, const int cta_launch_margin);
void fwd(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop);
void dgrad(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop);
void fwdInference(cudaStream_t stream);
dim3 calc_fwd_grid(int device_id, int *loop, const int max_cta_per_sm, const int cta_launch_margin);
dim3 calc_bwd_grid(int device_id, int *loop, const int max_cta_per_sm, const int cta_launch_margin);
dim3 calc_fwd_grid(int *loop, const int grid_dim_x);
dim3 calc_bwd_grid(int *loop, const int grid_dim_x);
void setInputDescriptor(const cudnnTensorFormat_t format,
const cudnnDataType_t data_type,
......@@ -260,8 +260,8 @@ class NhwcBatchNormAddRelu {
// version that was compiled with that occupancy in its launch bounds. This way, we avoid
// needless register spills.
void _fwdKernelLauncher(cudaStream_t stream, NhwcBatchNormFwdParams params,
dim3 grid_dim, int outer_loops, int device_id, const int max_cta_per_sm) {
#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY) \
dim3 grid_dim, int outer_loops, const int occupancy, const bool coop) {
#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << \
"Nhwc batchnormaddrelu kernel smem too big."; \
......@@ -294,27 +294,35 @@ class NhwcBatchNormAddRelu {
USE_RELU, \
USE_ADD_RELU, \
COMPILED_FOR_OCCUPANCY>); \
cudaLaunchCooperativeKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_FWD, \
stream); \
if (COOP) { \
cudaLaunchCooperativeKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_FWD, \
stream); \
} else { \
cudaLaunchKernel<FWD_FUNC>(fwd_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_FWD, \
stream); \
} \
checkCudaStatus(name_ + " fwd ser coop kernel"); \
} while (0)
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
int occupancy = smem_driven_fwd_occupancy(device_id, max_cta_per_sm);
if (outer_loops == 1) {
if (occupancy >= 2)
LAUNCH_FWD_KERNEL(1, false, true, 2);
LAUNCH_FWD_KERNEL(1, false, true, 2, coop);
else
LAUNCH_FWD_KERNEL(1, false, true, 1);
LAUNCH_FWD_KERNEL(1, false, true, 1, coop);
} else {
if (occupancy >= 2)
LAUNCH_FWD_KERNEL(0, false, true, 2);
LAUNCH_FWD_KERNEL(0, false, true, 2, coop);
else
LAUNCH_FWD_KERNEL(0, false, true, 1);
LAUNCH_FWD_KERNEL(0, false, true, 1, coop);
}
#undef LAUNCH_FWD_KERNEL
}
......@@ -322,8 +330,8 @@ class NhwcBatchNormAddRelu {
// Helper function to launch the backward kernel.
void _bwdKernelLauncher(cudaStream_t stream, NhwcBatchNormBwdParams params,
dim3 grid_dim, int outer_loops, int device_id, const int max_cta_per_sm) {
#define LAUNCH_BWD_ADD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY) \
dim3 grid_dim, int outer_loops, const int occupancy, const bool coop) {
#define LAUNCH_BWD_ADD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \
do { \
CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << \
"Nhwc batchnormaddrelu kernel smem too big."; \
......@@ -354,52 +362,40 @@ class NhwcBatchNormAddRelu {
USE_ONLINE_APPROACH, \
OUTER_LOOPS, \
COMPILED_FOR_OCCUPANCY>); \
cudaLaunchCooperativeKernel<BWD_ADD_RELU_FUNC>(bwd_add_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
if (COOP) { \
cudaLaunchCooperativeKernel<BWD_ADD_RELU_FUNC>(bwd_add_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} else { \
cudaLaunchKernel<BWD_ADD_RELU_FUNC>(bwd_add_relu_func, \
grid_dim, \
THREADS_PER_CTA, \
&params_ptr, \
SMEM_SIZE_BWD, \
stream); \
} \
checkCudaStatus(name_ + " bwd-add-relu coop serial kernel"); \
} while (0)
// Don't try for an occupancy > 2 as this will squeeze register use and create spills.
int occupancy = smem_driven_bwd_occupancy(device_id, max_cta_per_sm);
if (outer_loops == 1) {
if (occupancy >= 2)
LAUNCH_BWD_ADD_RELU_KERNEL(1, 2);
LAUNCH_BWD_ADD_RELU_KERNEL(1, 2, coop);
else
LAUNCH_BWD_ADD_RELU_KERNEL(1, 1);
LAUNCH_BWD_ADD_RELU_KERNEL(1, 1, coop);
} else {
if (occupancy >= 2)
LAUNCH_BWD_ADD_RELU_KERNEL(0, 2);
LAUNCH_BWD_ADD_RELU_KERNEL(0, 2, coop);
else
LAUNCH_BWD_ADD_RELU_KERNEL(0, 1);
LAUNCH_BWD_ADD_RELU_KERNEL(0, 1, coop);
}
#undef LAUNCH_BWD_KERNEL
}
private:
// Calculate the max number of CTAs allowed in the grid for the fwd kernel.
static size_t max_fwd_grid_x(int device_id, const int max_cta_per_sm, const int cta_launch_margin) {
using namespace at::cuda::utils;
int answer = MultiprocessorCount(device_id) * smem_driven_fwd_occupancy(device_id, max_cta_per_sm);
if (SMArch(device_id) >= 70)
answer -= cta_launch_margin;
answer = std::max(1, answer); // we need at least one CTA to operate
return static_cast<size_t>(answer);
}
// Calculate the max number of CTAs allowed in the grid for the bwd kernel.
static size_t max_bwd_grid_x(int device_id, const int max_cta_per_sm, const int cta_launch_margin) {
using namespace at::cuda::utils;
int answer = MultiprocessorCount(device_id) * smem_driven_bwd_occupancy(device_id, max_cta_per_sm);
if (SMArch(device_id) >= 70)
answer -= cta_launch_margin;
answer = std::max(1, answer); // we need at least one CTA to operate
return static_cast<size_t>(answer);
}
public:
// Calculate the expected fwd kernel occupancy, as dictated by shared memory usage.
static int smem_driven_fwd_occupancy(int device_id, const int max_cta_per_sm) {
using namespace at::cuda::utils;
......@@ -553,11 +549,11 @@ void NhwcBatchNormAddRelu::fwdInference(cudaStream_t stream) {
checkCudaStatus(name_ + " fwd_inference-relu kernel");
}
dim3 NhwcBatchNormAddRelu::calc_fwd_grid(int device_id, int *loop, const int max_cta_per_sm, const int cta_launch_margin) {
dim3 NhwcBatchNormAddRelu::calc_fwd_grid(int *loop, const int grid_dim_x) {
dim3 grid_dim;
grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD);
int c_blks = div_up(c_, C_ELEMENTS_PER_CTA);
unsigned int max_grid_x = max_fwd_grid_x(device_id, max_cta_per_sm, cta_launch_margin);
unsigned int max_grid_x = grid_dim_x;
if (grid_dim.x <= max_grid_x) {
*loop = 1;
if (max_grid_x / grid_dim.x > 1) {
......@@ -576,11 +572,11 @@ dim3 NhwcBatchNormAddRelu::calc_fwd_grid(int device_id, int *loop, const int max
return grid_dim;
}
dim3 NhwcBatchNormAddRelu::calc_bwd_grid(int device_id, int *loop, const int max_cta_per_sm, const int cta_launch_margin) {
dim3 NhwcBatchNormAddRelu::calc_bwd_grid(int *loop, const int grid_dim_x) {
dim3 grid_dim;
grid_dim.x = div_up(m_, PIXELS_PER_CTA_BWD);
int c_blks = div_up(c_, C_ELEMENTS_PER_CTA);
unsigned int max_grid_x = max_bwd_grid_x(device_id, max_cta_per_sm, cta_launch_margin);
unsigned int max_grid_x = grid_dim_x;
if (grid_dim.x <= max_grid_x) {
*loop = 1;
if (max_grid_x / grid_dim.x > 1) {
......@@ -599,7 +595,8 @@ dim3 NhwcBatchNormAddRelu::calc_bwd_grid(int device_id, int *loop, const int max
return grid_dim;
}
void NhwcBatchNormAddRelu::fwd(cudaStream_t stream, int device_id, void* my_data, void* pair_data, void* pair_data2, const int bn_group, const int magic, const int max_cta_per_sm, const int cta_launch_margin) {
void NhwcBatchNormAddRelu::fwd(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3,
const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop) {
bool ptrs_are_set =
X_tensor_desc_ != nullptr
&& Y_tensor_desc_ != nullptr
......@@ -630,16 +627,18 @@ void NhwcBatchNormAddRelu::fwd(cudaStream_t stream, int device_id, void* my_data
_setFwdParams(&params);
params.my_data = my_data;
params.pair_data = pair_data;
params.pair_data2 = pair_data2;
params.pair_datas[0] = pair_data;
params.pair_datas[1] = pair_data2;
params.pair_datas[2] = pair_data3;
params.magic = magic;
params.sync_iters = bn_group >> 1;
params.sync_iters = (bn_group==8)?3:(bn_group >> 1);
dim3 grid_dim = calc_fwd_grid(device_id, &params.outer_loops, max_cta_per_sm, cta_launch_margin);
_fwdKernelLauncher(stream, params, grid_dim, params.outer_loops, device_id, max_cta_per_sm);
dim3 grid_dim = calc_fwd_grid(&params.outer_loops, grid_dim_x);
_fwdKernelLauncher(stream, params, grid_dim, params.outer_loops, occupancy, coop);
}
void NhwcBatchNormAddRelu::dgrad(cudaStream_t stream, int device_id, void* my_data, void* pair_data, void* pair_data2, const int bn_group, const int magic, const int max_cta_per_sm, const int cta_launch_margin) {
void NhwcBatchNormAddRelu::dgrad(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3,
const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop) {
bool ptrs_are_set =
X_tensor_desc_ != nullptr
&& Y_tensor_desc_ != nullptr
......@@ -668,14 +667,15 @@ void NhwcBatchNormAddRelu::dgrad(cudaStream_t stream, int device_id, void* my_da
_setBwdParams(&params);
params.my_data = my_data;
params.pair_data = pair_data;
params.pair_data2 = pair_data2;
params.pair_datas[0] = pair_data;
params.pair_datas[1] = pair_data2;
params.pair_datas[2] = pair_data3;
params.magic = magic;
params.sync_iters = bn_group >> 1;
params.sync_iters = (bn_group==8)?3:(bn_group >> 1);
params.wgrad_coeff = 1.0 / bn_group;
dim3 grid_dim = calc_bwd_grid(device_id, &params.outer_loops, max_cta_per_sm, cta_launch_margin);
_bwdKernelLauncher(stream, params, grid_dim, params.outer_loops, device_id, max_cta_per_sm);
dim3 grid_dim = calc_bwd_grid(&params.outer_loops, grid_dim_x);
_bwdKernelLauncher(stream, params, grid_dim, params.outer_loops, occupancy, coop);
}
#endif // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_
......@@ -7,17 +7,6 @@ namespace cuda {
namespace utils {
//eventually should be replaced by real query functions
static inline int MultiprocessorCount(int device_id) {
return getDeviceProperties(device_id)->multiProcessorCount;
}
static inline int SMArch(int device_id) {
auto device_property = getDeviceProperties(device_id);
int cc = device_property->major * 10 + device_property->minor;
return cc;
}
static inline int MaxSharedMemoryPerMultiprocessor(int device_id) {
return getDeviceProperties(device_id)->sharedMemPerMultiprocessor;
}
......
......@@ -36,16 +36,19 @@ at::Tensor nhwc_bn_fwd_train(
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const at::Tensor& ret_cta,
const float momentum,
const float epsilon,
const bool fuse_relu,
void* my_data,
void* pair_data,
void* pair_data2,
void* pair_data3,
const int bn_group,
const at::Tensor& magic_tensor,
const int max_cta_per_sm,
const int cta_launch_margin);
const int occupancy,
const int grid_dim_x,
const bool coop);
at::Tensor nhwc_bn_fwd_eval(
const at::Tensor& x,
......@@ -53,6 +56,7 @@ at::Tensor nhwc_bn_fwd_eval(
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& ret_cta,
const int bn_group,
const float momentum,
const float epsilon,
......@@ -67,16 +71,19 @@ std::vector<at::Tensor> nhwc_bn_bwd(
const at::Tensor& running_inv_var,
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const at::Tensor& ret_cta,
const float momentum,
const float epsilon,
const bool fuse_relu,
void* my_data,
void* pair_data,
void* pair_data2,
void* pair_data3,
const int bn_group,
const at::Tensor& magic_tensor,
const int max_cta_per_sm,
const int cta_launch_margin);
const int occupancy,
const int grid_dim_x,
const bool coop);
at::Tensor nhwc_bn_addrelu_fwd_train(
const at::Tensor& x,
......@@ -88,15 +95,18 @@ at::Tensor nhwc_bn_addrelu_fwd_train(
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const at::Tensor& bitmask,
const at::Tensor& ret_cta,
const float momentum,
const float epsilon,
void* my_data,
void* pair_data,
void* pair_data2,
void* pair_data3,
const int bn_group,
const at::Tensor& magic_tensor,
const int max_cta_per_sm,
const int cta_launch_margin);
const int occupancy,
const int grid_dim_x,
const bool coop);
at::Tensor nhwc_bn_addrelu_fwd_eval(
const at::Tensor& x,
......@@ -105,6 +115,7 @@ at::Tensor nhwc_bn_addrelu_fwd_eval(
const at::Tensor& bias,
const at::Tensor& running_mean,
const at::Tensor& running_inv_var,
const at::Tensor& ret_cta,
const int bn_group,
const float momentum,
const float epsilon);
......@@ -119,16 +130,24 @@ std::vector<at::Tensor> nhwc_bn_addrelu_bwd(
const at::Tensor& minibatch_mean,
const at::Tensor& minibatch_inv_var,
const at::Tensor& bitmask,
const at::Tensor& ret_cta,
const float momentum,
const float epsilon,
void* my_data,
void* pair_data,
void* pair_data2,
void* pair_data3,
const int bn_group,
const at::Tensor& magic_tensor,
const int max_cta_per_sm,
const int cta_launch_margin);
const int occupancy,
const int grid_dim_x,
const bool coop);
int nhwc_bn_fwd_occupancy();
int nhwc_bn_bwd_occupancy();
int nhwc_bn_addrelu_fwd_occupancy();
int nhwc_bn_addrelu_bwd_occupancy();
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......@@ -141,8 +160,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("bn_fwd_eval_nhwc", &nhwc_bn_fwd_eval, "bn_fwd_eval_nhwc");
m.def("bn_bwd_nhwc", &nhwc_bn_bwd, "bn_bwd_nhwc");
m.def("bn_fwd_nhwc_occupancy", &nhwc_bn_fwd_occupancy, "bn_fwd_nhwc_occupancy");
m.def("bn_bwd_nhwc_occupancy", &nhwc_bn_bwd_occupancy, "bn_bwd_nhwc_occupancy");
m.def("bn_addrelu_fwd_nhwc", &nhwc_bn_addrelu_fwd_train, "bn_addrelu_fwd_nhwc");
m.def("bn_addrelu_fwd_eval_nhwc", &nhwc_bn_addrelu_fwd_eval, "bn_addrelu_fwd_eval_nhwc");
m.def("bn_addrelu_bwd_nhwc", &nhwc_bn_addrelu_bwd, "bn_addrelu_bwd_nhwc");
m.def("bn_addrelu_fwd_nhwc_occupancy", &nhwc_bn_addrelu_fwd_occupancy, "bn_addrelu_fwd_nhwc_occupancy");
m.def("bn_addrelu_bwd_nhwc_occupancy", &nhwc_bn_addrelu_bwd_occupancy, "bn_addrelu_bwd_nhwc_occupancy");
}
......@@ -58,7 +58,7 @@ const int MAX_BLOCK_Y = 256;
const int MAX_OFFSET = REDUCE_OPS*MAX_BLOCK_Y;
const int BYTES_PER_ELEM = 4;
// Buffer size per sync step
const int SINGLE_SYNC_BUFFER_BYTES = MAX_OFFSET*THREADS_PER_PIXEL*(1+ELEMENTS_PER_LDG)*BYTES_PER_ELEM;
const int SINGLE_SYNC_BUFFER_BYTES = MAX_OFFSET*THREADS_PER_PIXEL*2*ELEMENTS_PER_LDG*BYTES_PER_ELEM;
};
class IpcMemHandleRegistry {
......
......@@ -364,9 +364,11 @@ DEVICE_FUNCTION void relu_activation(float (&x)[N]) {
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int THREADS_PER_CTA >
DEVICE_FUNCTION void parallel_sums_16x2(float *smem, float (&x)[4], int nhw, void* params_my_data, void* params_pair_data, int off, const int magic, void* params_pair_data2, const unsigned int& sync_iters) {
DEVICE_FUNCTION void parallel_sums_16x2(float *smem, float (&x)[4], int nhw,
void* params_my_data, void** params_pair_datas, int off,
const int magic,
const int sync_iters) {
// The size of a warp.
const int THREADS_PER_WARP = 32;
// The number of warps in a CTA.
......@@ -383,11 +385,8 @@ DEVICE_FUNCTION void parallel_sums_16x2(float *smem, float (&x)[4], int nhw, voi
// The warp decomposition.
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int lane_id = threadIdx.x % THREADS_PER_WARP;
#ifdef BNDEBUGX
if (threadIdx.x==0)
printf("start parallel_sums_16x2 off=%d magic=%d sync_iters=%d thread%d block %d , %d\n", off, magic, sync_iters, threadIdx.x, blockIdx.x, blockIdx.y);
#endif
// total size of data per sync iter
const int data_total = MAX_OFFSET*THREADS_PER_PIXEL*ELEMENTS_PER_LDG*2;
#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
......@@ -426,69 +425,52 @@ DEVICE_FUNCTION void parallel_sums_16x2(float *smem, float (&x)[4], int nhw, voi
// Store the final values.
if (threadIdx.x < THREADS_PER_PIXEL) {
//probably could do it earlier, before sync
for (int sync_iter=0; sync_iter<sync_iters; ++sync_iter)
{
// total size of flags per sync iter, to be skiped for data
const int flags_total = MAX_OFFSET*THREADS_PER_PIXEL;
// total size of data per sync iter
const int data_total = MAX_OFFSET*THREADS_PER_PIXEL*ELEMENTS_PER_LDG;
//skip the space consumed by previous sync iterations
const int xbuf_offset = sync_iter*(flags_total+data_total);
// flags are at the begining of the buffer, one per thread
const int flags_offset = xbuf_offset + off*THREADS_PER_PIXEL;
// data starts after flags, but have to skip previous
const int data_offset = xbuf_offset + flags_total + off*ELEMENTS_PER_LDG*THREADS_PER_PIXEL + ELEMENTS_PER_LDG*threadIdx.x;
//after sums for this GPU were computed, let CTA0 broadcast the sum to over GPU
if (blockIdx.x==0)
{
volatile float * write_data = &(((float*)params_pair_data)[data_offset]);
volatile int32_t * write_flag = &(((int32_t*)((params_pair_data)))[flags_offset]);
//write the data to memory region to be reflected to other GPU
asm volatile ("st.global.wt.v4.f32 [%0], {%1,%2,%3,%4};"
:: "l"((float4 *)write_data) , "f"(x[0]), "f"( x[1]), "f"(x[2]), "f"( x[3]));
__threadfence_system();
//write the magic value to indicate data readiness
write_flag[threadIdx.x] = magic; //or can sync and set only one flag
#ifdef BNDEBUG
printf("writing buddy flag, thread %d myvalue %d data offset %d flag offset %d\n", threadIdx.x, magic, 4*THREADS_PER_PIXEL+off*ELEMENTS_PER_LDG*THREADS_PER_PIXEL + ELEMENTS_PER_LDG*threadIdx.x, off*THREADS_PER_PIXEL);
#endif
// probably could do it earlier, before sync
for (int sync_iter=0; sync_iter < sync_iters; ++sync_iter) {
//float* params_pair_data = (reinterpret_cast<float**>(params_pair_datas))[sync_iter];
void* params_pair_data = params_pair_datas[sync_iter];
// skip the space consumed by previous sync iterations
const int xbuf_offset = sync_iter*data_total;
// data starts after flags, but have to skip previous
const int data_offset = xbuf_offset
+ off*ELEMENTS_PER_LDG*THREADS_PER_PIXEL*2
+ ELEMENTS_PER_LDG*threadIdx.x*2;
// after sums for this GPU were computed, let CTA0 broadcast the sum to over GPU
if (blockIdx.x == 0) {
volatile float * write_data =
&((reinterpret_cast<float*>(params_pair_data))[data_offset]);
// write the data to memory region to be reflected to other GPU
asm volatile ("st.global.wt.v4.b32 [%0], {%1,%2,%3,%4};"
:: "l"(write_data) , "f"(x[0]), "r"(magic), "f"(x[2]), "r"(magic));
asm volatile ("st.global.wt.v4.b32 [%0], {%1,%2,%3,%4};"
:: "l"(write_data+4) , "f"(x[1]), "r"(magic), "f"(x[3]), "r"(magic));
}
//now each CTA (on each GPU) reads the data written by CTA 0 of the other GPU
volatile float * read_data_ = &(((float*)params_my_data)[data_offset]);
volatile int32_t * read_flag = &(((int32_t*)((params_my_data)))[flags_offset]);
//check if other side has written
#ifdef BNDEBUG
unsigned int safety=0;
while ((read_flag[threadIdx.x] % 1000000) != (magic % 1000000) )
{
++safety;
if (safety>99999) {
printf("stuck waiting for my buddy, thread %d myvalue %d data offset %d flag offset %d read value %d\n", threadIdx.x, magic, 4*THREADS_PER_PIXEL+off*ELEMENTS_PER_LDG*THREADS_PER_PIXEL + ELEMENTS_PER_LDG*threadIdx.x, off*THREADS_PER_PIXEL, read_flag[threadIdx.x]);
safety=0;
}
}
#else
while ((read_flag[threadIdx.x] ) != (magic ) ) ;
#endif
// now each CTA (on each GPU) reads the data written by CTA 0 of the other GPU
volatile float * read_data =
&((reinterpret_cast<float*>(params_my_data))[data_offset]);
float other[4];
asm volatile ("ld.global.cv.v4.f32 {%0, %1, %2, %3}, [%4];"
: "=f"(other[0]), "=f"(other[1]), "=f"(other[2]), "=f"(other[3]) : "l"(read_data_));
uint32_t other_flag_a, other_flag_b;
do {
asm volatile ("ld.volatile.global.v4.b32 {%0, %1, %2, %3}, [%4];"
: "=f"(other[0]), "=r"(other_flag_a), "=f"(other[2]), "=r"(other_flag_b) : "l"(read_data));
} while ((other_flag_a != magic) || (other_flag_b != magic));
do {
asm volatile ("ld.volatile.global.v4.b32 {%0, %1, %2, %3}, [%4];"
: "=f"(other[1]), "=r"(other_flag_a), "=f"(other[3]), "=r"(other_flag_b) : "l"(read_data+4));
} while ((other_flag_a != magic) || (other_flag_b != magic));
add(x, other);
params_pair_data = params_pair_data2; //FIXME use an array
}
// finally, after syncing up and accounting for partial sums from other GPUs as required, write the result
// finally, after syncing up and accounting for partial sums from
// other GPUs as required, write the result
write_to_smem(smem, threadIdx.x, x);
......@@ -655,12 +637,12 @@ template<>
struct ParallelSums<16, 4> {
template< int THREADS_PER_CTA >
DEVICE_FUNCTION void dispatch(float *smem, float (&x)[4], int nhw) {
parallel_sums_16x2<THREADS_PER_CTA>(smem, x, nhw, 0, 0, 0, 0, 0, 0);
parallel_sums_16x2<THREADS_PER_CTA>(smem, x, nhw, 0, 0, 0, 0, 0);
}
template< int THREADS_PER_CTA >
DEVICE_FUNCTION void dispatchX(float *smem, float (&x)[4], int nhw, void* params_my_data, void* params_pair_data, int off, const int magic, void* params_pair_data2, const unsigned int& sync_iters) {
parallel_sums_16x2<THREADS_PER_CTA>(smem, x, nhw, params_my_data, params_pair_data, off, magic, params_pair_data2, sync_iters);
DEVICE_FUNCTION void dispatchX(float *smem, float (&x)[4], int nhw, void* params_my_data, void** params_pair_datas, int off, const int magic, const unsigned int& sync_iters) {
parallel_sums_16x2<THREADS_PER_CTA>(smem, x, nhw, params_my_data, params_pair_datas, off, magic, sync_iters);
}
};
......@@ -668,9 +650,6 @@ template<>
struct ParallelSums<8, 4> {
template< int THREADS_PER_CTA >
DEVICE_FUNCTION void dispatch(float *smem, float (&x)[4], int nhw) {
#ifdef BNDEBUGX
assert(0);
#endif
parallel_sums_8x4<THREADS_PER_CTA>(smem, x, nhw);
}
};
......@@ -687,10 +666,6 @@ static inline int div_up(int m, int n) {
// It is expected that all threads in the CTA enter this function!
DEVICE_FUNCTION void inter_block_sync(int* gmem_retired_ctas, int expected_count, bool master) {
#ifdef BNDEBUGX
if (threadIdx.x==0)
printf("start inter_block_sync thread%d block %d , %d grid.X %d\n", threadIdx.x, blockIdx.x, blockIdx.y, gridDim.x);
#endif
// Register the CTA.
if (threadIdx.x == 0) {
......@@ -714,10 +689,6 @@ DEVICE_FUNCTION void inter_block_sync(int* gmem_retired_ctas, int expected_count
} while (retired_ctas != 0);
}
__syncthreads();
#ifdef BNDEBUGX
if (threadIdx.x==0)
printf("finish inter_block_sync thread%d block %d , %d\n", threadIdx.x, blockIdx.x, blockIdx.y);
#endif
}
......@@ -858,8 +829,7 @@ struct NhwcBatchNormFwdParams {
int c_blks;
void* my_data;
void* pair_data;
void* pair_data2;
void* pair_datas[4];
int magic;
int sync_iters;
};
......@@ -1185,7 +1155,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
if (params.sync_iters>0)
{
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, m1, thread_in_cta_nhw, params.my_data, params.pair_data, 4*c_blk_index+3, params.magic, params.pair_data2, params.sync_iters);
smem, m1, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+3, params.magic, params.sync_iters);
} else {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, m1, thread_in_cta_nhw);
......@@ -1242,7 +1212,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
if (params.sync_iters>0)
{
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, m2, thread_in_cta_nhw, params.my_data, params.pair_data, 4*c_blk_index+2, params.magic, params.pair_data2, params.sync_iters);
smem, m2, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+2, params.magic, params.sync_iters);
} else {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, m2, thread_in_cta_nhw);
......@@ -1440,8 +1410,7 @@ struct NhwcBatchNormBwdParams {
int c_blks;
void* my_data;
void* pair_data;
void* pair_data2;
void* pair_datas[4];
int magic;
int sync_iters;
float wgrad_coeff;
......@@ -1582,11 +1551,6 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
// Shared memory buffer to store the extra pixels.
extern __shared__ PackedStorageType smem_storage_packed[];
#ifdef BNDEBUGX
if (threadIdx.x==0)
printf("starting nhwc_batch_norm_bwd\n");
#endif
for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) {
// The position in the NHW dimension where the CTA starts.
int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS;
......@@ -1778,7 +1742,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
// dscale parallel sum
if (params.sync_iters>0) {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_data, 4*c_blk_index+1, params.magic, params.pair_data2, params.sync_iters);
smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+1, params.magic, params.sync_iters);
} else {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, dscale, thread_in_cta_nhw);
......@@ -1792,7 +1756,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
// dbias parallel sum
if (params.sync_iters>0) {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_data, 4*c_blk_index+0, params.magic, params.pair_data2, params.sync_iters);
smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+0, params.magic, params.sync_iters);
} else {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, dbias, thread_in_cta_nhw);
......@@ -1950,10 +1914,6 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
// Shared memory buffer to store the extra pixels.
extern __shared__ PackedStorageType smem_storage_packed[];
#ifdef BNDEBUGX
if (threadIdx.x==0)
printf("starting nhwc_batch_norm_bwd_relu\n");
#endif
for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) {
// The position in the NHW dimension where the CTA starts.
......@@ -2172,7 +2132,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
// dscale parallel sum
if (params.sync_iters>0) {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_data, 4*c_blk_index+1, params.magic, params.pair_data2, params.sync_iters);
smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+1, params.magic, params.sync_iters);
} else {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, dscale, thread_in_cta_nhw);
......@@ -2186,7 +2146,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
// dbias parallel sum
if (params.sync_iters>0) {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_data, 4*c_blk_index+0, params.magic, params.pair_data2, params.sync_iters);
smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+0, params.magic, params.sync_iters);
} else {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, dbias, thread_in_cta_nhw);
......@@ -2343,11 +2303,6 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
// Shared memory buffer to store the extra pixels.
extern __shared__ PackedStorageType smem_storage_packed[];
#ifdef BNDEBUGX
if (threadIdx.x==0)
printf("starting nhwc_batch_norm_bwd_add_relu\n");
#endif
for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) {
// The position in the NHW dimension where the CTA starts.
int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS;
......@@ -2595,7 +2550,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
// dscale parallel sum
if (params.sync_iters>0) {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_data, 4*c_blk_index+1, params.magic, params.pair_data2, params.sync_iters);
smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+1, params.magic, params.sync_iters);
} else {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, dscale, thread_in_cta_nhw);
......@@ -2609,7 +2564,7 @@ __global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY)
// dbias parallel sum
if (params.sync_iters>0) {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatchX<THREADS_PER_CTA>(
smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_data, 4*c_blk_index+0, params.magic, params.pair_data2, params.sync_iters);
smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+0, params.magic, params.sync_iters);
} else {
ParallelSums<THREADS_PER_PIXEL, ELEMENTS_PER_LDG>::dispatch<THREADS_PER_CTA>(
smem, dbias, thread_in_cta_nhw);
......
......@@ -6,91 +6,105 @@ import bnp
class bn_NHWC_impl(torch.autograd.Function):
@staticmethod
def forward(ctx, x, s, b, rm, riv, mini_m, mini_riv, mom, epsilon, fuse_relu=False, is_train=True, bn_group=1, my_data=None, pair_data=None, magic=1, pair_data2=None, max_cta_per_sm=2, cta_launch_margin=12):
def forward(ctx, x, s, b, rm, riv, mini_m, mini_riv, ret_cta, mom, epsilon, fuse_relu, is_train, bn_group, my_data, pair_data, magic, pair_data2, pair_data3, fwd_occup, fwd_grid_x, bwd_occup, bwd_grid_x, multi_stream):
if is_train:
ctx.save_for_backward(x, s, b, rm, riv, mini_m, mini_riv)
ctx.epsilon = epsilon
ctx.momentum = mom
ctx.ret_cta = ret_cta
ctx.fuse_relu = fuse_relu
ctx.my_data = my_data
ctx.pair_data = pair_data
ctx.magic = magic
ctx.pair_data2 = pair_data2
ctx.pair_data3 = pair_data3
ctx.bn_group = bn_group
ctx.max_cta_per_sm = max_cta_per_sm
ctx.cta_launch_margin = cta_launch_margin
ctx.bwd_occup = bwd_occup
ctx.bwd_grid_x = bwd_grid_x
ctx.multi_stream = multi_stream
res = bnp.bn_fwd_nhwc(x, s, b, rm, riv, mini_m, mini_riv, mom, epsilon, fuse_relu, my_data, pair_data, pair_data2, bn_group, magic, max_cta_per_sm, cta_launch_margin)
res = bnp.bn_fwd_nhwc(x, s, b, rm, riv, mini_m, mini_riv, ret_cta, mom, epsilon, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, fwd_occup, fwd_grid_x, multi_stream)
return res
else:
return bnp.bn_fwd_eval_nhwc(x, s, b, rm, riv, bn_group, mom, epsilon, fuse_relu)
return bnp.bn_fwd_eval_nhwc(x, s, b, rm, riv, ret_cta, bn_group, mom, epsilon, fuse_relu)
@staticmethod
def backward(ctx, grad_y):
x, s, b, rm, riv, mini_m, mini_riv = ctx.saved_variables
epsilon = ctx.epsilon
mom = ctx.momentum
ret_cta = ctx.ret_cta
fuse_relu = ctx.fuse_relu
my_data = ctx.my_data
pair_data = ctx.pair_data
magic = ctx.magic
pair_data2 = ctx.pair_data2
pair_data3 = ctx.pair_data3
bn_group = ctx.bn_group
max_cta_per_sm = ctx.max_cta_per_sm
cta_launch_margin = ctx.cta_launch_margin
bwd_occup = ctx.bwd_occup
bwd_grid_x = ctx.bwd_grid_x
multi_stream = ctx.multi_stream
dx, dscale, dbias = bnp.bn_bwd_nhwc(x, grad_y, s, b, rm, riv, mini_m, mini_riv, mom, epsilon, fuse_relu, my_data, pair_data, pair_data2, bn_group, magic, max_cta_per_sm, cta_launch_margin)
dx, dscale, dbias = bnp.bn_bwd_nhwc(x, grad_y, s, b, rm, riv, mini_m, mini_riv, ret_cta, mom, epsilon, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, bwd_occup, bwd_grid_x, multi_stream)
return dx, dscale, dbias, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
return dx, dscale, dbias, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
class bn_addrelu_NHWC_impl(torch.autograd.Function):
@staticmethod
def forward(ctx, x, z, s, b, rm, riv, mini_m, mini_riv, mom, epsilon, is_train=True, bn_group=1, my_data=None, pair_data=None, magic=1, pair_data2=None, max_cta_per_sm=2, cta_launch_margin=12):
def forward(ctx, x, z, s, b, rm, riv, mini_m, mini_riv, grid_dim_y, ret_cta, mom, epsilon, is_train, bn_group, my_data, pair_data, magic, pair_data2, pair_data3, fwd_occup, fwd_grid_x, bwd_occup, bwd_grid_x, multi_stream):
if is_train:
bitmask = torch.cuda.IntTensor(x.numel()//32)
bitmask = torch.cuda.IntTensor(((x.numel()+31)//32) * 2 * grid_dim_y)
ctx.save_for_backward(x, s, b, rm, riv, mini_m, mini_riv, bitmask)
ctx.epsilon = epsilon
ctx.momentum = mom
ctx.ret_cta = ret_cta
ctx.my_data = my_data
ctx.pair_data = pair_data
ctx.magic = magic
ctx.pair_data2 = pair_data2
ctx.pair_data3 = pair_data3
ctx.bn_group = bn_group
ctx.max_cta_per_sm = max_cta_per_sm
ctx.cta_launch_margin = cta_launch_margin
ctx.bwd_occup = bwd_occup
ctx.bwd_grid_x = bwd_grid_x
ctx.multi_stream = multi_stream
res = bnp.bn_addrelu_fwd_nhwc(x, z, s, b, rm, riv, mini_m, mini_riv, bitmask, mom, epsilon, my_data, pair_data, pair_data2, bn_group, magic, max_cta_per_sm, cta_launch_margin)
res = bnp.bn_addrelu_fwd_nhwc(x, z, s, b, rm, riv, mini_m, mini_riv, bitmask, ret_cta, mom, epsilon, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, fwd_occup, fwd_grid_x, multi_stream)
return res
else:
return bnp.bn_addrelu_fwd_eval_nhwc(x, z, s, b, rm, riv, bn_group, mom, epsilon)
return bnp.bn_addrelu_fwd_eval_nhwc(x, z, s, b, rm, riv, ret_cta, bn_group, mom, epsilon)
@staticmethod
def backward(ctx, grad_y):
x, s, b, rm, riv, mini_m, mini_riv, bitmask = ctx.saved_variables
epsilon = ctx.epsilon
mom = ctx.momentum
ret_cta = ctx.ret_cta
my_data = ctx.my_data
pair_data = ctx.pair_data
magic = ctx.magic
pair_data2 = ctx.pair_data2
pair_data3 = ctx.pair_data3
bn_group = ctx.bn_group
max_cta_per_sm = ctx.max_cta_per_sm
cta_launch_margin = ctx.cta_launch_margin
bwd_occup = ctx.bwd_occup
bwd_grid_x = ctx.bwd_grid_x
multi_stream = ctx.multi_stream
dx, dz, dscale, dbias = bnp.bn_addrelu_bwd_nhwc(x, grad_y, s, b, rm, riv, mini_m, mini_riv, bitmask, mom, epsilon, my_data, pair_data, pair_data2, bn_group, magic, max_cta_per_sm, cta_launch_margin)
dx, dz, dscale, dbias = bnp.bn_addrelu_bwd_nhwc(x, grad_y, s, b, rm, riv, mini_m, mini_riv, bitmask, ret_cta, mom, epsilon, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, bwd_occup, bwd_grid_x, multi_stream)
return dx, dz, dscale, dbias, None, None, None, None, None, None, None, None, None, None, None, None, None, None
return dx, dz, dscale, dbias, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
class BatchNorm2d_NHWC(_BatchNorm):
def __init__(self, num_features, fuse_relu=False, bn_group=1, max_cta_per_sm=2, cta_launch_margin=12):
# if using BatchNorm2d_NHWC simultaneously with multiple streams set multi_stream to True
def __init__(self, num_features, fuse_relu=False, bn_group=1, max_cta_per_sm=2, cta_launch_margin=12, multi_stream=False):
super(BatchNorm2d_NHWC, self).__init__(num_features)
self.fuse_relu = fuse_relu
self.multi_stream = multi_stream
self.minibatch_mean = torch.cuda.FloatTensor(num_features)
self.minibatch_riv = torch.cuda.FloatTensor(num_features)
......@@ -102,13 +116,32 @@ class BatchNorm2d_NHWC(_BatchNorm):
self.my_data = None
self.pair_data = None
self.pair_data2 = None
self.pair_data3 = None
self.local_rank = 0
self.magic = torch.IntTensor([0])
#calculate cta per sm occupancies
assert(max_cta_per_sm>0) # won't be able to do much with 0 CTAs :)
self.fwd_occupancy = min(bnp.bn_fwd_nhwc_occupancy(), max_cta_per_sm)
self.bwd_occupancy = min(bnp.bn_bwd_nhwc_occupancy(), max_cta_per_sm)
self.addrelu_fwd_occupancy = min(bnp.bn_addrelu_fwd_nhwc_occupancy(), max_cta_per_sm)
self.addrelu_bwd_occupancy = min(bnp.bn_addrelu_bwd_nhwc_occupancy(), max_cta_per_sm)
#calculate grid dimentions based on occupancy numbers
mp_count = torch.cuda.get_device_properties(None).multi_processor_count
self.fwd_grid_dim_x = max(mp_count*self.fwd_occupancy - cta_launch_margin , 1)
self.bwd_grid_dim_x = max(mp_count*self.bwd_occupancy - cta_launch_margin , 1)
self.addrelu_fwd_grid_dim_x = max(mp_count*self.addrelu_fwd_occupancy - cta_launch_margin , 1)
self.addrelu_bwd_grid_dim_x = max(mp_count*self.addrelu_bwd_occupancy - cta_launch_margin , 1)
self.grid_dim_y = (num_features + 63) // 64
# allocate scratch space used by implementation
# TODO: scratch space that is not supposed to be exposed at user code. We only need one time initialization, the
# same buffer could be reused in future iterations. Currently we exposed it here instead of requesting new
# buffer from cache allocator to avoid unnecessary initialization at future iterations.
self.ret_cta = torch.cuda.ByteTensor(8192).fill_(0)
#FIXME: turn pair handles into an array
if bn_group>1:
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
......@@ -118,6 +151,8 @@ class BatchNorm2d_NHWC(_BatchNorm):
bn_sync_steps = 1
if (bn_group==4):
bn_sync_steps = 2
if (bn_group==8):
bn_sync_steps = 3
self.ipc_buffer = torch.cuda.ByteTensor(bnp.get_buffer_size(bn_sync_steps))
self.my_data = bnp.get_data_ptr(self.ipc_buffer)
......@@ -148,6 +183,11 @@ class BatchNorm2d_NHWC(_BatchNorm):
pair_offset2 = offsets_l[local_rank ^ 2].cpu()
self.pair_data2 = bnp.get_remote_data_ptr(self.pair_handle2, pair_offset2)
if bn_group>4:
self.pair_handle3 = handles_l[local_rank ^ 3].cpu().contiguous()
pair_offset3 = offsets_l[local_rank ^ 3].cpu()
self.pair_data3 = bnp.get_remote_data_ptr(self.pair_handle3, pair_offset3)
#FIXME: get magic value into C code and eliminate from here
self.magic = torch.IntTensor([2])
self.local_rank = local_rank
......@@ -159,21 +199,27 @@ class BatchNorm2d_NHWC(_BatchNorm):
return bn_addrelu_NHWC_impl.apply(x, z,
self.weight, self.bias,
self.running_mean, self.running_var,
self.minibatch_mean, self.minibatch_riv,
self.minibatch_mean, self.minibatch_riv, self.grid_dim_y, self.ret_cta,
self.momentum,
self.eps, self.training, self.bn_group, self.my_data, self.pair_data, (self.magic), self.pair_data2,
self.max_cta_per_sm, self.cta_launch_margin)
self.eps, self.training, self.bn_group, self.my_data, self.pair_data, (self.magic), self.pair_data2, self.pair_data3,
self.addrelu_fwd_occupancy, self.addrelu_fwd_grid_dim_x,
self.addrelu_bwd_occupancy, self.addrelu_bwd_grid_dim_x,
self.multi_stream)
else:
return bn_NHWC_impl.apply(x,
self.weight, self.bias,
self.running_mean, self.running_var,
self.minibatch_mean, self.minibatch_riv,
self.minibatch_mean, self.minibatch_riv, self.ret_cta,
self.momentum,
self.eps, self.fuse_relu, self.training, self.bn_group, self.my_data, self.pair_data, (self.magic), self.pair_data2,
self.max_cta_per_sm, self.cta_launch_margin)
self.eps, self.fuse_relu, self.training, self.bn_group, self.my_data, self.pair_data, (self.magic), self.pair_data2, self.pair_data3,
self.fwd_occupancy, self.fwd_grid_dim_x,
self.bwd_occupancy, self.bwd_grid_dim_x,
self.multi_stream)
def __del__(self):
if self.bn_group>1:
bnp.close_remote_data(self.pair_handle)
if self.bn_group>2:
bnp.close_remote_data(self.pair_handle2)
if self.bn_group>4:
bnp.close_remote_data(self.pair_handle3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment