"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "c27bed4521a5d0c2128df330476721bfbcf0f99d"
Unverified Commit 370b8483 authored by Pete's avatar Pete Committed by GitHub
Browse files

Use original forward pass directly when in eval mode from within checkpoint wrapper (#709)

* add failing test

* add fix

* use 'torch.is_grad_enabled()' instead of 'module.training'

* Revert "add failing test"

This reverts commit 1c34242208f9b2c5fa6c8f181434c2be6d7cdbc0.

* add simple test

* improve test

* add check for fwd_counter

* revert typing/format changes

* move to new test file

* CHANGELOG

* remove old test

* fix import order

* fix test to be compat with torch 1.6.0

* clean up

* comments

* isort 🤦
parent d60fc284
...@@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## NEXT - TBD ## NEXT - TBD
### Fixed ### Fixed
- checkpointing: use dummy tensor to ensure backward pass is called [#701] - checkpointing: use dummy tensor to ensure backward pass is called [#701]
- checkpointing: ensure internal fwd counter is not incremented in eval mode [#709]
- FSDP: fixed bug where buffers returned in `state_dict()` could still be half precision when `mixed_precision` is set to `True`. - FSDP: fixed bug where buffers returned in `state_dict()` could still be half precision when `mixed_precision` is set to `True`.
### Added ### Added
......
...@@ -157,10 +157,20 @@ def checkpoint_wrapper( ...@@ -157,10 +157,20 @@ def checkpoint_wrapper(
def _checkpointed_forward( def _checkpointed_forward(
original_forward: Any, weak_self: Any, offload_to_cpu: bool, *args: Any, **kwargs: Any original_forward: Any, weak_self: Any, offload_to_cpu: bool, *args: Any, **kwargs: Any
) -> Any: ) -> Any:
module = weak_self()
# If gradients are disabled, just use original `.forward()` method directly.
# Doing so also ensures the internal fwd counter is not incremented in the forward pass,
# which would be an issue during eval since there wouldn't be a corresponding backward pass
# to decrement the fwd counter.
# See https://github.com/facebookresearch/fairscale/pull/709.
if not torch.is_grad_enabled():
return original_forward(module, *args, **kwargs)
# Autograd Functions in PyTorch work best with positional args, since # Autograd Functions in PyTorch work best with positional args, since
# the backward must return gradients (or None) for every input argument. # the backward must return gradients (or None) for every input argument.
# We can flatten keyword arguments to make this easier. # We can flatten keyword arguments to make this easier.
args = (weak_self(),) + args args = (module,) + args
kwarg_keys, flat_args = pack_kwargs(*args, **kwargs) kwarg_keys, flat_args = pack_kwargs(*args, **kwargs)
parent_ctx_dict: Dict[str, Any] = { parent_ctx_dict: Dict[str, Any] = {
"offload": offload_to_cpu, "offload": offload_to_cpu,
...@@ -227,8 +237,7 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -227,8 +237,7 @@ class CheckpointFunction(torch.autograd.Function):
*args: Any, *args: Any,
**kwargs: Any **kwargs: Any
) -> Any: ) -> Any:
if torch.is_grad_enabled(): # grad may be disabled, e.g., during validation torch_checkpoint.check_backward_validity(args)
torch_checkpoint.check_backward_validity(args)
ctx.run_function = run_function ctx.run_function = run_function
ctx.kwarg_keys = kwarg_keys ctx.kwarg_keys = kwarg_keys
......
...@@ -9,3 +9,4 @@ tests/nn/data_parallel/test_fsdp_summon_full_params.py ...@@ -9,3 +9,4 @@ tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_fsdp_input.py tests/nn/data_parallel/test_fsdp_input.py
tests/nn/data_parallel/test_fsdp_optimizer_utils.py tests/nn/data_parallel/test_fsdp_optimizer_utils.py
tests/nn/data_parallel/test_fsdp.py tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp_with_checkpoint_wrapper.py
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
""" Test FSDP with an submodule that is FSDP(checkpoint_wrapper()). """
import pytest
import torch
from torch import nn
import torch.distributed
import torch.multiprocessing as mp
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, temp_files_ctx, torch_version
@skip_if_single_gpu
def test_train_and_eval_with_checkpointing():
if torch_version() < (1, 6, 0):
pytest.skip("older pytorch doesn't support reduce_scatter")
world_size = 2
with temp_files_ctx(2) as (temp_file_name, unused):
mp.spawn(
_test_func, args=(world_size, temp_file_name, unused), nprocs=world_size, join=True,
)
def _test_func(rank, world_size, tempfile_name, unused):
result = dist_init(rank, world_size, tempfile_name, unused)
assert result, "Dist init failed"
# Keep initialization deterministic.
torch.manual_seed(0)
model = FullyShardedDataParallel(SimpleModuleWithCheckpointing().cuda())
optim = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Collect parameter sizes to ensure these stay consistent through the steps below.
expected_param_shapes = {name: tuple(param.shape) for name, param in model.named_parameters()}
# For clarity, this is what `expected_param_shapes` should look like depending on world size:
assert expected_param_shapes == {
"_fsdp_wrapped_module.flat_param": (12,),
"_fsdp_wrapped_module._fpw_module.ffn.1._fsdp_wrapped_module.flat_param": (6,),
}, expected_param_shapes
torch.manual_seed(1 + rank)
# Train for a step.
_train_step(model, optim, expected_param_shapes)
# Now do an eval step.
_eval_step(model, optim, expected_param_shapes)
# And finally do another train step.
_train_step(model, optim, expected_param_shapes)
teardown()
def _train_step(model, optim, expected_param_shapes):
# Prepare for training step.
optim.zero_grad()
model.train()
# Create input and run forward pass.
input = torch.randn(2, 3).cuda()
loss = model(input).sum()
_check_fwd_counter(model, 1)
_check_params(model, expected_param_shapes)
# Run backward pass.
loss.backward()
_check_fwd_counter(model, 0)
_check_params(model, expected_param_shapes)
# Finally, take a step.
optim.step()
_check_params(model, expected_param_shapes)
def _eval_step(model, optim, expected_param_shapes):
optim.zero_grad()
model.eval()
with torch.no_grad():
input = torch.randn(2, 3).cuda()
model(input).sum()
_check_fwd_counter(model, 0)
_check_params(model, expected_param_shapes)
def _check_params(model, expected_param_shapes):
current_param_shapes = {name: tuple(param.shape) for name, param in model.named_parameters()}
assert set(current_param_shapes.keys()) == set(expected_param_shapes.keys())
for key, current_shape in current_param_shapes.items():
expected_shape = expected_param_shapes[key]
assert (
current_shape == expected_shape
), f"Parameter {key} should have shape {expected_shape}, but found shape {current_shape}"
def _check_fwd_counter(model, expected_value):
current_value = model._fpw_module.ffn[1]._fsdp_wrapped_module.module._checkpoint_fwd_counter
assert (
current_value == expected_value
), f"forward counter of checkpointed submodule should be {expected_value}, but found {current_value}"
class SimpleModuleWithCheckpointing(nn.Module):
def __init__(self):
super().__init__()
self.ffn = nn.Sequential(
nn.Linear(3, 3),
FullyShardedDataParallel(checkpoint_wrapper(nn.Linear(3, 3), maintain_forward_counter=True)),
nn.Linear(3, 3),
)
def forward(self, x):
return self.ffn(x)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment