Unverified Commit f0a40046 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[feat] experimental.nn.SyncBatchNorm: initial commit (#662)

* [feat] experimental.nn.SyncBatchNorm: initial commit

Fast/simple re-implementation of SyncBatchNorm.

When profiling SSL Vision, I was seeing a majority of cycles spent in
SyncBatchNorm. With this change, I see a 10% to 20% speedup on the
model I was profiling.

When running benchmarks/experimental/sync_batchnorm.py on 8 x V100,
I get a 6x speedup:

<class 'torch.nn.modules.batchnorm.BatchNorm2d'>
Elapsed time is  0.08709120750427246
Elapsed time is  0.12632274627685547
Elapsed time is  0.14095258712768555
Elapsed time is  0.16529417037963867
Elapsed time is  0.1419970989227295
Elapsed time is  0.15166854858398438
Elapsed time is  0.12000870704650879
Elapsed time is  0.17534875869750977
<class 'torch.nn.modules.batchnorm.SyncBatchNorm'>
Elapsed time is  2.5087168216705322
Elapsed time is  2.497001886367798
Elapsed time is  2.5204885005950928
Elapsed time is  2.526789903640747
Elapsed time is  2.5080230236053467
Elapsed time is  2.524489641189575
Elapsed time is  2.513214588165283
Elapsed time is  2.5359973907470703
<class 'fairscale.experimental.nn.sync_batchnorm.SyncBatchNorm'>
Elapsed time is  0.4126114845275879
Elapsed time is  0.39051294326782227
Elapsed time is  0.40685415267944336
Elapsed time is  0.4159870147705078
Elapsed time is  0.42383885383605957
Elapsed time is  0.4080159664154053
Elapsed time is  0.41202712059020996
Elapsed time is  0.42400121688842773
parent b54eed1b
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import tempfile
import time
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import fairscale.experimental.nn
def benchmark_bn(rank, world_size, init_file, bn_cls):
dist.init_process_group(dist.Backend.NCCL, init_method="file://" + init_file, rank=rank, world_size=world_size)
x = torch.randn(50, 2048, 7, 7).to(rank)
bn = bn_cls(2048).to(rank)
bn = DDP(bn, device_ids=[rank])
# Warmup
for i in range(50):
with torch.no_grad():
x = bn(x)
torch.cuda.synchronize(rank)
t0 = time.time()
for i in range(100):
with torch.no_grad():
x = bn(x)
torch.cuda.synchronize(rank)
t1 = time.time()
print("Elapsed time is ", t1 - t0)
if __name__ == "__main__":
world_size = torch.cuda.device_count()
for cls in [torch.nn.BatchNorm2d, torch.nn.SyncBatchNorm, fairscale.experimental.nn.SyncBatchNorm]:
print(cls)
mp.spawn(benchmark_bn, args=(world_size, tempfile.mkstemp()[1], cls), nprocs=world_size)
...@@ -6,5 +6,6 @@ ...@@ -6,5 +6,6 @@
from typing import List from typing import List
from .offload import OffloadModel from .offload import OffloadModel
from .sync_batchnorm import SyncBatchNorm
__all__: List[str] = [] __all__: List[str] = []
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, Optional, Tuple
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
if torch.__version__.split(".")[:2] >= ["1", "8"]:
from torch.distributed.nn.functional import all_reduce as differentiable_all_reduce
else:
# Copied from https://github.com/pytorch/pytorch/blob/v1.8.1/torch/distributed/nn/functional.py
class _AllReduce(torch.autograd.Function):
@staticmethod
def forward(ctx, op, group, tensor): # type: ignore
ctx.group = group
ctx.op = op
tensor = tensor.clone()
dist.all_reduce(tensor, op=op, group=group)
return tensor
@staticmethod
def backward(ctx, grad_output): # type: ignore
return (None, None) + (_AllReduce.apply(ctx.op, ctx.group, grad_output),)
def differentiable_all_reduce(tensor, op=dist.ReduceOp.SUM, group=dist.group.WORLD): # type: ignore
return _AllReduce.apply(op, group, tensor)
def _forward(
input: torch.Tensor,
affine: bool,
track_running_stats: bool,
mean: torch.Tensor,
meansqr: torch.Tensor,
momentum: float,
eps: float,
weight: torch.Tensor,
bias: torch.Tensor,
running_mean: torch.Tensor,
running_var: torch.Tensor,
total_count: torch.Tensor,
) -> torch.Tensor:
var = meansqr - mean * mean
if track_running_stats:
with torch.no_grad():
unbiased_var = var * (total_count / (total_count - 1))
running_mean += momentum * (mean.reshape(-1) - running_mean)
running_var += momentum * (unbiased_var.reshape(-1) - running_var)
invstd = torch.rsqrt(var + eps)
if affine:
return (input - mean) * invstd * weight.reshape(mean.shape) + bias.reshape(mean.shape)
else:
return (input - mean) * invstd
if torch.__version__.split(".")[:2] >= ["1", "7"]:
_forward = torch.jit.script(_forward) # type: ignore
class SyncBatchNorm(torch.nn.BatchNorm2d):
"""
Fast re-implementation of ``torch.nn.SyncBatchNorm`` that can achieve a speedup
of 5x or more over the default implementation depending on size of the input
and number of distributed workers.
"""
def __init__(
self, *args: Tuple[Any, ...], process_group: Optional[ProcessGroup] = None, **kwargs: Dict[str, Any]
) -> None:
super().__init__(*args, **kwargs) # type: ignore
self._process_group = process_group if process_group is not None else dist.group.WORLD
def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore
if not dist.is_initialized() or not self.training:
return super().forward(input)
dim = [d for d in range(input.ndim) if d != 1]
count = torch.full((1,), input.numel() // input.size(1), device=input.device, dtype=input.dtype)
total_count = count.clone()
handle = dist.all_reduce(total_count, group=self._process_group, async_op=True)
mean = torch.mean(input, dim=dim, keepdim=True)
meansqr = torch.mean(input * input, dim=dim, keepdim=True)
vec = torch.cat([mean, meansqr])
handle.wait()
vec = vec * (count / total_count)
mean, meansqr = differentiable_all_reduce(vec, group=self._process_group).chunk(2) # type: ignore
return _forward(
input,
self.affine,
self.track_running_stats,
mean,
meansqr,
self.momentum,
self.eps,
self.weight,
self.bias,
self.running_mean,
self.running_var,
total_count,
)
@classmethod
def convert_sync_batchnorm(
cls, module: torch.nn.Module, process_group: Optional[ProcessGroup] = None
) -> torch.nn.Module:
r"""Helper function to convert all :attr:`BatchNorm*D` layers in the model to
:class:`fairscale.experimental.nn.SyncBatchNorm` layers.
Args:
module (nn.Module): module containing one or more attr:`BatchNorm*D` layers
process_group (optional): process group to scope synchronization,
default is the whole world
Returns:
The original :attr:`module` with the converted :class:`torch.nn.SyncBatchNorm`
layers. If the original :attr:`module` is a :attr:`BatchNorm*D` layer,
a new :class:`torch.nn.SyncBatchNorm` layer object will be returned
instead.
Example::
>>> # Network with nn.BatchNorm layer
>>> module = torch.nn.Sequential(
>>> torch.nn.Linear(20, 100),
>>> torch.nn.BatchNorm1d(100),
>>> ).cuda()
>>> # creating process group (optional)
>>> # ranks is a list of int identifying rank ids.
>>> ranks = list(range(8))
>>> r1, r2 = ranks[:4], ranks[4:]
>>> # Note: every rank calls into new_group for every
>>> # process group created, even if that rank is not
>>> # part of the group.
>>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
>>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
>>> sync_bn_module = fairscale.experimental.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group)
"""
module_output = module
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
module_output = SyncBatchNorm(
module.num_features, # type: ignore
module.eps, # type: ignore
module.momentum, # type: ignore
module.affine, # type: ignore
module.track_running_stats, # type: ignore
process_group=process_group,
)
if module.affine:
with torch.no_grad():
module_output.weight = module.weight
module_output.bias = module.bias
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
if hasattr(module, "qconfig"):
module_output.qconfig = module.qconfig
for name, child in module.named_children():
module_output.add_module(name, cls.convert_sync_batchnorm(child, process_group))
del module
return module_output
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Optional
from torch import Tensor
from torch.distributed import ProcessGroup, ReduceOp
def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None): ...
...@@ -5,3 +5,4 @@ tests/nn/data_parallel/test_fsdp_freezing_weights.py ...@@ -5,3 +5,4 @@ tests/nn/data_parallel/test_fsdp_freezing_weights.py
tests/nn/data_parallel/test_fsdp.py tests/nn/data_parallel/test_fsdp.py
tests/nn/misc/test_flatten_params_wrapper.py tests/nn/misc/test_flatten_params_wrapper.py
tests/nn/pipe/test_parity.py tests/nn/pipe/test_parity.py
tests/experimental/nn/test_sync_batchnorm.py
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import functools
import tempfile
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from fairscale.experimental.nn import SyncBatchNorm
pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def pg_worker(rank, world_size, init_file, func, *args):
dist.init_process_group(dist.Backend.NCCL, init_method="file://" + init_file, rank=rank, world_size=world_size)
func(*args)
dist.destroy_process_group()
def pg_test(world_size=torch.cuda.device_count()):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
mp.spawn(pg_worker, args=(world_size, tempfile.mkstemp()[1], func, *kwargs.values()), nprocs=world_size)
globals()["test_" + func.__name__] = wrapper
return func
return decorator
def check_parity(torch_bn, fs_bn, x):
yh = torch.ones_like(x)
torch_y = torch_bn(x)
fs_y = fs_bn(x)
torch_y.backward(yh)
fs_y.backward(yh)
assert torch.allclose(torch_y, fs_y), f"{torch_y} != {fs_y}"
assert torch.allclose(torch_bn.running_mean, fs_bn.running_mean), f"{torch_bn.running_mean} != {fs_bn.running_mean}"
assert torch.allclose(torch_bn.running_var, fs_bn.running_var), f"{torch_bn.running_var} != {fs_bn.running_var}"
assert torch.allclose(torch_bn.weight, fs_bn.weight), f"{torch_bn.weight.grad} != {fs_bn.weight.grad}"
assert torch.allclose(torch_bn.bias, fs_bn.bias), f"{torch_bn.bias.grad} != {fs_bn.bias.grad}"
# TODO(msb) currently disabled due to PyTorch bug: https://github.com/pytorch/pytorch/issues/57796
# assert torch.allclose(torch_bn.weight.grad, fs_bn.weight.grad), f"{torch_bn.weight.grad} != {fs_bn.weight.grad}"
assert torch.allclose(torch_bn.bias.grad, fs_bn.bias.grad), f"{torch_bn.bias.grad} != {fs_bn.bias.grad}"
def check_parity_ddp(torch_bn, fs_bn, x):
yh = torch.ones_like(x)
rank = dist.get_rank()
torch_ddp = DDP(torch_bn, device_ids=[rank])
fs_ddp = DDP(fs_bn, device_ids=[rank])
torch_bn = torch_ddp.module
fs_bn = fs_ddp.module
torch_y = torch_ddp(x)
fs_y = fs_ddp(x)
torch_y.backward(yh)
fs_y.backward(yh)
assert torch.allclose(torch_y, fs_y), f"{torch_y} != {fs_y}"
assert torch.allclose(torch_bn.running_mean, fs_bn.running_mean), f"{torch_bn.running_mean} != {fs_bn.running_mean}"
assert torch.allclose(torch_bn.running_var, fs_bn.running_var), f"{torch_bn.running_var} != {fs_bn.running_var}"
assert torch.allclose(torch_bn.weight, fs_bn.weight), f"{torch_bn.weight.grad} != {fs_bn.weight.grad}"
assert torch.allclose(torch_bn.bias, fs_bn.bias), f"{torch_bn.bias.grad} != {fs_bn.bias.grad}"
# TODO(msb) currently disabled due to PyTorch bug: https://github.com/pytorch/pytorch/issues/57796
# assert torch.allclose(torch_bn.weight.grad, fs_bn.weight.grad), f"{torch_bn.weight.grad} != {fs_bn.weight.grad}"
assert torch.allclose(torch_bn.bias.grad, fs_bn.bias.grad), f"{torch_bn.bias.grad} != {fs_bn.bias.grad}"
@pg_test(world_size=1)
def parity3d_bn():
rank = dist.get_rank()
torch.cuda.set_device(rank)
torch.manual_seed(rank)
x = torch.randn(4, 3, 4, 4, 4).cuda()
torch_bn = torch.nn.BatchNorm3d(3).cuda()
fs_bn = SyncBatchNorm(3).cuda()
check_parity(torch_bn, fs_bn, x)
@pg_test()
def parity3d_syncbn():
rank = dist.get_rank()
torch.cuda.set_device(rank)
torch.manual_seed(rank)
x = torch.randn(4, 3, 4, 4, 4).cuda() * rank
torch_bn = torch.nn.SyncBatchNorm(3).cuda()
fs_bn = SyncBatchNorm(3).cuda()
check_parity_ddp(torch_bn, fs_bn, x)
@pg_test(world_size=1)
def parity2d_bn():
rank = dist.get_rank()
torch.cuda.set_device(rank)
torch.manual_seed(rank)
x = torch.randn(4, 3, 4, 4).cuda()
torch_bn = torch.nn.BatchNorm2d(3).cuda()
fs_bn = SyncBatchNorm(3).cuda()
check_parity(torch_bn, fs_bn, x)
@pg_test()
def parity2d_syncbn():
rank = dist.get_rank()
torch.cuda.set_device(rank)
torch.manual_seed(rank)
x = torch.randn(4, 3, 4, 4).cuda() * rank
torch_bn = torch.nn.SyncBatchNorm(3).cuda()
fs_bn = SyncBatchNorm(3).cuda()
check_parity_ddp(torch_bn, fs_bn, x)
@pg_test(world_size=1)
def parity1d_bn():
rank = dist.get_rank()
torch.cuda.set_device(rank)
torch.manual_seed(rank)
x = torch.randn(4, 3, 4).cuda()
torch_bn = torch.nn.BatchNorm1d(3).cuda()
fs_bn = SyncBatchNorm(3).cuda()
check_parity(torch_bn, fs_bn, x)
@pg_test()
def parity1d_syncbn():
rank = dist.get_rank()
torch.cuda.set_device(rank)
torch.manual_seed(rank)
x = torch.randn(4, 3, 4).cuda()
torch_bn = torch.nn.SyncBatchNorm(3).cuda()
fs_bn = SyncBatchNorm(3).cuda()
check_parity_ddp(torch_bn, fs_bn, x)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment