Unverified Commit 5e64d6a7 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat]: checkpoint and normalization (#457)

* [feat]: checkpoint and normalization

- added special handling of BN for track_running_stats and checkpointing
- we test BN/LN and checkpointing
- we test them with mixed precision
parent b36e01d5
......@@ -15,6 +15,8 @@ import torch.utils.checkpoint as torch_checkpoint
from fairscale.utils.containers import pack_kwargs, split_non_tensors, unpack_kwargs, unpack_non_tensors
from .misc import patch_batchnorm
def checkpoint_wrapper(module: nn.Module, offload_to_cpu: bool = False) -> nn.Module:
"""
......@@ -61,15 +63,20 @@ def checkpoint_wrapper(module: nn.Module, offload_to_cpu: bool = False) -> nn.Mo
Returns:
(nn.Module):
wrapped module
Wrapped module
"""
# Patch the batchnorm layers in case there are any.
patch_batchnorm(module)
# The use of weakref here is to prevent creating a ref cycle: m -> m.forward -> m.
# When such cycle exists, gc won't collect the module when the module is freed.
# That causes GPU memory to be leaked. See the unit test for how we catch that.
#
# We prefer this over a class wrapper since the class wrapper would have to
# proxy a lot of fields and methods.
module.forward = functools.partial(_checkpointed_forward, type(module).forward, weakref.ref(module), offload_to_cpu) # type: ignore
module.forward = functools.partial( # type: ignore
_checkpointed_forward, type(module).forward, weakref.ref(module), offload_to_cpu
)
return module
......@@ -81,7 +88,9 @@ def _checkpointed_forward(
# We can flatten keyword arguments to make this easier.
args = (weak_self(),) + args
kwarg_keys, flat_args = pack_kwargs(*args, **kwargs)
parent_ctx_dict: Dict[str, Any] = {"offload": offload_to_cpu}
parent_ctx_dict: Dict[str, Any] = {
"offload": offload_to_cpu,
}
output = CheckpointFunction.apply(original_forward, parent_ctx_dict, kwarg_keys, *flat_args)
if not isinstance(output, torch.Tensor):
packed_non_tensor_outputs = parent_ctx_dict["packed_non_tensor_outputs"]
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import List
import torch
from torch import Tensor, nn
from torch.nn.modules.batchnorm import _BatchNorm
def patch_batchnorm(module: nn.Module) -> List:
"""Patch all batchnorm instances (1d, 2d, 3d, sync_bn, etc.) of a module
so that they don't track running stats when torch.no_grad() is enabled.
This is important in activation checkpointing to ensure stats are tracked
correctly as if there were no activation checkpointing. The reason is
that activation checkpointing runs the forward function twice, first
with torch.no_grad(), then with torch.grad().
Args:
module (nn.Module):
The module to be patched in-place.
Returns:
(list):
A list of hook handles, late can be freed.
"""
hooks = []
def pre_forward(module: _BatchNorm, input: Tensor) -> None:
if torch.is_grad_enabled():
return
module._track_running_stats_backup = module.track_running_stats
module.track_running_stats = False
def post_forward(module: _BatchNorm, input: Tensor, result: Tensor) -> None:
if torch.is_grad_enabled():
return
module.track_running_stats = module._track_running_stats_backup
for name, child in module.named_modules():
# _BatchNorm is base for bn1d, bn2d, bn3d and sync_bn, apex_sync_bn, etc.
if isinstance(child, _BatchNorm):
# Register the pre/post hooks.
pre_handle = child.register_forward_pre_hook(pre_forward)
post_handle = child.register_forward_hook(post_forward)
hooks += [pre_handle, post_handle]
return hooks
......@@ -15,6 +15,9 @@ class _BatchNorm(Module):
weight: Parameter = ...
bias: Parameter = ...
# This field is used by fairscale.nn.misc.misc::patch_batchnorm
_track_running_stats_backup: bool
#MODIFIED BY TORCHGPIPE
running_mean: Tensor
running_var: Tensor
......@@ -28,10 +31,6 @@ class _BatchNorm(Module):
def reset_parameters(self) -> None: ...
def forward(self, input: Tensor) -> Tensor: ... # type: ignore
def __call__(self, input: Tensor) -> Tensor: ... # type: ignore
class BatchNorm1d(_BatchNorm): ...
......@@ -46,7 +45,3 @@ class SyncBatchNorm(_BatchNorm):
# TODO set process_group to the write type once torch.distributed is stubbed
def __init__(self, num_features: int, eps: float = ..., momentum: float = ..., affine: bool = ...,
track_running_stats: bool = ..., process_group: Optional[Any] = ...) -> None: ...
def forward(self, input: Tensor) -> Tensor: ... # type: ignore
def __call__(self, input: Tensor) -> Tensor: ... # type: ignore
tests/nn/misc/test_flatten_params_wrapper.py
tests/nn/misc/test_checkpoint_activations.py
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/wrap/test_wrap.py
......@@ -2,6 +2,9 @@ tests/utils/test_reduce_scatter_bucketer.py
tests/utils/test_containers.py
tests/utils/test_parallel.py
tests/utils/test_state_dict.py
tests/nn/misc/test_checkpoint_activations.py
tests/nn/misc/test_checkpoint_activations_norm.py
tests/nn/wrap/test_wrap.py
tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_transparency.py
tests/nn/pipe_process/test_inplace.py
......
tests/nn/data_parallel/test_fsdp_uneven.py
tests/nn/data_parallel/test_fsdp_grad_scaler.py
tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_features_sharded_ddp.py
tests/nn/data_parallel/test_pytorch_parity_sharded_ddp.py
tests/nn/pipe/skip/test_gpipe.py
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
""" Test checkpoint_wrapper with normalization layers. """
import pytest
import torch
from torch.nn import BatchNorm2d, LayerNorm, Linear, Sequential
from torch.optim import SGD
from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper
from fairscale.utils.testing import objects_are_equal, torch_version
NORM_TYPES = [LayerNorm, BatchNorm2d]
MP_TYPES = ["fp32", "fp16", "call_half"]
def get_model(norm_type, checkpointed, mixed_precision):
assert norm_type in NORM_TYPES, norm_type
assert checkpointed in [True, False], checkpointed
assert mixed_precision in MP_TYPES
model = Sequential(Linear(3, 2), norm_type(2))
if mixed_precision == "fp16":
# Set param.data and buffers as fp16
for p in model.parameters():
p.data = p.data.half()
for m in model:
for n, b in m.named_buffers():
setattr(m, n, b.half())
elif mixed_precision == "call_half":
model.half()
if checkpointed:
model = checkpoint_wrapper(model)
return model
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("norm_type", NORM_TYPES)
@pytest.mark.parametrize("mixed_precision", MP_TYPES)
def test_norm(device, norm_type, mixed_precision):
"""Test checkpoint_wrapper with different norm layers."""
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("Skip due to lack of GPU")
# Get input, ref, checkpoint models and make them equal.
in_data = torch.rand(2, 2, 3, 3).to(device)
m_ref = get_model(norm_type, False, mixed_precision).to(device)
m_cpt = get_model(norm_type, True, mixed_precision).to(device)
m_cpt.load_state_dict(m_ref.state_dict())
if torch_version() >= (1, 6, 0):
# This assert fails on 1.5.1.
assert objects_are_equal(m_ref.state_dict(), m_cpt.state_dict())
if mixed_precision != "fp32":
in_data = in_data.half()
# Needed due to checkpointing.
in_data.requires_grad = True
for model in (m_ref, m_cpt):
optim = SGD(model.parameters(), lr=0.1)
if device == "cpu" and mixed_precision != "fp32":
# Got: RuntimeError: "batch_norm"/"layer_norm" not implemented for 'Half'.
with pytest.raises(RuntimeError):
out = model(in_data)
return
else:
# Everything else work.
out = model(in_data)
out.sum().backward()
optim.step()
if torch_version() >= (1, 6, 0):
assert objects_are_equal(m_ref.state_dict(), m_cpt.state_dict())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment