"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "7ab7c121733873a850cee368319f3f6fa558d12f"
Unverified Commit 7cfc839e authored by Jerry Jiarui XU's avatar Jerry Jiarui XU Committed by GitHub
Browse files

Add MMSyncBN in registry (#420)

* Add MMSyncBN in registery

* skip mmsyncbn test
parent 50f69e70
...@@ -6,6 +6,7 @@ from torch.autograd.function import once_differentiable ...@@ -6,6 +6,7 @@ from torch.autograd.function import once_differentiable
from torch.nn.modules.module import Module from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from mmcv.cnn import NORM_LAYERS
from ..utils import ext_loader from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', [ ext_module = ext_loader.load_ext('_ext', [
...@@ -109,99 +110,89 @@ class SyncBatchNormFunction(Function): ...@@ -109,99 +110,89 @@ class SyncBatchNormFunction(Function):
None, None, None, None None, None, None, None
if dist.is_available(): @NORM_LAYERS.register_module(name='MMSyncBN')
class SyncBatchNorm(Module):
class SyncBatchNorm(Module):
def __init__(self,
def __init__(self, num_features,
num_features, eps=1e-5,
eps=1e-5, momentum=0.1,
momentum=0.1, affine=True,
affine=True, track_running_stats=True,
track_running_stats=True, group=None):
group=dist.group.WORLD): super(SyncBatchNorm, self).__init__()
super(SyncBatchNorm, self).__init__() self.num_features = num_features
self.num_features = num_features self.eps = eps
self.eps = eps self.momentum = momentum
self.momentum = momentum self.affine = affine
self.affine = affine self.track_running_stats = track_running_stats
self.track_running_stats = track_running_stats self.group = dist.group.WORLD if group is None else group
self.group = group self.group_size = dist.get_world_size(group)
self.group_size = dist.get_world_size(group) if self.affine:
if self.affine: self.weight = Parameter(torch.Tensor(num_features))
self.weight = Parameter(torch.Tensor(num_features)) self.bias = Parameter(torch.Tensor(num_features))
self.bias = Parameter(torch.Tensor(num_features)) else:
else: self.register_parameter('weight', None)
self.register_parameter('weight', None) self.register_parameter('bias', None)
self.register_parameter('bias', None) if self.track_running_stats:
if self.track_running_stats: self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_mean', torch.zeros(num_features)) self.register_buffer('running_var', torch.ones(num_features))
self.register_buffer('running_var', torch.ones(num_features)) self.register_buffer('num_batches_tracked',
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
torch.tensor(0, dtype=torch.long)) else:
else: self.register_buffer('running_mean', None)
self.register_buffer('running_mean', None) self.register_buffer('running_var', None)
self.register_buffer('running_var', None) self.register_buffer('num_batches_tracked', None)
self.register_buffer('num_batches_tracked', None) self.reset_parameters()
self.reset_parameters()
def reset_running_stats(self):
def reset_running_stats(self): if self.track_running_stats:
if self.track_running_stats: self.running_mean.zero_()
self.running_mean.zero_() self.running_var.fill_(1)
self.running_var.fill_(1) self.num_batches_tracked.zero_()
self.num_batches_tracked.zero_()
def reset_parameters(self):
def reset_parameters(self): self.reset_running_stats()
self.reset_running_stats() if self.affine:
if self.affine: self.weight.data.uniform_() # pytorch use ones_()
self.weight.data.uniform_() # pytorch use ones_() self.bias.data.zero_()
self.bias.data.zero_()
def forward(self, input):
def forward(self, input): if input.dim() < 2:
if input.dim() < 2: raise ValueError(
raise ValueError( f'expected at least 2D input, got {input.dim()}D input')
f'expected at least 2D input, got {input.dim()}D input') if self.momentum is None:
if self.momentum is None: exponential_average_factor = 0.0
exponential_average_factor = 0.0 else:
else: exponential_average_factor = self.momentum
exponential_average_factor = self.momentum
if self.training and self.track_running_stats:
if self.training and self.track_running_stats: if self.num_batches_tracked is not None:
if self.num_batches_tracked is not None: self.num_batches_tracked += 1
self.num_batches_tracked += 1 if self.momentum is None: # use cumulative moving average
if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / float(
exponential_average_factor = 1.0 / float( self.num_batches_tracked)
self.num_batches_tracked) else: # use exponential moving average
else: # use exponential moving average exponential_average_factor = self.momentum
exponential_average_factor = self.momentum
if self.training or not self.track_running_stats:
if self.training or not self.track_running_stats: return SyncBatchNormFunction.apply(input, self.running_mean,
return SyncBatchNormFunction.apply(input, self.running_mean, self.running_var, self.weight,
self.running_var, self.bias,
self.weight, self.bias, exponential_average_factor,
exponential_average_factor, self.eps, self.group,
self.eps, self.group, self.group_size)
self.group_size) else:
else: return F.batch_norm(input, self.running_mean, self.running_var,
return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False,
self.weight, self.bias, False, exponential_average_factor, self.eps)
exponential_average_factor, self.eps)
def __repr__(self):
def __repr__(self): s = self.__class__.__name__
s = self.__class__.__name__ s += f'({self.num_features}, '
s += f'({self.num_features}, ' s += f'eps={self.eps}, '
s += f'eps={self.eps}, ' s += f'momentum={self.momentum}, '
s += f'momentum={self.momentum}, ' s += f'affine={self.affine}, '
s += f'affine={self.affine}, ' s += f'track_running_stats={self.track_running_stats}, '
s += f'track_running_stats={self.track_running_stats}, ' s += f'group_size={self.group_size})'
s += f'group_size={self.group_size})' return s
return s
else:
class SyncBatchNorm(Module):
def __init__(self, *args, **kwargs):
raise NotImplementedError(
'SyncBatchNorm is not supported in this OS since the '
'distributed package is not available')
...@@ -135,6 +135,8 @@ def test_build_norm_layer(): ...@@ -135,6 +135,8 @@ def test_build_norm_layer():
'IN3d': 'in', 'IN3d': 'in',
} }
for type_name, module in NORM_LAYERS.module_dict.items(): for type_name, module in NORM_LAYERS.module_dict.items():
if type_name == 'MMSyncBN': # skip MMSyncBN
continue
for postfix in ['_test', 1]: for postfix in ['_test', 1]:
cfg = dict(type=type_name) cfg = dict(type=type_name)
if type_name == 'GN': if type_name == 'GN':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment