Unverified Commit b11c5660 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

fix syncbn when dist is unavailabel (#388)

parent b2b42cbd
...@@ -109,88 +109,99 @@ class SyncBatchNormFunction(Function): ...@@ -109,88 +109,99 @@ class SyncBatchNormFunction(Function):
None, None, None, None None, None, None, None
class SyncBatchNorm(Module): if dist.is_available():
def __init__(self, class SyncBatchNorm(Module):
num_features,
eps=1e-5, def __init__(self,
momentum=0.1, num_features,
affine=True, eps=1e-5,
track_running_stats=True, momentum=0.1,
group=dist.group.WORLD): affine=True,
super(SyncBatchNorm, self).__init__() track_running_stats=True,
self.num_features = num_features group=dist.group.WORLD):
self.eps = eps super(SyncBatchNorm, self).__init__()
self.momentum = momentum self.num_features = num_features
self.affine = affine self.eps = eps
self.track_running_stats = track_running_stats self.momentum = momentum
self.group = group self.affine = affine
self.group_size = dist.get_world_size(group) self.track_running_stats = track_running_stats
if self.affine: self.group = group
self.weight = Parameter(torch.Tensor(num_features)) self.group_size = dist.get_world_size(group)
self.bias = Parameter(torch.Tensor(num_features)) if self.affine:
else: self.weight = Parameter(torch.Tensor(num_features))
self.register_parameter('weight', None) self.bias = Parameter(torch.Tensor(num_features))
self.register_parameter('bias', None) else:
if self.track_running_stats: self.register_parameter('weight', None)
self.register_buffer('running_mean', torch.zeros(num_features)) self.register_parameter('bias', None)
self.register_buffer('running_var', torch.ones(num_features)) if self.track_running_stats:
self.register_buffer('num_batches_tracked', self.register_buffer('running_mean', torch.zeros(num_features))
torch.tensor(0, dtype=torch.long)) self.register_buffer('running_var', torch.ones(num_features))
else: self.register_buffer('num_batches_tracked',
self.register_buffer('running_mean', None) torch.tensor(0, dtype=torch.long))
self.register_buffer('running_var', None) else:
self.register_buffer('num_batches_tracked', None) self.register_buffer('running_mean', None)
self.reset_parameters() self.register_buffer('running_var', None)
self.register_buffer('num_batches_tracked', None)
def reset_running_stats(self): self.reset_parameters()
if self.track_running_stats:
self.running_mean.zero_() def reset_running_stats(self):
self.running_var.fill_(1) if self.track_running_stats:
self.num_batches_tracked.zero_() self.running_mean.zero_()
self.running_var.fill_(1)
def reset_parameters(self): self.num_batches_tracked.zero_()
self.reset_running_stats()
if self.affine: def reset_parameters(self):
self.weight.data.uniform_() # pytorch use ones_() self.reset_running_stats()
self.bias.data.zero_() if self.affine:
self.weight.data.uniform_() # pytorch use ones_()
def forward(self, input): self.bias.data.zero_()
if input.dim() < 2:
raise ValueError( def forward(self, input):
f'expected at least 2D input, got {input.dim()}D input') if input.dim() < 2:
if self.momentum is None: raise ValueError(
exponential_average_factor = 0.0 f'expected at least 2D input, got {input.dim()}D input')
else: if self.momentum is None:
exponential_average_factor = self.momentum exponential_average_factor = 0.0
else:
if self.training and self.track_running_stats: exponential_average_factor = self.momentum
if self.num_batches_tracked is not None:
self.num_batches_tracked += 1 if self.training and self.track_running_stats:
if self.momentum is None: # use cumulative moving average if self.num_batches_tracked is not None:
exponential_average_factor = 1.0 / float( self.num_batches_tracked += 1
self.num_batches_tracked) if self.momentum is None: # use cumulative moving average
else: # use exponential moving average exponential_average_factor = 1.0 / float(
exponential_average_factor = self.momentum self.num_batches_tracked)
else: # use exponential moving average
if self.training or not self.track_running_stats: exponential_average_factor = self.momentum
return SyncBatchNormFunction.apply(input, self.running_mean,
self.running_var, self.weight, if self.training or not self.track_running_stats:
self.bias, return SyncBatchNormFunction.apply(input, self.running_mean,
exponential_average_factor, self.running_var,
self.eps, self.group, self.weight, self.bias,
self.group_size) exponential_average_factor,
else: self.eps, self.group,
return F.batch_norm(input, self.running_mean, self.running_var, self.group_size)
self.weight, self.bias, False, else:
exponential_average_factor, self.eps) return F.batch_norm(input, self.running_mean, self.running_var,
self.weight, self.bias, False,
def __repr__(self): exponential_average_factor, self.eps)
s = self.__class__.__name__
s += f'({self.num_features}, ' def __repr__(self):
s += f'eps={self.eps}, ' s = self.__class__.__name__
s += f'momentum={self.momentum}, ' s += f'({self.num_features}, '
s += f'affine={self.affine}, ' s += f'eps={self.eps}, '
s += f'track_running_stats={self.track_running_stats}, ' s += f'momentum={self.momentum}, '
s += f'group_size={self.group_size})' s += f'affine={self.affine}, '
return s s += f'track_running_stats={self.track_running_stats}, '
s += f'group_size={self.group_size})'
return s
else:
class SyncBatchNorm(Module):
def __init__(self, *args, **kwargs):
raise NotImplementedError(
'SyncBatchNorm is not supported in this OS since the '
'distributed package is not available')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment