Commit 5dad4c21 authored by jjsjann123's avatar jjsjann123 Committed by mcarilli
Browse files

[syncBN] (#90)

supporting user specified process group
parent bc62f325
......@@ -35,6 +35,9 @@ class SyncBatchNorm(_BatchNorm):
module tracks the running mean and variance, and when set to ``False``,
this module does not track such statistics and always uses batch
statistics in both training and eval modes. Default: ``True``
process_group: pass in a process group within which the stats of the
mini-batch is being synchronized. ``None`` for using default process
group
Examples::
>>> sbn = apex.parallel.SyncBatchNorm(100).cuda()
......@@ -44,8 +47,12 @@ class SyncBatchNorm(_BatchNorm):
>>> out = sbn(inp)
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None):
super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
self.process_group = process_group
def _specify_process_group(self, process_group):
self.process_group = process_group
def forward(self, input):
if not self.training and self.track_running_stats:
......@@ -53,4 +60,4 @@ class SyncBatchNorm(_BatchNorm):
return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps)
else:
self.num_batches_tracked += 1
return SyncBatchnormFunction.apply(input, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.track_running_stats, self.momentum)
return SyncBatchnormFunction.apply(input, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.track_running_stats, self.momentum, self.process_group)
......@@ -6,21 +6,26 @@ import syncbn
class SyncBatchnormFunction(Function):
@staticmethod
def forward(ctx, input, weight, bias, running_mean, running_variance, eps, track_running_stats = True, momentum = 1.0):
def forward(ctx, input, weight, bias, running_mean, running_variance, eps, track_running_stats = True, momentum = 1.0, process_group = None):
torch.cuda.nvtx.range_push("sync_BN_fw")
input = input.contiguous()
world_size = 0
if track_running_stats:
mean, var, var_biased = syncbn.welford_mean_var(input)
if torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size()
if process_group:
world_size = torch.distributed.get_world_size(process_group)
else:
process_group = torch.distributed.get_default_group()
world_size = torch.distributed.get_world_size()
mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=mean.device)
var_all = torch.empty(world_size, var.size(0), dtype=var.dtype, device=var.device)
mean_l = [mean_all.narrow(0, i, 1) for i in range(world_size)]
var_l = [var_all.narrow(0, i, 1) for i in range(world_size)]
torch.distributed.all_gather(mean_l, mean)
torch.distributed.all_gather(var_l, var_biased)
torch.distributed.all_gather(mean_l, mean, process_group)
torch.distributed.all_gather(var_l, var_biased, process_group)
mean, var, var_biased = syncbn.welford_parallel(mean_all.transpose(1,0).contiguous(), var_all.transpose(1,0).contiguous(), int(input.numel()/input.size(1)))
# TODO(Jie): should do fp32 math instead!
......@@ -34,6 +39,8 @@ class SyncBatchnormFunction(Function):
ctx.save_for_backward(input, weight, mean, var_biased)
ctx.eps = eps
ctx.process_group = process_group
ctx.world_size = world_size
out = syncbn.batchnorm_forward(input, mean, var_biased, weight, bias, eps)
......@@ -49,6 +56,8 @@ class SyncBatchnormFunction(Function):
# var = 1./N*np.sum((h-mu)**2, axis = 0)
saved_input, weight, running_mean, running_variance = ctx.saved_tensors
eps = ctx.eps
process_group = ctx.process_group
world_size = ctx.world_size
grad_input = grad_weight = grad_bias = None
# TODO(jie): why do I have to clone here? life time of grad_output?
......@@ -59,11 +68,11 @@ class SyncBatchnormFunction(Function):
if torch.distributed.is_initialized():
torch.distributed.all_reduce(
mean_dy, op=torch.distributed.reduce_op.SUM)
mean_dy = mean_dy / torch.distributed.get_world_size()
mean_dy, torch.distributed.reduce_op.SUM, process_group)
mean_dy = mean_dy / world_size
torch.distributed.all_reduce(
mean_dy_xmu, op=torch.distributed.reduce_op.SUM)
mean_dy_xmu = mean_dy_xmu / torch.distributed.get_world_size()
mean_dy_xmu, torch.distributed.reduce_op.SUM, process_group)
mean_dy_xmu = mean_dy_xmu / world_size
grad_input = syncbn.batchnorm_backward(grad_output, saved_input, running_mean, running_variance, weight, mean_dy, mean_dy_xmu, eps)
if weight is None or not ctx.needs_input_grad[1]:
......@@ -73,4 +82,4 @@ class SyncBatchnormFunction(Function):
grad_bias = None
torch.cuda.nvtx.range_pop()
return grad_input, grad_weight, grad_bias, None, None, None, None, None
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment