Unverified Commit b75a5e26 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] FSDP: fix the corner case of all params are in the children (#441)

* [fix] FSDP corner case of all params at in the children

* lint

* fix

* tradeoff

* fix doc build

* review comments
parent bd04f21f
......@@ -3,5 +3,5 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from .fully_sharded_data_parallel import FullyShardedDataParallel
from .fully_sharded_data_parallel import FullyShardedDataParallel, TrainingState
from .sharded_ddp import ShardedDataParallel
......@@ -8,7 +8,7 @@ import copy
from enum import Enum, auto
import functools
from math import inf
from typing import TYPE_CHECKING, Any, Dict, Generator, List, NamedTuple, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, NamedTuple, Optional, Tuple, Union
import torch
from torch.autograd import Variable
......@@ -94,6 +94,13 @@ class FullyShardedDataParallel(nn.Module):
since FSDP will shard parameters in-place and this will break any
previously initialized optimizers.
.. warning::
If you wrap every parameter inside a nested FSDP and leaving the outer
FSDP empty without any parameter, checkpointing activation may trigger
an assert on the backward pass. The solution is to leave some parameters
to the outer FSDP.
Args:
module (nn.Module):
module to checkpoint
......@@ -172,7 +179,8 @@ class FullyShardedDataParallel(nn.Module):
# shard any leftover parameters.
params = list(p for p in module.parameters() if not hasattr(p, "_is_sharded"))
if self.flatten_parameters and len(params) > 0:
self._has_params = len(params) > 0
if self.flatten_parameters and self._has_params:
self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(module, param_list=params)
del module # free original module in case it helps garbage collection
self.params = [self._fsdp_wrapped_module.flat_param]
......@@ -191,7 +199,7 @@ class FullyShardedDataParallel(nn.Module):
# Flag to indicate if we require gradient reduction in the backward
# pass. This will be False when inside the no_sync context manager.
self.require_backward_grad_sync: bool = True
self._require_backward_grad_sync: bool = True
# Enum to indicate if we're in the forward/backward pass, idle, etc.
self.training_state = TrainingState.IDLE
......@@ -251,9 +259,13 @@ class FullyShardedDataParallel(nn.Module):
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
"""
# We don't call torch.cuda.synchronize() here, since clipping can be
# inside the train loop and we probably don't want to force a GPU-CPU sync.
# _lazy_init should be sufficient, since it will force the other streams
# to sync with the default stream (via _wait_for_previous_optim_step).
self._lazy_init()
assert self._is_root, "clip_grad_norm should only be called on the root (parent) instance"
assert self.training_state == TrainingState.IDLE
self.assert_state(TrainingState.IDLE)
max_norm = float(max_norm)
norm_type = float(norm_type)
......@@ -495,13 +507,13 @@ class FullyShardedDataParallel(nn.Module):
old_flags = []
for m in self.modules(): # includes self
if isinstance(m, FullyShardedDataParallel):
old_flags.append((m, m.require_backward_grad_sync))
m.require_backward_grad_sync = False
old_flags.append((m, m._require_backward_grad_sync))
m._require_backward_grad_sync = False
try:
yield
finally:
for m, old_flag in old_flags:
m.require_backward_grad_sync = old_flag
m._require_backward_grad_sync = old_flag
@contextlib.contextmanager
def summon_full_params(self, recurse: bool = True) -> Generator:
......@@ -546,12 +558,14 @@ class FullyShardedDataParallel(nn.Module):
def _reset_lazy_init(self) -> None:
"""Reset instance so :func:`_lazy_init` will run on the next forward."""
self._is_root: Optional[bool] = None
self._queue_wait_for_post_backward_closure: Optional[Callable] = None
self._streams: Dict[str, torch.cuda.Stream] = {}
self._reducer: Optional[ReduceScatterBucketer] = None
def _lazy_init(self) -> None:
"""Initialization steps that should happen lazily, typically right
before the first forward pass."""
before the first forward pass.
"""
# Initialize param attributes lazily, in case the param's dtype or
# device changes after __init__.
for p in self.params:
......@@ -661,14 +675,26 @@ class FullyShardedDataParallel(nn.Module):
"""
if self._is_root is not None:
return
# No FullyShardedDataParallel instance wraps this, else _is_root would be set to False
# No FullyShardedDataParallel instance wraps this, else _is_root would be set to False.
self._is_root = True
# As the root, we now set all children instances to False.
assert self._queue_wait_for_post_backward_closure is None
self._queue_wait_for_post_backward_closure = self._queue_wait_for_post_backward
# As the root, we now set all children instances to False and
# give them a closure to try to queue a wait_for_post_backward.
self.children_share_process_group = True
for n, m in self.named_modules():
# `n != ""` excludes self.
if n != "" and isinstance(m, FullyShardedDataParallel):
assert m._is_root is None
m._is_root = False
# When root instance doesn't have params, allow children instances
# to queue the post_backward hook.
#
# TODO (Min): we should think if we can have a empty param at the root
# so that root always have a callback on the backward graph.
if not self._has_params:
assert m._queue_wait_for_post_backward_closure is None
m._queue_wait_for_post_backward_closure = self._queue_wait_for_post_backward
if m.process_group != self.process_group:
self.children_share_process_group = False
......@@ -779,7 +805,11 @@ class FullyShardedDataParallel(nn.Module):
"""Register backward hooks to reshard params and reduce-scatter grads."""
if not torch.is_grad_enabled():
return # don't register grad hooks if grad isn't enabled
self._post_backward_callback_queued = False
if self._is_root:
# This actually means that only root instance has this field
# defined. Accidentally accessing this field will assert on all
# other instances, giving us a nice bug checker.
self._post_backward_callback_queued = False
for p in self.params:
if p.requires_grad:
if hasattr(p, "_shard_bwd_hook"):
......@@ -825,14 +855,15 @@ class FullyShardedDataParallel(nn.Module):
# pre_backward_hook.
self._free_fp16_param_shard([param])
# Enqueue a callback at the end of the backward pass to ensure that all
# post-backward work has finished. We only need one callback and it only
# needs to be called from the outer-most (root) instance.
if self._is_root and not self._post_backward_callback_queued:
self._post_backward_callback_queued = True
Variable._execution_engine.queue_callback(self._wait_for_post_backward)
# (try to) Enqueue a callback at the end of the backward pass to ensure that all
# post-backward work has finished. We only need one callback and all instances
# of FSDP (root and children) make this attempt here to queue to ensure it is queued
# no matter which instance(s) has(have) params.
assert self._queue_wait_for_post_backward_closure is not None or not self._is_root
if self._queue_wait_for_post_backward_closure is not None:
self._queue_wait_for_post_backward_closure()
if not self.require_backward_grad_sync:
if not self._require_backward_grad_sync:
return
# Wait for all work in the current stream to finish, then start the
......@@ -888,9 +919,22 @@ class FullyShardedDataParallel(nn.Module):
# Don't let this memory get reused until after the transfers.
reduced_grad.record_stream(torch.cuda.current_stream())
def _queue_wait_for_post_backward(self) -> None:
"""Try to queue a `wait_for_post_backward` callback.
Only called on root and only queue one callback.
But can be called by children FSDPs via a closure in case the
root instance doesn't own any params.
"""
assert self._is_root
self.assert_state(TrainingState.BACKWARD)
if not self._post_backward_callback_queued:
self._post_backward_callback_queued = True
Variable._execution_engine.queue_callback(self._wait_for_post_backward)
@torch.no_grad()
def _wait_for_post_backward(self) -> None:
"""Wait for post-backward work to finish. Only called on root instance."""
"""Wait for post-backward work to finish. Only called on root instance.
"""
assert self._is_root
self.assert_state(TrainingState.BACKWARD)
# Flush any unreduced buckets in the post_backward stream.
......
......@@ -15,7 +15,7 @@ from parameterized import parameterized
import torch
from torch import nn
from fairscale.nn.data_parallel import FullyShardedDataParallel
from fairscale.nn.data_parallel import FullyShardedDataParallel, TrainingState
from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper
from fairscale.utils.testing import (
DeviceAndTypeCheckModule,
......@@ -65,8 +65,8 @@ class DistributedTest(unittest.TestCase):
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm, norm_type)
optim.step()
if hasattr(model, "assert_idle"):
model.assert_idle()
if isinstance(model, FullyShardedDataParallel):
model.assert_state(TrainingState.IDLE)
return loss.detach()
@staticmethod
......@@ -182,6 +182,12 @@ class TestComparisonToPyTorchDDP(DistributedTest):
PyTorch DDP vs. FullyShardedDataParallel.
"""
def test_nested_all_wrapped_model(self):
config = {"mixed_precision": True}
model_fn = functools.partial(NestedWrappedModule, wrap_everything=True)
test_fn = functools.partial(self._test_identical_outputs, model_fn, config)
spawn_and_init(test_fn)
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_transformer_parameterized(self, config):
# Test every combination of these options:
......@@ -730,7 +736,7 @@ class TransformerWithSharedParams(nn.Module):
class NestedWrappedModule(nn.Module):
def __init__(self, group, wrapper_config):
def __init__(self, group, wrapper_config, wrap_everything=False):
super().__init__()
self.rank = group.rank()
self.world_size = group.size()
......@@ -749,6 +755,15 @@ class NestedWrappedModule(nn.Module):
nn.Linear(4, 8),
)
# Wrap all modules triggers a corner case where root FSDP doesn't have any params.
if wrap_everything:
self.module = nn.Sequential(
_maybe_wrap(nn.Linear(8, 4)),
_maybe_wrap(nn.Linear(4, 16)),
_maybe_wrap(nn.Linear(16, 4)),
_maybe_wrap(nn.Linear(4, 8)),
)
def get_input(self, device):
torch.manual_seed(1 + self.rank) # keep everything deterministic
return (torch.rand(4, 8, device=device),)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment