Unverified Commit b60f3db0 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[test] improve a test's coverage (#798)



* checkpoint + nonflat + mixed_precision

* make tests pass with expected errors

* addressed comments

* add a comment
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent 5f895f0b
...@@ -1252,6 +1252,26 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1252,6 +1252,26 @@ class FullyShardedDataParallel(nn.Module):
_registered = 0 _registered = 0
def _register_hook(t: torch.Tensor) -> torch.Tensor: def _register_hook(t: torch.Tensor) -> torch.Tensor:
# We don't register the pre_backward hook on the same tensor that has been
# returned from an inner FSDP, unless it is the first one. This does
# not cover all problematic cases though. A tensor not from an inner
# FSDP can cause problems too:
# ```
# x = layer1(input)
# state = [x] # better change to x.detach(), not fixed by the following if-condition
# x = inner_fsdp_module_layer2(x)
# state.append(x) # better change to x.detach(), but fixed by the following if-condition
# x = layer3(x)
# return x, state
# ```
# The tensors in `state`, if not detached, can be registered with
# backward hooks (in addition to the `x` on the last line). In that case,
# pre-backward hook can fire multiple times in the order that causes
# the outer FSDP to crash.
#
# The best practice is for modules to be wrapped by FSDP to return 1 and only
# 1 tensor to be used for backward. All other tensors returned should be
# detached.
nonlocal _registered nonlocal _registered
assert self._output_pre_backward_hook_registered is not None assert self._output_pre_backward_hook_registered is not None
if t.requires_grad and (_registered == 0 or id(t) not in self._output_pre_backward_hook_registered): if t.requires_grad and (_registered == 0 or id(t) not in self._output_pre_backward_hook_registered):
......
...@@ -3,7 +3,9 @@ ...@@ -3,7 +3,9 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
""" Test FSDP with an submodule that is FSDP(checkpoint_wrapper()). """ """ Test FSDP with an submodule that is FSDP(checkpoint_wrapper()) or checkpoint_wrapper(FSDP()). """
import contextlib
import pytest import pytest
import torch import torch
...@@ -12,65 +14,121 @@ import torch.distributed ...@@ -12,65 +14,121 @@ import torch.distributed
import torch.multiprocessing as mp import torch.multiprocessing as mp
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, temp_files_ctx from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, temp_files_ctx
@skip_if_single_gpu @skip_if_single_gpu
def test_train_and_eval_with_checkpointing(): @pytest.mark.parametrize("flatten", ["flat", "nonflat"])
if torch_version() < (1, 6, 0): @pytest.mark.parametrize("mixed_precision", ["fp16", "fp32"])
pytest.skip("older pytorch doesn't support reduce_scatter") @pytest.mark.parametrize("amp_context", ["autocast", "noautocast"])
@pytest.mark.parametrize("half_input", ["halfin", "fullin"])
@pytest.mark.parametrize("fsdp_wrap_ckpt", ["F->C", "C->F"])
def test_train_and_eval_with_checkpointing(flatten, mixed_precision, amp_context, half_input, fsdp_wrap_ckpt):
flatten = flatten == "flat"
mixed_precision = mixed_precision == "fp16"
amp_context = amp_context == "autocast"
half_input = half_input == "halfin"
fsdp_wrap_ckpt = fsdp_wrap_ckpt == "F->C"
world_size = 2 world_size = 2
with temp_files_ctx(2) as (temp_file_name, unused): with temp_files_ctx(2) as (temp_file_name, unused):
mp.spawn( mp.spawn(
_test_func, args=(world_size, temp_file_name, unused), nprocs=world_size, join=True, _test_func,
args=(
world_size,
temp_file_name,
unused,
flatten,
mixed_precision,
amp_context,
half_input,
fsdp_wrap_ckpt,
),
nprocs=world_size,
join=True,
) )
def _test_func(rank, world_size, tempfile_name, unused): def _test_func(
rank, world_size, tempfile_name, unused, flatten, mixed_precision, amp_context, half_input, fsdp_wrap_ckpt
):
result = dist_init(rank, world_size, tempfile_name, unused) result = dist_init(rank, world_size, tempfile_name, unused)
assert result, "Dist init failed" assert result, "Dist init failed"
# Keep initialization deterministic. # Keep initialization deterministic.
torch.manual_seed(0) torch.manual_seed(0)
model = FullyShardedDataParallel(SimpleModuleWithCheckpointing().cuda()) model = FSDP(
SimpleModuleWithCheckpointing(flatten, mixed_precision, fsdp_wrap_ckpt).cuda(),
flatten_parameters=flatten,
mixed_precision=mixed_precision,
)
optim = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) optim = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Collect parameter sizes to ensure these stay consistent through the steps below. # Collect parameter sizes to ensure these stay consistent through the steps below.
expected_param_shapes = {name: tuple(param.shape) for name, param in model.named_parameters()} expected_param_shapes = {name: tuple(param.shape) for name, param in model.named_parameters()}
# For clarity, this is what `expected_param_shapes` should look like depending on world size: # For clarity, this is what `expected_param_shapes` should look like depending on world size:
assert expected_param_shapes == { if not flatten:
"_fsdp_wrapped_module.flat_param_0": (12,), assert expected_param_shapes == {
"_fsdp_wrapped_module._fpw_module.ffn.1._fsdp_wrapped_module.flat_param_0": (6,), "ffn.0.weight": (5,),
}, expected_param_shapes "ffn.0.bias": (2,),
"ffn.1.weight": (5,),
"ffn.1.bias": (2,),
"ffn.2.weight": (5,),
"ffn.2.bias": (2,),
}
else:
assert expected_param_shapes == {
"_fsdp_wrapped_module.flat_param_0": (12,),
"_fsdp_wrapped_module._fpw_module.ffn.1._fsdp_wrapped_module.flat_param_0": (6,),
}, expected_param_shapes
torch.manual_seed(1 + rank) torch.manual_seed(1 + rank)
# Train for a step. # Expecting an known bug in 4 out of 32 cases.
_train_step(model, optim, expected_param_shapes) context_train1 = contextlib.suppress()
context_train2 = contextlib.suppress()
if fsdp_wrap_ckpt and mixed_precision and not flatten:
context_train1 = pytest.raises(SystemError)
context_train2 = pytest.raises(ValueError)
with context_train1:
# Train for a step.
_train_step(model, optim, expected_param_shapes, amp_context, mixed_precision, half_input)
# Now do an eval step. # Now do an eval step.
_eval_step(model, optim, expected_param_shapes) _eval_step(model, optim, expected_param_shapes, amp_context, mixed_precision, half_input)
# And finally do another train step. with context_train2:
_train_step(model, optim, expected_param_shapes) # And finally do another train step.
_train_step(model, optim, expected_param_shapes, amp_context, mixed_precision, half_input)
teardown() teardown()
def _train_step(model, optim, expected_param_shapes): def _train_step(model, optim, expected_param_shapes, amp_context, mixed_precision, half_input):
# Prepare for training step. # Prepare for training step.
optim.zero_grad() optim.zero_grad()
model.train() model.train()
# Create input and run forward pass. # Create input and run forward pass.
input = torch.randn(2, 3).cuda() input = torch.randn(2, 3).cuda()
loss = model(input).sum()
# Make it FP16 when it is OK to do so.
if (amp_context and half_input) or (mixed_precision and half_input):
input = input.half()
context = contextlib.suppress()
if amp_context:
context = torch.cuda.amp.autocast(True)
with context:
loss = model(input).sum()
_check_params(model, expected_param_shapes) _check_params(model, expected_param_shapes)
# Run backward pass. # Run backward pass.
...@@ -82,12 +140,18 @@ def _train_step(model, optim, expected_param_shapes): ...@@ -82,12 +140,18 @@ def _train_step(model, optim, expected_param_shapes):
_check_params(model, expected_param_shapes) _check_params(model, expected_param_shapes)
def _eval_step(model, optim, expected_param_shapes): def _eval_step(model, optim, expected_param_shapes, amp_context, mixed_precision, half_input):
optim.zero_grad() optim.zero_grad()
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
input = torch.randn(2, 3).cuda() input = torch.randn(2, 3).cuda()
model(input).sum() if (amp_context and half_input) or (mixed_precision and half_input):
input = input.half()
context = contextlib.suppress()
if amp_context:
context = torch.cuda.amp.autocast(True)
with context:
model(input).sum()
_check_params(model, expected_param_shapes) _check_params(model, expected_param_shapes)
...@@ -102,11 +166,18 @@ def _check_params(model, expected_param_shapes): ...@@ -102,11 +166,18 @@ def _check_params(model, expected_param_shapes):
class SimpleModuleWithCheckpointing(nn.Module): class SimpleModuleWithCheckpointing(nn.Module):
def __init__(self): def __init__(self, flatten, mixed_precision, fsdp_wrap_ckpt):
super().__init__() super().__init__()
self.ffn = nn.Sequential( if fsdp_wrap_ckpt:
nn.Linear(3, 3), FullyShardedDataParallel(checkpoint_wrapper(nn.Linear(3, 3))), nn.Linear(3, 3), middle_module = FSDP(
) checkpoint_wrapper(nn.Linear(3, 3)), flatten_parameters=flatten, mixed_precision=mixed_precision
)
else:
middle_module = checkpoint_wrapper(
FSDP(nn.Linear(3, 3), flatten_parameters=flatten, mixed_precision=mixed_precision)
)
self.ffn = nn.Sequential(nn.Linear(3, 3), middle_module, nn.Linear(3, 3))
def forward(self, x): def forward(self, x):
return self.ffn(x) return self.ffn(x)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment