Unverified Commit 8a1ed9e8 authored by lly-zero-one's avatar lly-zero-one Committed by GitHub
Browse files

Optimize the sync batchnorm by batching the communication (#980)

In this PR, we mainly tried to optimize the performance of Syncatchnorm and also fixed one potential issue in the welford_parallel kernel implementation.

For performance improvement, we batched the mean/var/count all_gather communication together and sent it once in the forward path
We also batch the all_reduce in backward path
We add the contiguous call on the input of welford_parallel kernel.
If there is any standard perf benchmark, I would be happy to run it.
parent a109f856
......@@ -21,33 +21,31 @@ class SyncBatchnormFunction(Function):
if channel_last:
count = int(input.numel()/input.size(-1))
mean, var_biased = syncbn.welford_mean_var_c_last(input)
num_channels = input.size(-1)
else:
count = int(input.numel()/input.size(1))
mean, var_biased = syncbn.welford_mean_var(input)
num_channels = input.size(1)
if torch.distributed.is_initialized():
if not process_group:
process_group = torch.distributed.group.WORLD
device = mean.device
world_size = torch.distributed.get_world_size(process_group)
mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=device)
var_all = torch.empty(world_size, var_biased.size(0), dtype=var_biased.dtype, device=device)
count_all = torch.cuda.IntTensor(world_size, device=device)
mean_l = [mean_all.narrow(0, i, 1).view(-1) for i in range(world_size)]
var_l = [var_all.narrow(0, i, 1).view(-1) for i in range(world_size)]
count_l = [count_all.narrow(0, i, 1) for i in range(world_size)]
torch.distributed.all_gather(mean_l, mean.view(-1), process_group)
torch.distributed.all_gather(var_l, var_biased.view(-1), process_group)
torch.distributed.all_gather(
count_l,
torch.cuda.IntTensor([count], device=device),
process_group)
mean, var, inv_std = syncbn.welford_parallel(mean_all, var_all, count_all, eps)
count_t = torch.empty(1, dtype=mean.dtype, device=mean.device).fill_(count)
combined = torch.cat([mean.view(-1), var_biased.view(-1), count_t], dim=0)
combined_list = [torch.empty_like(combined) for k in range(world_size)]
torch.distributed.all_gather(combined_list, combined, process_group)
combined = torch.stack(combined_list, dim=0)
mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
count_all = count_all.view(-1)
mean, var, inv_std = syncbn.welford_parallel(mean_all, invstd_all, count_all.to(torch.int32), eps)
else:
device = mean.device
count_all = torch.cuda.IntTensor([count], device=device)
inv_std = 1.0 / torch.sqrt(var_biased + eps)
var = var_biased * (count) / (count-1)
var = var_biased * (count) / (count-1)
if count == 1 and world_size < 2:
raise ValueError('Expected more than 1 value per channel when training, got input size{}'.format(input.size()))
......@@ -60,7 +58,7 @@ class SyncBatchnormFunction(Function):
mean = running_mean.data
inv_std = 1.0 / torch.sqrt(running_variance.data + eps)
ctx.save_for_backward(input, weight, mean, inv_std, z, bias, count_all)
ctx.save_for_backward(input, weight, mean, inv_std, z, bias, count_all.to(torch.int32))
ctx.process_group = process_group
ctx.channel_last = channel_last
ctx.world_size = world_size
......@@ -101,10 +99,12 @@ class SyncBatchnormFunction(Function):
if ctx.needs_input_grad[0]:
if torch.distributed.is_initialized():
num_channels = sum_dy.shape[0]
combined = torch.cat([sum_dy, sum_dy_xmu], dim=0)
torch.distributed.all_reduce(
sum_dy, ReduceOp.SUM, process_group)
torch.distributed.all_reduce(
sum_dy_xmu, ReduceOp.SUM, process_group)
combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False)
sum_dy, sum_dy_xmu = torch.split(combined, num_channels)
if channel_last:
grad_input = syncbn.batchnorm_backward_c_last(grad_output, saved_input, mean, inv_std, weight, sum_dy, sum_dy_xmu, count)
else:
......
......@@ -1155,6 +1155,10 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node
at::Tensor inv_std = at::empty_like(out_var);
at::Tensor out_mean = at::empty_like(out_var);
at::Tensor mean_feature_nodes_ = mean_feature_nodes.contiguous();
at::Tensor var_biased_ = var_biased.contiguous();
at::Tensor numel_ = numel.contiguous();
// TODO(jie): tile this for memory coalescing!
const int block = std::min(h_last_pow2(feature_size), MAX_BLOCK_SIZE);
const int grid = std::max<int>(1, feature_size / block);
......@@ -1165,9 +1169,9 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node
using namespace at;
DISPATCH_FLOAT_AND_HALF(mean_feature_nodes.scalar_type(), 0, "welford_parallel_kernel",
welford_kernel_parallel<scalar_t_0><<<grid, block, 0, stream>>>(
mean_feature_nodes.DATA_PTR<scalar_t_0>(),
var_biased.DATA_PTR<scalar_t_0>(),
numel.DATA_PTR<int>(),
mean_feature_nodes_.DATA_PTR<scalar_t_0>(),
var_biased_.DATA_PTR<scalar_t_0>(),
numel_.DATA_PTR<int>(),
out_mean.DATA_PTR<scalar_t_0>(),
out_var.DATA_PTR<scalar_t_0>(),
inv_std.DATA_PTR<scalar_t_0>(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment