Unverified Commit 8a1ed9e8 authored by lly-zero-one's avatar lly-zero-one Committed by GitHub
Browse files

Optimize the sync batchnorm by batching the communication (#980)

In this PR, we mainly tried to optimize the performance of Syncatchnorm and also fixed one potential issue in the welford_parallel kernel implementation.

For performance improvement, we batched the mean/var/count all_gather communication together and sent it once in the forward path
We also batch the all_reduce in backward path
We add the contiguous call on the input of welford_parallel kernel.
If there is any standard perf benchmark, I would be happy to run it.
parent a109f856
...@@ -21,28 +21,26 @@ class SyncBatchnormFunction(Function): ...@@ -21,28 +21,26 @@ class SyncBatchnormFunction(Function):
if channel_last: if channel_last:
count = int(input.numel()/input.size(-1)) count = int(input.numel()/input.size(-1))
mean, var_biased = syncbn.welford_mean_var_c_last(input) mean, var_biased = syncbn.welford_mean_var_c_last(input)
num_channels = input.size(-1)
else: else:
count = int(input.numel()/input.size(1)) count = int(input.numel()/input.size(1))
mean, var_biased = syncbn.welford_mean_var(input) mean, var_biased = syncbn.welford_mean_var(input)
num_channels = input.size(1)
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
if not process_group: if not process_group:
process_group = torch.distributed.group.WORLD process_group = torch.distributed.group.WORLD
device = mean.device device = mean.device
world_size = torch.distributed.get_world_size(process_group) world_size = torch.distributed.get_world_size(process_group)
mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=device)
var_all = torch.empty(world_size, var_biased.size(0), dtype=var_biased.dtype, device=device) count_t = torch.empty(1, dtype=mean.dtype, device=mean.device).fill_(count)
count_all = torch.cuda.IntTensor(world_size, device=device) combined = torch.cat([mean.view(-1), var_biased.view(-1), count_t], dim=0)
mean_l = [mean_all.narrow(0, i, 1).view(-1) for i in range(world_size)] combined_list = [torch.empty_like(combined) for k in range(world_size)]
var_l = [var_all.narrow(0, i, 1).view(-1) for i in range(world_size)] torch.distributed.all_gather(combined_list, combined, process_group)
count_l = [count_all.narrow(0, i, 1) for i in range(world_size)] combined = torch.stack(combined_list, dim=0)
torch.distributed.all_gather(mean_l, mean.view(-1), process_group) mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
torch.distributed.all_gather(var_l, var_biased.view(-1), process_group) count_all = count_all.view(-1)
torch.distributed.all_gather( mean, var, inv_std = syncbn.welford_parallel(mean_all, invstd_all, count_all.to(torch.int32), eps)
count_l,
torch.cuda.IntTensor([count], device=device),
process_group)
mean, var, inv_std = syncbn.welford_parallel(mean_all, var_all, count_all, eps)
else: else:
device = mean.device device = mean.device
count_all = torch.cuda.IntTensor([count], device=device) count_all = torch.cuda.IntTensor([count], device=device)
...@@ -60,7 +58,7 @@ class SyncBatchnormFunction(Function): ...@@ -60,7 +58,7 @@ class SyncBatchnormFunction(Function):
mean = running_mean.data mean = running_mean.data
inv_std = 1.0 / torch.sqrt(running_variance.data + eps) inv_std = 1.0 / torch.sqrt(running_variance.data + eps)
ctx.save_for_backward(input, weight, mean, inv_std, z, bias, count_all) ctx.save_for_backward(input, weight, mean, inv_std, z, bias, count_all.to(torch.int32))
ctx.process_group = process_group ctx.process_group = process_group
ctx.channel_last = channel_last ctx.channel_last = channel_last
ctx.world_size = world_size ctx.world_size = world_size
...@@ -101,10 +99,12 @@ class SyncBatchnormFunction(Function): ...@@ -101,10 +99,12 @@ class SyncBatchnormFunction(Function):
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
num_channels = sum_dy.shape[0]
combined = torch.cat([sum_dy, sum_dy_xmu], dim=0)
torch.distributed.all_reduce( torch.distributed.all_reduce(
sum_dy, ReduceOp.SUM, process_group) combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False)
torch.distributed.all_reduce( sum_dy, sum_dy_xmu = torch.split(combined, num_channels)
sum_dy_xmu, ReduceOp.SUM, process_group)
if channel_last: if channel_last:
grad_input = syncbn.batchnorm_backward_c_last(grad_output, saved_input, mean, inv_std, weight, sum_dy, sum_dy_xmu, count) grad_input = syncbn.batchnorm_backward_c_last(grad_output, saved_input, mean, inv_std, weight, sum_dy, sum_dy_xmu, count)
else: else:
......
...@@ -1155,6 +1155,10 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node ...@@ -1155,6 +1155,10 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node
at::Tensor inv_std = at::empty_like(out_var); at::Tensor inv_std = at::empty_like(out_var);
at::Tensor out_mean = at::empty_like(out_var); at::Tensor out_mean = at::empty_like(out_var);
at::Tensor mean_feature_nodes_ = mean_feature_nodes.contiguous();
at::Tensor var_biased_ = var_biased.contiguous();
at::Tensor numel_ = numel.contiguous();
// TODO(jie): tile this for memory coalescing! // TODO(jie): tile this for memory coalescing!
const int block = std::min(h_last_pow2(feature_size), MAX_BLOCK_SIZE); const int block = std::min(h_last_pow2(feature_size), MAX_BLOCK_SIZE);
const int grid = std::max<int>(1, feature_size / block); const int grid = std::max<int>(1, feature_size / block);
...@@ -1165,9 +1169,9 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node ...@@ -1165,9 +1169,9 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node
using namespace at; using namespace at;
DISPATCH_FLOAT_AND_HALF(mean_feature_nodes.scalar_type(), 0, "welford_parallel_kernel", DISPATCH_FLOAT_AND_HALF(mean_feature_nodes.scalar_type(), 0, "welford_parallel_kernel",
welford_kernel_parallel<scalar_t_0><<<grid, block, 0, stream>>>( welford_kernel_parallel<scalar_t_0><<<grid, block, 0, stream>>>(
mean_feature_nodes.DATA_PTR<scalar_t_0>(), mean_feature_nodes_.DATA_PTR<scalar_t_0>(),
var_biased.DATA_PTR<scalar_t_0>(), var_biased_.DATA_PTR<scalar_t_0>(),
numel.DATA_PTR<int>(), numel_.DATA_PTR<int>(),
out_mean.DATA_PTR<scalar_t_0>(), out_mean.DATA_PTR<scalar_t_0>(),
out_var.DATA_PTR<scalar_t_0>(), out_var.DATA_PTR<scalar_t_0>(),
inv_std.DATA_PTR<scalar_t_0>(), inv_std.DATA_PTR<scalar_t_0>(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment