Unverified Commit 29aae007 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[perf] SyncBatchNorm: avoid 2nd set of all_reduce when wrapped by checkpoint_wrapper (#694)

This change also ensure that we calculate running_{mean,var} correctly
when wrapped.
parent 3dcc9eff
...@@ -3,80 +3,62 @@ ...@@ -3,80 +3,62 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import torch import torch
from torch import Tensor
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from fairscale.nn.checkpoint import is_checkpointing, is_recomputing
def _forward(
input: torch.Tensor,
affine: bool,
track_running_stats: bool,
mean: torch.Tensor,
var: torch.Tensor,
invstd: torch.Tensor,
momentum: float,
weight: torch.Tensor,
bias: torch.Tensor,
running_mean: torch.Tensor,
running_var: torch.Tensor,
total_count: torch.Tensor,
) -> torch.Tensor:
if track_running_stats:
with torch.no_grad():
unbiased_var = var * (total_count / (total_count - 1))
running_mean += momentum * (mean.reshape(-1) - running_mean)
running_var += momentum * (unbiased_var.reshape(-1) - running_var)
def _forward(input: Tensor, affine: bool, mean: Tensor, invstd: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
if affine: if affine:
return (input - mean) * (invstd * weight.reshape_as(mean)) + bias.reshape_as(mean) return (input - mean) * (invstd * weight.reshape_as(mean)) + bias.reshape_as(mean)
else: else:
return (input - mean) * invstd return (input - mean) * invstd
def _track_running_stats(
running_mean: Tensor, running_var: Tensor, momentum: float, mean: Tensor, var: Tensor, total_count: Tensor
) -> None:
with torch.no_grad():
unbiased_var = var * (total_count / (total_count - 1))
running_mean += momentum * (mean.reshape(-1) - running_mean)
running_var += momentum * (unbiased_var.reshape(-1) - running_var)
def _calculate_stats(input: Tensor, eps: float, process_group: ProcessGroup) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
dim = [d for d in range(input.ndim) if d != 1]
count = torch.full((1,), input.numel() // input.size(1), device=input.device, dtype=input.dtype)
total_count = count.clone()
all_reduce_handle = dist.all_reduce(total_count, group=process_group, async_op=True)
mean = torch.mean(input, dim=dim, keepdim=True)
meansqr = torch.mean(input * input, dim=dim, keepdim=True)
vec = torch.cat([mean, meansqr])
all_reduce_handle.wait()
vec = vec * (count / total_count)
dist.all_reduce(vec, group=process_group)
mean, meansqr = vec.chunk(2)
var = meansqr - mean * mean
invstd = torch.rsqrt(var + eps)
return mean, var, invstd, total_count
if torch.__version__.split(".")[:2] >= ["1", "7"]: if torch.__version__.split(".")[:2] >= ["1", "7"]:
_forward = torch.jit.script(_forward) # type: ignore _forward = torch.jit.script(_forward) # type: ignore
_track_running_stats = torch.jit.script(_track_running_stats) # type: ignore
class _SyncBatchNormFunction(torch.autograd.Function): class _SyncBatchNormFunction(torch.autograd.Function):
@staticmethod @staticmethod
# type: ignore # type: ignore
def forward( def forward(ctx, input, weight, bias, affine, mean, invstd, total_count, process_group):
ctx, input, weight, bias, affine, track_running_stats, running_mean, running_var, eps, momentum, process_group
):
dim = [d for d in range(input.ndim) if d != 1]
count = torch.full((1,), input.numel() // input.size(1), device=input.device, dtype=input.dtype)
total_count = count.clone()
all_reduce_handle = dist.all_reduce(total_count, group=process_group, async_op=True)
mean = torch.mean(input, dim=dim, keepdim=True)
meansqr = torch.mean(input * input, dim=dim, keepdim=True)
vec = torch.cat([mean, meansqr])
all_reduce_handle.wait()
vec = vec * (count / total_count)
dist.all_reduce(vec, group=process_group)
mean, meansqr = vec.chunk(2)
var = meansqr - mean * mean
invstd = torch.rsqrt(var + eps)
ctx.save_for_backward(input, weight, bias, mean, invstd, total_count) ctx.save_for_backward(input, weight, bias, mean, invstd, total_count)
ctx.process_group = process_group ctx.process_group = process_group
return _forward( return _forward(input, affine, mean, invstd, weight, bias)
input,
affine,
track_running_stats,
mean,
var,
invstd,
momentum,
weight,
bias,
running_mean,
running_var,
total_count,
)
@staticmethod @staticmethod
# type: ignore # type: ignore
...@@ -138,22 +120,32 @@ class SyncBatchNorm(torch.nn.BatchNorm2d): ...@@ -138,22 +120,32 @@ class SyncBatchNorm(torch.nn.BatchNorm2d):
) -> None: ) -> None:
super().__init__(*args, **kwargs) # type: ignore super().__init__(*args, **kwargs) # type: ignore
self._process_group = process_group if process_group is not None else dist.group.WORLD self._process_group = process_group if process_group is not None else dist.group.WORLD
self.saved_for_2nd_fwd: List[Tuple] = []
self.disable_patch_batchnorm = True
def forward(self, input: Tensor) -> Tensor: # type: ignore
# There are 3 modes this is being called:
# 1. not wrapped (and there is only a single phase)
# 2. wrapped and in checkpointing phase
# 3. wrapped and in recomputing phase
def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore
if not dist.is_initialized() or not self.training: if not dist.is_initialized() or not self.training:
return super().forward(input) return super().forward(input)
wrapped = is_checkpointing() or is_recomputing()
if not wrapped or is_checkpointing():
mean, var, invstd, total_count = _calculate_stats(input, self.eps, self._process_group)
if self.track_running_stats:
_track_running_stats(self.running_mean, self.running_var, self.momentum, mean, var, total_count)
if is_checkpointing():
self.saved_for_2nd_fwd.append((mean, invstd, total_count))
return _forward(input, self.affine, mean, invstd, self.weight, self.bias)
if is_recomputing():
mean, invstd, total_count = self.saved_for_2nd_fwd.pop(0)
return _SyncBatchNormFunction.apply( return _SyncBatchNormFunction.apply(
input, input, self.weight, self.bias, self.affine, mean, invstd, total_count, self._process_group
self.weight,
self.bias,
self.affine,
self.track_running_stats,
self.running_mean,
self.running_var,
self.eps,
self.momentum,
self._process_group,
) )
@classmethod @classmethod
......
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
from typing import List from typing import List
from .checkpoint_activations import checkpoint_wrapper from .checkpoint_activations import checkpoint_wrapper, is_checkpointing, is_recomputing
__all__: List[str] = [] __all__: List[str] = []
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
from contextlib import contextmanager from contextlib import contextmanager
import functools import functools
import threading
from typing import Any, Dict, Generator, Optional, Tuple from typing import Any, Dict, Generator, Optional, Tuple
import weakref import weakref
...@@ -18,6 +19,70 @@ from fairscale.utils.containers import pack_kwargs, split_non_tensors, unpack_kw ...@@ -18,6 +19,70 @@ from fairscale.utils.containers import pack_kwargs, split_non_tensors, unpack_kw
from .checkpoint_utils import dec_counter, inc_counter, init_counter, patch_batchnorm from .checkpoint_utils import dec_counter, inc_counter, init_counter, patch_batchnorm
# https://docs.python.org/3/library/threading.html#thread-local-data
# Manage the checkpoint context with thread-local data.
class ThreadLocal(threading.local):
def __init__(self) -> None:
self.is_checkpointing = False
self.is_recomputing = False
thread_local = ThreadLocal()
@contextmanager
def enable_checkpointing() -> Generator[None, None, None]:
"""Makes :func:`is_checkpointing` return :data:`True` within a context."""
orig = thread_local.is_checkpointing
thread_local.is_checkpointing = True
try:
yield
finally:
thread_local.is_checkpointing = orig
@contextmanager
def enable_recomputing() -> Generator[None, None, None]:
"""Makes :func:`is_recomputing` return :data:`True` within a context."""
orig = thread_local.is_recomputing
thread_local.is_recomputing = True
try:
yield
finally:
thread_local.is_recomputing = orig
def is_checkpointing() -> bool:
"""Whether the current forward propagation is under checkpointing.
Returns:
bool: :data:`True` if it's under checkpointing.
"""
return thread_local.is_checkpointing
def is_recomputing() -> bool:
"""Whether the current forward propagation is under checkpoint
recomputation. Use this to prevent duplicated side-effects at forward
propagation::
class Counter(nn.Module):
def __init__(self):
super().__init__()
self.counter = 0
def forward(self, input):
if not is_recomputing():
self.counter += 1
return input
Returns:
bool: :data:`True` if it's under checkpoint recomputation.
"""
return thread_local.is_recomputing
def checkpoint_wrapper( def checkpoint_wrapper(
module: nn.Module, offload_to_cpu: bool = False, maintain_forward_counter: bool = False module: nn.Module, offload_to_cpu: bool = False, maintain_forward_counter: bool = False
) -> nn.Module: ) -> nn.Module:
...@@ -174,7 +239,7 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -174,7 +239,7 @@ class CheckpointFunction(torch.autograd.Function):
ctx.save_for_backward(*tensor_inputs) ctx.save_for_backward(*tensor_inputs)
ctx.packed_non_tensor_inputs = packed_non_tensor_inputs ctx.packed_non_tensor_inputs = packed_non_tensor_inputs
with torch.no_grad(): with torch.no_grad(), enable_checkpointing():
unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args) unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args)
outputs = run_function(*unpacked_args, **unpacked_kwargs) outputs = run_function(*unpacked_args, **unpacked_kwargs)
the_module = unpacked_args[0] the_module = unpacked_args[0]
...@@ -207,7 +272,7 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -207,7 +272,7 @@ class CheckpointFunction(torch.autograd.Function):
# Set the states to what it used to be before the forward pass. # Set the states to what it used to be before the forward pass.
set_rng_state(ctx.fwd_rng_state) set_rng_state(ctx.fwd_rng_state)
with torch.enable_grad(), autocast(ctx.had_autocast_in_fwd): with torch.enable_grad(), enable_recomputing(), autocast(ctx.had_autocast_in_fwd):
unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs) unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs)
outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs) outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs)
tensor_outputs, _ = split_non_tensors(outputs) tensor_outputs, _ = split_non_tensors(outputs)
......
...@@ -42,7 +42,7 @@ def patch_batchnorm(module: nn.Module) -> List: ...@@ -42,7 +42,7 @@ def patch_batchnorm(module: nn.Module) -> List:
hooks = [] hooks = []
for name, child in module.named_modules(): for name, child in module.named_modules():
# _BatchNorm is base for bn1d, bn2d, bn3d and sync_bn, apex_sync_bn, etc. # _BatchNorm is base for bn1d, bn2d, bn3d and sync_bn, apex_sync_bn, etc.
if isinstance(child, _BatchNorm): if isinstance(child, _BatchNorm) and not hasattr(child, "disable_patch_batchnorm"):
# Register the pre/post hooks. # Register the pre/post hooks.
pre_handle = child.register_forward_pre_hook(pre_forward) pre_handle = child.register_forward_pre_hook(pre_forward)
post_handle = child.register_forward_hook(post_forward) post_handle = child.register_forward_hook(post_forward)
......
...@@ -10,9 +10,11 @@ import pytest ...@@ -10,9 +10,11 @@ import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from fairscale.experimental.nn import SyncBatchNorm from fairscale.experimental.nn import SyncBatchNorm
from fairscale.nn.checkpoint import checkpoint_wrapper
pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
...@@ -71,13 +73,16 @@ def check_parity_ddp(torch_bn, fs_bn, x): ...@@ -71,13 +73,16 @@ def check_parity_ddp(torch_bn, fs_bn, x):
fs_y = fs_ddp(fs_x) fs_y = fs_ddp(fs_x)
fs_y.backward(yh) fs_y.backward(yh)
torch.testing.assert_allclose(torch_y, fs_y) torch.testing.assert_allclose(torch_y, fs_y)
torch.testing.assert_allclose(torch_x.grad, fs_x.grad)
if isinstance(torch_bn, nn.Sequential):
torch_bn = torch_bn[0]
fs_bn = fs_bn[0]
torch.testing.assert_allclose(torch_bn.running_mean, fs_bn.running_mean) torch.testing.assert_allclose(torch_bn.running_mean, fs_bn.running_mean)
torch.testing.assert_allclose(torch_bn.running_var, fs_bn.running_var) torch.testing.assert_allclose(torch_bn.running_var, fs_bn.running_var)
torch.testing.assert_allclose(torch_bn.weight, fs_bn.weight) torch.testing.assert_allclose(torch_bn.weight, fs_bn.weight)
torch.testing.assert_allclose(torch_bn.bias, fs_bn.bias) torch.testing.assert_allclose(torch_bn.bias, fs_bn.bias)
torch.testing.assert_allclose(torch_bn.weight.grad, fs_bn.weight.grad) torch.testing.assert_allclose(torch_bn.weight.grad, fs_bn.weight.grad)
torch.testing.assert_allclose(torch_bn.bias.grad, fs_bn.bias.grad) torch.testing.assert_allclose(torch_bn.bias.grad, fs_bn.bias.grad)
torch.testing.assert_allclose(torch_x.grad, fs_x.grad)
@pg_test(world_size=1) @pg_test(world_size=1)
...@@ -92,6 +97,34 @@ def parity3d_bn(): ...@@ -92,6 +97,34 @@ def parity3d_bn():
check_parity(torch_bn, fs_bn, x) check_parity(torch_bn, fs_bn, x)
@pg_test()
def parity3d_checkpoint_syncbn():
rank = dist.get_rank()
torch.cuda.set_device(rank)
torch.manual_seed(rank)
x = torch.randn(4, 3, 4, 4, 4).cuda() * rank
torch_bn = torch.nn.SyncBatchNorm(3).cuda()
fs_bn = SyncBatchNorm(3).cuda()
fs_bn = checkpoint_wrapper(fs_bn, maintain_forward_counter=True)
check_parity_ddp(torch_bn, fs_bn, x)
@pg_test()
def parity3d_checkpoint_syncbn_twice():
rank = dist.get_rank()
torch.cuda.set_device(rank)
torch.manual_seed(rank)
x = torch.randn(4, 3, 4, 4, 4).cuda() * rank
torch_bn = torch.nn.SyncBatchNorm(3)
torch_bn = nn.Sequential(torch_bn, torch_bn).cuda()
fs_bn = SyncBatchNorm(3)
fs_bn = nn.Sequential(fs_bn, fs_bn).cuda()
fs_bn = checkpoint_wrapper(fs_bn)
check_parity_ddp(torch_bn, fs_bn, x)
@pg_test() @pg_test()
def parity3d_syncbn(): def parity3d_syncbn():
rank = dist.get_rank() rank = dist.get_rank()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment