Unverified Commit 16e17159 authored by Shilong Zhang's avatar Shilong Zhang Committed by GitHub
Browse files

Fix `NaiveSyncBatchNorm1d` and `NaiveSyncBatchNorm2d` (#1435)

* add quick install command

* fix SyncBatchNorm

* fix SyncBatchNorm
parent d8425466
...@@ -53,11 +53,27 @@ class NaiveSyncBatchNorm1d(nn.BatchNorm1d): ...@@ -53,11 +53,27 @@ class NaiveSyncBatchNorm1d(nn.BatchNorm1d):
# TODO: make mmcv fp16 utils handle customized norm layers # TODO: make mmcv fp16 utils handle customized norm layers
@force_fp32(out_fp16=True) @force_fp32(out_fp16=True)
def forward(self, input): def forward(self, input):
"""
Args:
input (tensor): Has shape (N, C) or (N, C, L), where N is
the batch size, C is the number of features or
channels, and L is the sequence length
Returns:
tensor: Has shape (N, C) or (N, C, L), has same shape
as input.
"""
assert input.dtype == torch.float32, \ assert input.dtype == torch.float32, \
f'input should be in float32 type, got {input.dtype}' f'input should be in float32 type, got {input.dtype}'
if dist.get_world_size() == 1 or not self.training: using_dist = dist.is_available() and dist.is_initialized()
if (not using_dist) or dist.get_world_size() == 1 \
or not self.training:
return super().forward(input) return super().forward(input)
assert input.shape[0] > 0, 'SyncBN does not support empty inputs' assert input.shape[0] > 0, 'SyncBN does not support empty inputs'
is_two_dim = input.dim() == 2
if is_two_dim:
input = input.unsqueeze(2)
C = input.shape[1] C = input.shape[1]
mean = torch.mean(input, dim=[0, 2]) mean = torch.mean(input, dim=[0, 2])
meansqr = torch.mean(input * input, dim=[0, 2]) meansqr = torch.mean(input * input, dim=[0, 2])
...@@ -76,7 +92,10 @@ class NaiveSyncBatchNorm1d(nn.BatchNorm1d): ...@@ -76,7 +92,10 @@ class NaiveSyncBatchNorm1d(nn.BatchNorm1d):
bias = self.bias - mean * scale bias = self.bias - mean * scale
scale = scale.reshape(1, -1, 1) scale = scale.reshape(1, -1, 1)
bias = bias.reshape(1, -1, 1) bias = bias.reshape(1, -1, 1)
return input * scale + bias output = input * scale + bias
if is_two_dim:
output = output.squeeze(2)
return output
@NORM_LAYERS.register_module('naiveSyncBN2d') @NORM_LAYERS.register_module('naiveSyncBN2d')
...@@ -107,9 +126,19 @@ class NaiveSyncBatchNorm2d(nn.BatchNorm2d): ...@@ -107,9 +126,19 @@ class NaiveSyncBatchNorm2d(nn.BatchNorm2d):
# TODO: make mmcv fp16 utils handle customized norm layers # TODO: make mmcv fp16 utils handle customized norm layers
@force_fp32(out_fp16=True) @force_fp32(out_fp16=True)
def forward(self, input): def forward(self, input):
"""
Args:
Input (tensor): Feature has shape (N, C, H, W).
Returns:
tensor: Has shape (N, C, H, W), same shape as input.
"""
assert input.dtype == torch.float32, \ assert input.dtype == torch.float32, \
f'input should be in float32 type, got {input.dtype}' f'input should be in float32 type, got {input.dtype}'
if dist.get_world_size() == 1 or not self.training: using_dist = dist.is_available() and dist.is_initialized()
if (not using_dist) or \
dist.get_world_size() == 1 or \
not self.training:
return super().forward(input) return super().forward(input)
assert input.shape[0] > 0, 'SyncBN does not support empty inputs' assert input.shape[0] > 0, 'SyncBN does not support empty inputs'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment