Unverified Commit 29aae007 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[perf] SyncBatchNorm: avoid 2nd set of all_reduce when wrapped by checkpoint_wrapper (#694)

This change also ensure that we calculate running_{mean,var} correctly
when wrapped.
parent 3dcc9eff
......@@ -3,80 +3,62 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch import Tensor
import torch.distributed as dist
from torch.distributed import ProcessGroup
from fairscale.nn.checkpoint import is_checkpointing, is_recomputing
def _forward(
input: torch.Tensor,
affine: bool,
track_running_stats: bool,
mean: torch.Tensor,
var: torch.Tensor,
invstd: torch.Tensor,
momentum: float,
weight: torch.Tensor,
bias: torch.Tensor,
running_mean: torch.Tensor,
running_var: torch.Tensor,
total_count: torch.Tensor,
) -> torch.Tensor:
if track_running_stats:
with torch.no_grad():
unbiased_var = var * (total_count / (total_count - 1))
running_mean += momentum * (mean.reshape(-1) - running_mean)
running_var += momentum * (unbiased_var.reshape(-1) - running_var)
def _forward(input: Tensor, affine: bool, mean: Tensor, invstd: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
if affine:
return (input - mean) * (invstd * weight.reshape_as(mean)) + bias.reshape_as(mean)
else:
return (input - mean) * invstd
def _track_running_stats(
running_mean: Tensor, running_var: Tensor, momentum: float, mean: Tensor, var: Tensor, total_count: Tensor
) -> None:
with torch.no_grad():
unbiased_var = var * (total_count / (total_count - 1))
running_mean += momentum * (mean.reshape(-1) - running_mean)
running_var += momentum * (unbiased_var.reshape(-1) - running_var)
def _calculate_stats(input: Tensor, eps: float, process_group: ProcessGroup) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
dim = [d for d in range(input.ndim) if d != 1]
count = torch.full((1,), input.numel() // input.size(1), device=input.device, dtype=input.dtype)
total_count = count.clone()
all_reduce_handle = dist.all_reduce(total_count, group=process_group, async_op=True)
mean = torch.mean(input, dim=dim, keepdim=True)
meansqr = torch.mean(input * input, dim=dim, keepdim=True)
vec = torch.cat([mean, meansqr])
all_reduce_handle.wait()
vec = vec * (count / total_count)
dist.all_reduce(vec, group=process_group)
mean, meansqr = vec.chunk(2)
var = meansqr - mean * mean
invstd = torch.rsqrt(var + eps)
return mean, var, invstd, total_count
if torch.__version__.split(".")[:2] >= ["1", "7"]:
_forward = torch.jit.script(_forward) # type: ignore
_track_running_stats = torch.jit.script(_track_running_stats) # type: ignore
class _SyncBatchNormFunction(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(
ctx, input, weight, bias, affine, track_running_stats, running_mean, running_var, eps, momentum, process_group
):
dim = [d for d in range(input.ndim) if d != 1]
count = torch.full((1,), input.numel() // input.size(1), device=input.device, dtype=input.dtype)
total_count = count.clone()
all_reduce_handle = dist.all_reduce(total_count, group=process_group, async_op=True)
mean = torch.mean(input, dim=dim, keepdim=True)
meansqr = torch.mean(input * input, dim=dim, keepdim=True)
vec = torch.cat([mean, meansqr])
all_reduce_handle.wait()
vec = vec * (count / total_count)
dist.all_reduce(vec, group=process_group)
mean, meansqr = vec.chunk(2)
var = meansqr - mean * mean
invstd = torch.rsqrt(var + eps)
def forward(ctx, input, weight, bias, affine, mean, invstd, total_count, process_group):
ctx.save_for_backward(input, weight, bias, mean, invstd, total_count)
ctx.process_group = process_group
return _forward(
input,
affine,
track_running_stats,
mean,
var,
invstd,
momentum,
weight,
bias,
running_mean,
running_var,
total_count,
)
return _forward(input, affine, mean, invstd, weight, bias)
@staticmethod
# type: ignore
......@@ -138,22 +120,32 @@ class SyncBatchNorm(torch.nn.BatchNorm2d):
) -> None:
super().__init__(*args, **kwargs) # type: ignore
self._process_group = process_group if process_group is not None else dist.group.WORLD
self.saved_for_2nd_fwd: List[Tuple] = []
self.disable_patch_batchnorm = True
def forward(self, input: Tensor) -> Tensor: # type: ignore
# There are 3 modes this is being called:
# 1. not wrapped (and there is only a single phase)
# 2. wrapped and in checkpointing phase
# 3. wrapped and in recomputing phase
def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore
if not dist.is_initialized() or not self.training:
return super().forward(input)
wrapped = is_checkpointing() or is_recomputing()
if not wrapped or is_checkpointing():
mean, var, invstd, total_count = _calculate_stats(input, self.eps, self._process_group)
if self.track_running_stats:
_track_running_stats(self.running_mean, self.running_var, self.momentum, mean, var, total_count)
if is_checkpointing():
self.saved_for_2nd_fwd.append((mean, invstd, total_count))
return _forward(input, self.affine, mean, invstd, self.weight, self.bias)
if is_recomputing():
mean, invstd, total_count = self.saved_for_2nd_fwd.pop(0)
return _SyncBatchNormFunction.apply(
input,
self.weight,
self.bias,
self.affine,
self.track_running_stats,
self.running_mean,
self.running_var,
self.eps,
self.momentum,
self._process_group,
input, self.weight, self.bias, self.affine, mean, invstd, total_count, self._process_group
)
@classmethod
......
......@@ -5,6 +5,6 @@
from typing import List
from .checkpoint_activations import checkpoint_wrapper
from .checkpoint_activations import checkpoint_wrapper, is_checkpointing, is_recomputing
__all__: List[str] = []
......@@ -5,6 +5,7 @@
from contextlib import contextmanager
import functools
import threading
from typing import Any, Dict, Generator, Optional, Tuple
import weakref
......@@ -18,6 +19,70 @@ from fairscale.utils.containers import pack_kwargs, split_non_tensors, unpack_kw
from .checkpoint_utils import dec_counter, inc_counter, init_counter, patch_batchnorm
# https://docs.python.org/3/library/threading.html#thread-local-data
# Manage the checkpoint context with thread-local data.
class ThreadLocal(threading.local):
def __init__(self) -> None:
self.is_checkpointing = False
self.is_recomputing = False
thread_local = ThreadLocal()
@contextmanager
def enable_checkpointing() -> Generator[None, None, None]:
"""Makes :func:`is_checkpointing` return :data:`True` within a context."""
orig = thread_local.is_checkpointing
thread_local.is_checkpointing = True
try:
yield
finally:
thread_local.is_checkpointing = orig
@contextmanager
def enable_recomputing() -> Generator[None, None, None]:
"""Makes :func:`is_recomputing` return :data:`True` within a context."""
orig = thread_local.is_recomputing
thread_local.is_recomputing = True
try:
yield
finally:
thread_local.is_recomputing = orig
def is_checkpointing() -> bool:
"""Whether the current forward propagation is under checkpointing.
Returns:
bool: :data:`True` if it's under checkpointing.
"""
return thread_local.is_checkpointing
def is_recomputing() -> bool:
"""Whether the current forward propagation is under checkpoint
recomputation. Use this to prevent duplicated side-effects at forward
propagation::
class Counter(nn.Module):
def __init__(self):
super().__init__()
self.counter = 0
def forward(self, input):
if not is_recomputing():
self.counter += 1
return input
Returns:
bool: :data:`True` if it's under checkpoint recomputation.
"""
return thread_local.is_recomputing
def checkpoint_wrapper(
module: nn.Module, offload_to_cpu: bool = False, maintain_forward_counter: bool = False
) -> nn.Module:
......@@ -174,7 +239,7 @@ class CheckpointFunction(torch.autograd.Function):
ctx.save_for_backward(*tensor_inputs)
ctx.packed_non_tensor_inputs = packed_non_tensor_inputs
with torch.no_grad():
with torch.no_grad(), enable_checkpointing():
unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args)
outputs = run_function(*unpacked_args, **unpacked_kwargs)
the_module = unpacked_args[0]
......@@ -207,7 +272,7 @@ class CheckpointFunction(torch.autograd.Function):
# Set the states to what it used to be before the forward pass.
set_rng_state(ctx.fwd_rng_state)
with torch.enable_grad(), autocast(ctx.had_autocast_in_fwd):
with torch.enable_grad(), enable_recomputing(), autocast(ctx.had_autocast_in_fwd):
unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs)
outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs)
tensor_outputs, _ = split_non_tensors(outputs)
......
......@@ -42,7 +42,7 @@ def patch_batchnorm(module: nn.Module) -> List:
hooks = []
for name, child in module.named_modules():
# _BatchNorm is base for bn1d, bn2d, bn3d and sync_bn, apex_sync_bn, etc.
if isinstance(child, _BatchNorm):
if isinstance(child, _BatchNorm) and not hasattr(child, "disable_patch_batchnorm"):
# Register the pre/post hooks.
pre_handle = child.register_forward_pre_hook(pre_forward)
post_handle = child.register_forward_hook(post_forward)
......
......@@ -10,9 +10,11 @@ import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from fairscale.experimental.nn import SyncBatchNorm
from fairscale.nn.checkpoint import checkpoint_wrapper
pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
......@@ -71,13 +73,16 @@ def check_parity_ddp(torch_bn, fs_bn, x):
fs_y = fs_ddp(fs_x)
fs_y.backward(yh)
torch.testing.assert_allclose(torch_y, fs_y)
torch.testing.assert_allclose(torch_x.grad, fs_x.grad)
if isinstance(torch_bn, nn.Sequential):
torch_bn = torch_bn[0]
fs_bn = fs_bn[0]
torch.testing.assert_allclose(torch_bn.running_mean, fs_bn.running_mean)
torch.testing.assert_allclose(torch_bn.running_var, fs_bn.running_var)
torch.testing.assert_allclose(torch_bn.weight, fs_bn.weight)
torch.testing.assert_allclose(torch_bn.bias, fs_bn.bias)
torch.testing.assert_allclose(torch_bn.weight.grad, fs_bn.weight.grad)
torch.testing.assert_allclose(torch_bn.bias.grad, fs_bn.bias.grad)
torch.testing.assert_allclose(torch_x.grad, fs_x.grad)
@pg_test(world_size=1)
......@@ -92,6 +97,34 @@ def parity3d_bn():
check_parity(torch_bn, fs_bn, x)
@pg_test()
def parity3d_checkpoint_syncbn():
rank = dist.get_rank()
torch.cuda.set_device(rank)
torch.manual_seed(rank)
x = torch.randn(4, 3, 4, 4, 4).cuda() * rank
torch_bn = torch.nn.SyncBatchNorm(3).cuda()
fs_bn = SyncBatchNorm(3).cuda()
fs_bn = checkpoint_wrapper(fs_bn, maintain_forward_counter=True)
check_parity_ddp(torch_bn, fs_bn, x)
@pg_test()
def parity3d_checkpoint_syncbn_twice():
rank = dist.get_rank()
torch.cuda.set_device(rank)
torch.manual_seed(rank)
x = torch.randn(4, 3, 4, 4, 4).cuda() * rank
torch_bn = torch.nn.SyncBatchNorm(3)
torch_bn = nn.Sequential(torch_bn, torch_bn).cuda()
fs_bn = SyncBatchNorm(3)
fs_bn = nn.Sequential(fs_bn, fs_bn).cuda()
fs_bn = checkpoint_wrapper(fs_bn)
check_parity_ddp(torch_bn, fs_bn, x)
@pg_test()
def parity3d_syncbn():
rank = dist.get_rank()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment