Unverified Commit 370b8483 authored by Pete's avatar Pete Committed by GitHub
Browse files

Use original forward pass directly when in eval mode from within checkpoint wrapper (#709)

* add failing test

* add fix

* use 'torch.is_grad_enabled()' instead of 'module.training'

* Revert "add failing test"

This reverts commit 1c34242208f9b2c5fa6c8f181434c2be6d7cdbc0.

* add simple test

* improve test

* add check for fwd_counter

* revert typing/format changes

* move to new test file

* CHANGELOG

* remove old test

* fix import order

* fix test to be compat with torch 1.6.0

* clean up

* comments

* isort 🤦
parent d60fc284
......@@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## NEXT - TBD
### Fixed
- checkpointing: use dummy tensor to ensure backward pass is called [#701]
- checkpointing: ensure internal fwd counter is not incremented in eval mode [#709]
- FSDP: fixed bug where buffers returned in `state_dict()` could still be half precision when `mixed_precision` is set to `True`.
### Added
......
......@@ -157,10 +157,20 @@ def checkpoint_wrapper(
def _checkpointed_forward(
original_forward: Any, weak_self: Any, offload_to_cpu: bool, *args: Any, **kwargs: Any
) -> Any:
module = weak_self()
# If gradients are disabled, just use original `.forward()` method directly.
# Doing so also ensures the internal fwd counter is not incremented in the forward pass,
# which would be an issue during eval since there wouldn't be a corresponding backward pass
# to decrement the fwd counter.
# See https://github.com/facebookresearch/fairscale/pull/709.
if not torch.is_grad_enabled():
return original_forward(module, *args, **kwargs)
# Autograd Functions in PyTorch work best with positional args, since
# the backward must return gradients (or None) for every input argument.
# We can flatten keyword arguments to make this easier.
args = (weak_self(),) + args
args = (module,) + args
kwarg_keys, flat_args = pack_kwargs(*args, **kwargs)
parent_ctx_dict: Dict[str, Any] = {
"offload": offload_to_cpu,
......@@ -227,7 +237,6 @@ class CheckpointFunction(torch.autograd.Function):
*args: Any,
**kwargs: Any
) -> Any:
if torch.is_grad_enabled(): # grad may be disabled, e.g., during validation
torch_checkpoint.check_backward_validity(args)
ctx.run_function = run_function
......
......@@ -9,3 +9,4 @@ tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_fsdp_input.py
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp_with_checkpoint_wrapper.py
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
""" Test FSDP with an submodule that is FSDP(checkpoint_wrapper()). """
import pytest
import torch
from torch import nn
import torch.distributed
import torch.multiprocessing as mp
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, temp_files_ctx, torch_version
@skip_if_single_gpu
def test_train_and_eval_with_checkpointing():
if torch_version() < (1, 6, 0):
pytest.skip("older pytorch doesn't support reduce_scatter")
world_size = 2
with temp_files_ctx(2) as (temp_file_name, unused):
mp.spawn(
_test_func, args=(world_size, temp_file_name, unused), nprocs=world_size, join=True,
)
def _test_func(rank, world_size, tempfile_name, unused):
result = dist_init(rank, world_size, tempfile_name, unused)
assert result, "Dist init failed"
# Keep initialization deterministic.
torch.manual_seed(0)
model = FullyShardedDataParallel(SimpleModuleWithCheckpointing().cuda())
optim = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Collect parameter sizes to ensure these stay consistent through the steps below.
expected_param_shapes = {name: tuple(param.shape) for name, param in model.named_parameters()}
# For clarity, this is what `expected_param_shapes` should look like depending on world size:
assert expected_param_shapes == {
"_fsdp_wrapped_module.flat_param": (12,),
"_fsdp_wrapped_module._fpw_module.ffn.1._fsdp_wrapped_module.flat_param": (6,),
}, expected_param_shapes
torch.manual_seed(1 + rank)
# Train for a step.
_train_step(model, optim, expected_param_shapes)
# Now do an eval step.
_eval_step(model, optim, expected_param_shapes)
# And finally do another train step.
_train_step(model, optim, expected_param_shapes)
teardown()
def _train_step(model, optim, expected_param_shapes):
# Prepare for training step.
optim.zero_grad()
model.train()
# Create input and run forward pass.
input = torch.randn(2, 3).cuda()
loss = model(input).sum()
_check_fwd_counter(model, 1)
_check_params(model, expected_param_shapes)
# Run backward pass.
loss.backward()
_check_fwd_counter(model, 0)
_check_params(model, expected_param_shapes)
# Finally, take a step.
optim.step()
_check_params(model, expected_param_shapes)
def _eval_step(model, optim, expected_param_shapes):
optim.zero_grad()
model.eval()
with torch.no_grad():
input = torch.randn(2, 3).cuda()
model(input).sum()
_check_fwd_counter(model, 0)
_check_params(model, expected_param_shapes)
def _check_params(model, expected_param_shapes):
current_param_shapes = {name: tuple(param.shape) for name, param in model.named_parameters()}
assert set(current_param_shapes.keys()) == set(expected_param_shapes.keys())
for key, current_shape in current_param_shapes.items():
expected_shape = expected_param_shapes[key]
assert (
current_shape == expected_shape
), f"Parameter {key} should have shape {expected_shape}, but found shape {current_shape}"
def _check_fwd_counter(model, expected_value):
current_value = model._fpw_module.ffn[1]._fsdp_wrapped_module.module._checkpoint_fwd_counter
assert (
current_value == expected_value
), f"forward counter of checkpointed submodule should be {expected_value}, but found {current_value}"
class SimpleModuleWithCheckpointing(nn.Module):
def __init__(self):
super().__init__()
self.ffn = nn.Sequential(
nn.Linear(3, 3),
FullyShardedDataParallel(checkpoint_wrapper(nn.Linear(3, 3), maintain_forward_counter=True)),
nn.Linear(3, 3),
)
def forward(self, x):
return self.ffn(x)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment