Unverified Commit 4e101e0b authored by Wenwei Zhang's avatar Wenwei Zhang Committed by GitHub
Browse files

[Feature]: Support empty tensor in MMSyncBN (#1205)

* [Feature]: support empty tensor in MMSyncBN

* refine code

* resolve comments

* clean unnecessary comments

* fix inaccurate statistics when empty tensor

* resolve comments and add docstrings

* update unit tests

* rephrase, ready for merge
parent b6eb3822
...@@ -20,7 +20,7 @@ class SyncBatchNormFunction(Function): ...@@ -20,7 +20,7 @@ class SyncBatchNormFunction(Function):
@staticmethod @staticmethod
def symbolic(g, input, running_mean, running_var, weight, bias, momentum, def symbolic(g, input, running_mean, running_var, weight, bias, momentum,
eps, group, group_size): eps, group, group_size, stats_mode):
return g.op( return g.op(
'mmcv::MMCVSyncBatchNorm', 'mmcv::MMCVSyncBatchNorm',
input, input,
...@@ -31,41 +31,82 @@ class SyncBatchNormFunction(Function): ...@@ -31,41 +31,82 @@ class SyncBatchNormFunction(Function):
momentum_f=momentum, momentum_f=momentum,
eps_f=eps, eps_f=eps,
group_i=group, group_i=group,
group_size_i=group_size) group_size_i=group_size,
stats_mode=stats_mode)
@staticmethod @staticmethod
def forward(self, input, running_mean, running_var, weight, bias, momentum, def forward(self, input, running_mean, running_var, weight, bias, momentum,
eps, group, group_size): eps, group, group_size, stats_mode):
self.momentum = momentum self.momentum = momentum
self.eps = eps self.eps = eps
self.group = group self.group = group
self.group_size = group_size self.group_size = group_size
self.stats_mode = stats_mode
assert isinstance( assert isinstance(
input, (torch.HalfTensor, torch.FloatTensor, input, (torch.HalfTensor, torch.FloatTensor,
torch.cuda.HalfTensor, torch.cuda.FloatTensor)), \ torch.cuda.HalfTensor, torch.cuda.FloatTensor)), \
f'only support Half or Float Tensor, but {input.type()}' f'only support Half or Float Tensor, but {input.type()}'
output = torch.empty_like(input) output = torch.zeros_like(input)
input3d = input.view(input.size(0), input.size(1), -1) input3d = input.flatten(start_dim=2)
output3d = output.view_as(input3d) output3d = output.view_as(input3d)
num_channels = input3d.size(1)
mean = torch.empty( # ensure mean/var/norm/std are initialized as zeros
input3d.size(1), dtype=torch.float, device=input3d.device) # ``torch.empty()`` does not guarantee that
var = torch.empty( mean = torch.zeros(
input3d.size(1), dtype=torch.float, device=input3d.device) num_channels, dtype=torch.float, device=input3d.device)
norm = torch.empty_like( var = torch.zeros(
num_channels, dtype=torch.float, device=input3d.device)
norm = torch.zeros_like(
input3d, dtype=torch.float, device=input3d.device) input3d, dtype=torch.float, device=input3d.device)
std = torch.empty( std = torch.zeros(
input3d.size(1), dtype=torch.float, device=input3d.device) num_channels, dtype=torch.float, device=input3d.device)
ext_module.sync_bn_forward_mean(input3d, mean) batch_size = input3d.size(0)
if batch_size > 0:
ext_module.sync_bn_forward_mean(input3d, mean)
batch_flag = torch.ones([1], device=mean.device, dtype=mean.dtype)
else:
# skip updating mean and leave it as zeros when the input is empty
batch_flag = torch.zeros([1], device=mean.device, dtype=mean.dtype)
# synchronize mean and the batch flag
vec = torch.cat([mean, batch_flag])
if self.stats_mode == 'N':
vec *= batch_size
if self.group_size > 1: if self.group_size > 1:
dist.all_reduce(mean, group=self.group) dist.all_reduce(vec, group=self.group)
mean /= self.group_size total_batch = vec[-1].detach()
ext_module.sync_bn_forward_var(input3d, mean, var) mean = vec[:num_channels]
if self.stats_mode == 'default':
mean = mean / self.group_size
elif self.stats_mode == 'N':
mean = mean / total_batch.clamp(min=1)
else:
raise NotImplementedError
# leave var as zeros when the input is empty
if batch_size > 0:
ext_module.sync_bn_forward_var(input3d, mean, var)
if self.stats_mode == 'N':
var *= batch_size
if self.group_size > 1: if self.group_size > 1:
dist.all_reduce(var, group=self.group) dist.all_reduce(var, group=self.group)
if self.stats_mode == 'default':
var /= self.group_size var /= self.group_size
elif self.stats_mode == 'N':
var /= total_batch.clamp(min=1)
else:
raise NotImplementedError
# if the total batch size over all the ranks is zero,
# we should not update the statistics in the current batch
update_flag = total_batch.clamp(max=1)
momentum = update_flag * self.momentum
ext_module.sync_bn_forward_output( ext_module.sync_bn_forward_output(
input3d, input3d,
mean, mean,
...@@ -78,7 +119,7 @@ class SyncBatchNormFunction(Function): ...@@ -78,7 +119,7 @@ class SyncBatchNormFunction(Function):
std, std,
output3d, output3d,
eps=self.eps, eps=self.eps,
momentum=self.momentum, momentum=momentum,
group_size=self.group_size) group_size=self.group_size)
self.save_for_backward(norm, std, weight) self.save_for_backward(norm, std, weight)
return output return output
...@@ -87,28 +128,67 @@ class SyncBatchNormFunction(Function): ...@@ -87,28 +128,67 @@ class SyncBatchNormFunction(Function):
@once_differentiable @once_differentiable
def backward(self, grad_output): def backward(self, grad_output):
norm, std, weight = self.saved_tensors norm, std, weight = self.saved_tensors
grad_weight = torch.empty_like(weight) grad_weight = torch.zeros_like(weight)
grad_bias = torch.empty_like(weight) grad_bias = torch.zeros_like(weight)
grad_input = torch.empty_like(grad_output) grad_input = torch.zeros_like(grad_output)
grad_output3d = grad_output.view( grad_output3d = grad_output.flatten(start_dim=2)
grad_output.size(0), grad_output.size(1), -1)
grad_input3d = grad_input.view_as(grad_output3d) grad_input3d = grad_input.view_as(grad_output3d)
ext_module.sync_bn_backward_param(grad_output3d, norm, grad_weight,
grad_bias) batch_size = grad_input3d.size(0)
if batch_size > 0:
ext_module.sync_bn_backward_param(grad_output3d, norm, grad_weight,
grad_bias)
# all reduce # all reduce
if self.group_size > 1: if self.group_size > 1:
dist.all_reduce(grad_weight, group=self.group) dist.all_reduce(grad_weight, group=self.group)
dist.all_reduce(grad_bias, group=self.group) dist.all_reduce(grad_bias, group=self.group)
grad_weight /= self.group_size grad_weight /= self.group_size
grad_bias /= self.group_size grad_bias /= self.group_size
ext_module.sync_bn_backward_data(grad_output3d, weight, grad_weight,
grad_bias, norm, std, grad_input3d) if batch_size > 0:
ext_module.sync_bn_backward_data(grad_output3d, weight,
grad_weight, grad_bias, norm, std,
grad_input3d)
return grad_input, None, None, grad_weight, grad_bias, \ return grad_input, None, None, grad_weight, grad_bias, \
None, None, None, None None, None, None, None, None
@NORM_LAYERS.register_module(name='MMSyncBN') @NORM_LAYERS.register_module(name='MMSyncBN')
class SyncBatchNorm(Module): class SyncBatchNorm(Module):
"""Synchronized Batch Normalization.
Args:
num_features (int): number of features/chennels in input tensor
eps (float, optional): a value added to the denominator for numerical
stability. Defaults to 1e-5.
momentum (float, optional): the value used for the running_mean and
running_var computation. Defaults to 0.1.
affine (bool, optional): whether to use learnable affine parameters.
Defaults to True.
track_running_stats (bool, optional): whether to track the running
mean and variance during training. When set to False, this
module does not track such statistics, and initializes statistics
buffers ``running_mean`` and ``running_var`` as ``None``. When
these buffers are ``None``, this module always uses batch
statistics in both training and eval modes. Defaults to True.
group (int, optional): synchronization of stats happen within
each process group individually. By default it is synchronization
across the whole world. Defaults to None.
stats_mode (str, optional): The statistical mode. Available options
includes ``'default'`` and ``'N'``. Defaults to 'default'.
When ``stats_mode=='default'``, it computes the overall statistics
using those from each worker with equal weight, i.e., the
statistics are synchronized and simply divied by ``group``. This
mode will produce inaccurate statistics when empty tensors occur.
When ``stats_mode=='N'``, it compute the overall statistics using
the total number of batches in each worker ignoring the number of
group, i.e., the statistics are synchronized and then divied by
the total batch ``N``. This mode is beneficial when empty tensors
occur during training, as it average the total mean by the real
number of batch.
"""
def __init__(self, def __init__(self,
num_features, num_features,
...@@ -116,7 +196,8 @@ class SyncBatchNorm(Module): ...@@ -116,7 +196,8 @@ class SyncBatchNorm(Module):
momentum=0.1, momentum=0.1,
affine=True, affine=True,
track_running_stats=True, track_running_stats=True,
group=None): group=None,
stats_mode='default'):
super(SyncBatchNorm, self).__init__() super(SyncBatchNorm, self).__init__()
self.num_features = num_features self.num_features = num_features
self.eps = eps self.eps = eps
...@@ -126,6 +207,9 @@ class SyncBatchNorm(Module): ...@@ -126,6 +207,9 @@ class SyncBatchNorm(Module):
group = dist.group.WORLD if group is None else group group = dist.group.WORLD if group is None else group
self.group = group self.group = group
self.group_size = dist.get_world_size(group) self.group_size = dist.get_world_size(group)
assert stats_mode in ['default', 'N'], \
f'"stats_mode" only accepts "default" and "N", got "{stats_mode}"'
self.stats_mode = stats_mode
if self.affine: if self.affine:
self.weight = Parameter(torch.Tensor(num_features)) self.weight = Parameter(torch.Tensor(num_features))
self.bias = Parameter(torch.Tensor(num_features)) self.bias = Parameter(torch.Tensor(num_features))
...@@ -174,12 +258,10 @@ class SyncBatchNorm(Module): ...@@ -174,12 +258,10 @@ class SyncBatchNorm(Module):
exponential_average_factor = self.momentum exponential_average_factor = self.momentum
if self.training or not self.track_running_stats: if self.training or not self.track_running_stats:
return SyncBatchNormFunction.apply(input, self.running_mean, return SyncBatchNormFunction.apply(
self.running_var, self.weight, input, self.running_mean, self.running_var, self.weight,
self.bias, self.bias, exponential_average_factor, self.eps, self.group,
exponential_average_factor, self.group_size, self.stats_mode)
self.eps, self.group,
self.group_size)
else: else:
return F.batch_norm(input, self.running_mean, self.running_var, return F.batch_norm(input, self.running_mean, self.running_var,
self.weight, self.bias, False, self.weight, self.bias, False,
...@@ -192,5 +274,6 @@ class SyncBatchNorm(Module): ...@@ -192,5 +274,6 @@ class SyncBatchNorm(Module):
s += f'momentum={self.momentum}, ' s += f'momentum={self.momentum}, '
s += f'affine={self.affine}, ' s += f'affine={self.affine}, '
s += f'track_running_stats={self.track_running_stats}, ' s += f'track_running_stats={self.track_running_stats}, '
s += f'group_size={self.group_size})' s += f'group_size={self.group_size},'
s += f'stats_mode={self.stats_mode})'
return s return s
from mmcv import Config # isort:skip from mmcv import Config # isort:skip
cfg = Config.fromfile('./tests/data/config/a.py') cfg = Config.fromfile('./tests/data/config/a.py')
item5 = cfg.item1[0] + cfg.item2.a item5 = cfg.item1[0] + cfg.item2.a
...@@ -2,6 +2,7 @@ import os ...@@ -2,6 +2,7 @@ import os
import platform import platform
import numpy as np import numpy as np
import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
...@@ -141,6 +142,121 @@ class TestSyncBN(object): ...@@ -141,6 +142,121 @@ class TestSyncBN(object):
assert np.allclose(x_grad.data.cpu().numpy(), assert np.allclose(x_grad.data.cpu().numpy(),
sx_grad.data.cpu().numpy(), 1e-2) sx_grad.data.cpu().numpy(), 1e-2)
def _test_syncbn_empty_train(self, size=1, half=False):
if 'SLURM_NTASKS' not in os.environ or int(
os.environ['SLURM_NTASKS']) != 4:
print('must run with slurm has 4 processes!\n'
'srun -p test --gres=gpu:4 -n4')
return
else:
print('Running syncbn test')
from mmcv.ops import SyncBatchNorm
assert size in (1, 2, 4)
if not dist.is_initialized():
self.dist_init()
rank = dist.get_rank()
torch.manual_seed(9)
torch.cuda.manual_seed(9)
self.x = torch.rand(0, 3, 2, 3).cuda()
self.y_bp = torch.rand(0, 3, 2, 3).cuda()
if half:
self.x = self.x.half()
self.y_bp = self.y_bp.half()
dist.broadcast(self.x, src=0)
dist.broadcast(self.y_bp, src=0)
torch.cuda.synchronize()
if size == 1:
groups = [None, None, None, None]
groups[0] = dist.new_group([0])
groups[1] = dist.new_group([1])
groups[2] = dist.new_group([2])
groups[3] = dist.new_group([3])
group = groups[rank]
elif size == 2:
groups = [None, None, None, None]
groups[0] = groups[1] = dist.new_group([0, 1])
groups[2] = groups[3] = dist.new_group([2, 3])
group = groups[rank]
elif size == 4:
group = dist.group.WORLD
syncbn = SyncBatchNorm(3, group=group, stats_mode='N').cuda()
syncbn.weight.data[0] = 0.2
syncbn.weight.data[1] = 0.5
syncbn.weight.data[2] = 0.7
syncbn.train()
bn = nn.BatchNorm2d(3).cuda()
bn.weight.data[0] = 0.2
bn.weight.data[1] = 0.5
bn.weight.data[2] = 0.7
bn.train()
sx = self.x[rank * 4:rank * 4 + 4]
sx.requires_grad_()
sy = syncbn(sx)
sy.backward(self.y_bp[rank * 4:rank * 4 + 4])
smean = syncbn.running_mean
svar = syncbn.running_var
sx_grad = sx.grad
sw_grad = syncbn.weight.grad
sb_grad = syncbn.bias.grad
if size == 1:
x = self.x[rank * 4:rank * 4 + 4]
y_bp = self.y_bp[rank * 4:rank * 4 + 4]
elif size == 2:
x = self.x[rank // 2 * 8:rank // 2 * 8 + 8]
y_bp = self.y_bp[rank // 2 * 8:rank // 2 * 8 + 8]
elif size == 4:
x = self.x
y_bp = self.y_bp
x.requires_grad_()
y = bn(x)
y.backward(y_bp)
if size == 2:
y = y[rank % 2 * 4:rank % 2 * 4 + 4]
elif size == 4:
y = y[rank * 4:rank * 4 + 4]
mean = bn.running_mean
var = bn.running_var
if size == 1:
x_grad = x.grad
w_grad = bn.weight.grad
b_grad = bn.bias.grad
elif size == 2:
x_grad = x.grad[rank % 2 * 4:rank % 2 * 4 + 4]
w_grad = bn.weight.grad / 2
b_grad = bn.bias.grad / 2
elif size == 4:
x_grad = x.grad[rank * 4:rank * 4 + 4]
w_grad = bn.weight.grad / 4
b_grad = bn.bias.grad / 4
assert np.allclose(mean.data.cpu().numpy(),
smean.data.cpu().numpy(), 1e-3)
assert np.allclose(var.data.cpu().numpy(),
svar.data.cpu().numpy(), 1e-3)
assert np.allclose(y.data.cpu().numpy(), sy.data.cpu().numpy(), 1e-3)
assert np.allclose(w_grad.data.cpu().numpy(),
sw_grad.data.cpu().numpy(), 1e-3)
assert np.allclose(b_grad.data.cpu().numpy(),
sb_grad.data.cpu().numpy(), 1e-3)
assert np.allclose(x_grad.data.cpu().numpy(),
sx_grad.data.cpu().numpy(), 1e-2)
# 'stats_mode' only allows 'default' and 'N'
with pytest.raises(AssertionError):
SyncBatchNorm(3, group=group, stats_mode='X')
def test_syncbn_1(self): def test_syncbn_1(self):
self._test_syncbn_train(size=1) self._test_syncbn_train(size=1)
...@@ -158,3 +274,21 @@ class TestSyncBN(object): ...@@ -158,3 +274,21 @@ class TestSyncBN(object):
def test_syncbn_4_half(self): def test_syncbn_4_half(self):
self._test_syncbn_train(size=4, half=True) self._test_syncbn_train(size=4, half=True)
def test_syncbn_empty_1(self):
self._test_syncbn_empty_train(size=1)
def test_syncbn_empty_2(self):
self._test_syncbn_empty_train(size=2)
def test_syncbn_empty_4(self):
self._test_syncbn_empty_train(size=4)
def test_syncbn_empty_1_half(self):
self._test_syncbn_empty_train(size=1, half=True)
def test_syncbn_empty_2_half(self):
self._test_syncbn_empty_train(size=2, half=True)
def test_syncbn_empty_4_half(self):
self._test_syncbn_empty_train(size=4, half=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment