Unverified Commit 3f9b5c98 authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

Merge pull request #94 from NVIDIA/sbn_util

Adding process group in convert_syncbn_model
parents 920da6da 6d3c75e5
...@@ -8,7 +8,7 @@ except ImportError: ...@@ -8,7 +8,7 @@ except ImportError:
print("Warning: apex was installed without --cuda_ext. Fused syncbn kernels will be unavailable. Python fallbacks will be used instead.") print("Warning: apex was installed without --cuda_ext. Fused syncbn kernels will be unavailable. Python fallbacks will be used instead.")
from .sync_batchnorm import SyncBatchNorm from .sync_batchnorm import SyncBatchNorm
def convert_syncbn_model(module): def convert_syncbn_model(module, process_group=None):
''' '''
Recursively traverse module and its children to replace all Recursively traverse module and its children to replace all
`torch.nn.modules.batchnorm._BatchNorm` with `apex.parallel.SyncBatchNorm` `torch.nn.modules.batchnorm._BatchNorm` with `apex.parallel.SyncBatchNorm`
...@@ -27,7 +27,7 @@ def convert_syncbn_model(module): ...@@ -27,7 +27,7 @@ def convert_syncbn_model(module):
''' '''
mod = module mod = module
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
mod = SyncBatchNorm(module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats) mod = SyncBatchNorm(module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats, process_group)
mod.running_mean = module.running_mean mod.running_mean = module.running_mean
mod.running_var = module.running_var mod.running_var = module.running_var
if module.affine: if module.affine:
......
...@@ -44,8 +44,12 @@ class SyncBatchNorm(_BatchNorm): ...@@ -44,8 +44,12 @@ class SyncBatchNorm(_BatchNorm):
>>> out = sbn(inp) >>> out = sbn(inp)
""" """
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None):
super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
self.process_group = process_group
def _specify_process_group(self, process_group):
self.process_group = process_group
def forward(self, input): def forward(self, input):
torch.cuda.nvtx.range_push("sync_bn_fw_with_mean_var") torch.cuda.nvtx.range_push("sync_bn_fw_with_mean_var")
...@@ -56,6 +60,13 @@ class SyncBatchNorm(_BatchNorm): ...@@ -56,6 +60,13 @@ class SyncBatchNorm(_BatchNorm):
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps) return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps)
else: else:
process_group = self.process_group
world_size = 0
if self.process_group:
world_size = torch.distributed.get_world_size(process_group)
else:
process_group = torch.distributed.get_default_group()
world_size = torch.distributed.get_world_size()
self.num_batches_tracked += 1 self.num_batches_tracked += 1
with torch.no_grad(): with torch.no_grad():
channel_first_input = input.transpose(0, 1).contiguous() channel_first_input = input.transpose(0, 1).contiguous()
...@@ -69,12 +80,12 @@ class SyncBatchNorm(_BatchNorm): ...@@ -69,12 +80,12 @@ class SyncBatchNorm(_BatchNorm):
squashed_input_tensor_view, 2).mean(1) squashed_input_tensor_view, 2).mean(1)
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
torch.distributed.all_reduce( torch.distributed.all_reduce(
local_mean, op=torch.distributed.reduce_op.SUM) local_mean, torch.distributed.reduce_op.SUM, process_group)
mean = local_mean / torch.distributed.get_world_size() mean = local_mean / world_size
torch.distributed.all_reduce( torch.distributed.all_reduce(
local_sqr_mean, op=torch.distributed.reduce_op.SUM) local_sqr_mean, torch.distributed.reduce_op.SUM, process_group)
sqr_mean = local_sqr_mean / torch.distributed.get_world_size() sqr_mean = local_sqr_mean / world_size
m = local_m * torch.distributed.get_world_size() m = local_m * world_size
else: else:
m = local_m m = local_m
mean = local_mean mean = local_mean
...@@ -94,4 +105,4 @@ class SyncBatchNorm(_BatchNorm): ...@@ -94,4 +105,4 @@ class SyncBatchNorm(_BatchNorm):
(m-1) * self.momentum * var + \ (m-1) * self.momentum * var + \
(1 - self.momentum) * self.running_var (1 - self.momentum) * self.running_var
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
return SyncBatchnormFunction.apply(input, self.weight, self.bias, mean, var, self.eps) return SyncBatchnormFunction.apply(input, self.weight, self.bias, mean, var, self.eps, process_group, world_size)
...@@ -5,7 +5,7 @@ from torch.autograd.function import Function ...@@ -5,7 +5,7 @@ from torch.autograd.function import Function
class SyncBatchnormFunction(Function): class SyncBatchnormFunction(Function):
@staticmethod @staticmethod
def forward(ctx, input, weight, bias, running_mean, running_variance, eps): def forward(ctx, input, weight, bias, running_mean, running_variance, eps, process_group, world_size):
torch.cuda.nvtx.range_push("sync_BN_fw") torch.cuda.nvtx.range_push("sync_BN_fw")
# transpose it to channel last to support broadcasting for input with different rank # transpose it to channel last to support broadcasting for input with different rank
c_last_input = input.transpose(1, -1).contiguous().clone() c_last_input = input.transpose(1, -1).contiguous().clone()
...@@ -13,6 +13,8 @@ class SyncBatchnormFunction(Function): ...@@ -13,6 +13,8 @@ class SyncBatchnormFunction(Function):
ctx.save_for_backward(c_last_input, weight, bias, ctx.save_for_backward(c_last_input, weight, bias,
running_mean, running_variance) running_mean, running_variance)
ctx.eps = eps ctx.eps = eps
ctx.process_group = process_group
ctx.world_size = world_size
c_last_input = (c_last_input - running_mean) / \ c_last_input = (c_last_input - running_mean) / \
torch.sqrt(running_variance + eps) torch.sqrt(running_variance + eps)
...@@ -34,6 +36,8 @@ class SyncBatchnormFunction(Function): ...@@ -34,6 +36,8 @@ class SyncBatchnormFunction(Function):
c_last_input, weight, bias, running_mean, running_variance = ctx.saved_tensors c_last_input, weight, bias, running_mean, running_variance = ctx.saved_tensors
eps = ctx.eps eps = ctx.eps
process_group = ctx.process_group
world_size = ctx.world_size
grad_input = grad_weight = grad_bias = None grad_input = grad_weight = grad_bias = None
num_features = running_mean.size()[0] num_features = running_mean.size()[0]
...@@ -53,11 +57,11 @@ class SyncBatchnormFunction(Function): ...@@ -53,11 +57,11 @@ class SyncBatchnormFunction(Function):
running_mean)).view(-1, num_features).mean(0) running_mean)).view(-1, num_features).mean(0)
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
torch.distributed.all_reduce( torch.distributed.all_reduce(
mean_dy, op=torch.distributed.reduce_op.SUM) mean_dy, torch.distributed.reduce_op.SUM, process_group)
mean_dy = mean_dy / torch.distributed.get_world_size() mean_dy = mean_dy / world_size
torch.distributed.all_reduce( torch.distributed.all_reduce(
mean_dy_xmu, op=torch.distributed.reduce_op.SUM) mean_dy_xmu, torch.distributed.reduce_op.SUM, process_group)
mean_dy_xmu = mean_dy_xmu / torch.distributed.get_world_size() mean_dy_xmu = mean_dy_xmu / world_size
c_last_grad_input = (c_last_grad - mean_dy - (c_last_input - running_mean) / ( c_last_grad_input = (c_last_grad - mean_dy - (c_last_input - running_mean) / (
running_variance + eps) * mean_dy_xmu) / torch.sqrt(running_variance + eps) running_variance + eps) * mean_dy_xmu) / torch.sqrt(running_variance + eps)
if weight is not None: if weight is not None:
...@@ -78,4 +82,4 @@ class SyncBatchnormFunction(Function): ...@@ -78,4 +82,4 @@ class SyncBatchnormFunction(Function):
grad_bias = c_grad.sum(0) grad_bias = c_grad.sum(0)
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
return grad_input, grad_weight, grad_bias, None, None, None return grad_input, grad_weight, grad_bias, None, None, None, None, None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment