"vscode:/vscode.git/clone" did not exist on "4d247b084164e8c6873dd8c7d71e083d320c3197"
Unverified Commit b75a5e26 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] FSDP: fix the corner case of all params are in the children (#441)

* [fix] FSDP corner case of all params at in the children

* lint

* fix

* tradeoff

* fix doc build

* review comments
parent bd04f21f
...@@ -3,5 +3,5 @@ ...@@ -3,5 +3,5 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from .fully_sharded_data_parallel import FullyShardedDataParallel from .fully_sharded_data_parallel import FullyShardedDataParallel, TrainingState
from .sharded_ddp import ShardedDataParallel from .sharded_ddp import ShardedDataParallel
...@@ -8,7 +8,7 @@ import copy ...@@ -8,7 +8,7 @@ import copy
from enum import Enum, auto from enum import Enum, auto
import functools import functools
from math import inf from math import inf
from typing import TYPE_CHECKING, Any, Dict, Generator, List, NamedTuple, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, NamedTuple, Optional, Tuple, Union
import torch import torch
from torch.autograd import Variable from torch.autograd import Variable
...@@ -94,6 +94,13 @@ class FullyShardedDataParallel(nn.Module): ...@@ -94,6 +94,13 @@ class FullyShardedDataParallel(nn.Module):
since FSDP will shard parameters in-place and this will break any since FSDP will shard parameters in-place and this will break any
previously initialized optimizers. previously initialized optimizers.
.. warning::
If you wrap every parameter inside a nested FSDP and leaving the outer
FSDP empty without any parameter, checkpointing activation may trigger
an assert on the backward pass. The solution is to leave some parameters
to the outer FSDP.
Args: Args:
module (nn.Module): module (nn.Module):
module to checkpoint module to checkpoint
...@@ -172,7 +179,8 @@ class FullyShardedDataParallel(nn.Module): ...@@ -172,7 +179,8 @@ class FullyShardedDataParallel(nn.Module):
# shard any leftover parameters. # shard any leftover parameters.
params = list(p for p in module.parameters() if not hasattr(p, "_is_sharded")) params = list(p for p in module.parameters() if not hasattr(p, "_is_sharded"))
if self.flatten_parameters and len(params) > 0: self._has_params = len(params) > 0
if self.flatten_parameters and self._has_params:
self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(module, param_list=params) self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(module, param_list=params)
del module # free original module in case it helps garbage collection del module # free original module in case it helps garbage collection
self.params = [self._fsdp_wrapped_module.flat_param] self.params = [self._fsdp_wrapped_module.flat_param]
...@@ -191,7 +199,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -191,7 +199,7 @@ class FullyShardedDataParallel(nn.Module):
# Flag to indicate if we require gradient reduction in the backward # Flag to indicate if we require gradient reduction in the backward
# pass. This will be False when inside the no_sync context manager. # pass. This will be False when inside the no_sync context manager.
self.require_backward_grad_sync: bool = True self._require_backward_grad_sync: bool = True
# Enum to indicate if we're in the forward/backward pass, idle, etc. # Enum to indicate if we're in the forward/backward pass, idle, etc.
self.training_state = TrainingState.IDLE self.training_state = TrainingState.IDLE
...@@ -251,9 +259,13 @@ class FullyShardedDataParallel(nn.Module): ...@@ -251,9 +259,13 @@ class FullyShardedDataParallel(nn.Module):
.. warning:: This needs to be called on all ranks, since synchronization .. warning:: This needs to be called on all ranks, since synchronization
primitives will be used. primitives will be used.
""" """
# We don't call torch.cuda.synchronize() here, since clipping can be
# inside the train loop and we probably don't want to force a GPU-CPU sync.
# _lazy_init should be sufficient, since it will force the other streams
# to sync with the default stream (via _wait_for_previous_optim_step).
self._lazy_init() self._lazy_init()
assert self._is_root, "clip_grad_norm should only be called on the root (parent) instance" assert self._is_root, "clip_grad_norm should only be called on the root (parent) instance"
assert self.training_state == TrainingState.IDLE self.assert_state(TrainingState.IDLE)
max_norm = float(max_norm) max_norm = float(max_norm)
norm_type = float(norm_type) norm_type = float(norm_type)
...@@ -495,13 +507,13 @@ class FullyShardedDataParallel(nn.Module): ...@@ -495,13 +507,13 @@ class FullyShardedDataParallel(nn.Module):
old_flags = [] old_flags = []
for m in self.modules(): # includes self for m in self.modules(): # includes self
if isinstance(m, FullyShardedDataParallel): if isinstance(m, FullyShardedDataParallel):
old_flags.append((m, m.require_backward_grad_sync)) old_flags.append((m, m._require_backward_grad_sync))
m.require_backward_grad_sync = False m._require_backward_grad_sync = False
try: try:
yield yield
finally: finally:
for m, old_flag in old_flags: for m, old_flag in old_flags:
m.require_backward_grad_sync = old_flag m._require_backward_grad_sync = old_flag
@contextlib.contextmanager @contextlib.contextmanager
def summon_full_params(self, recurse: bool = True) -> Generator: def summon_full_params(self, recurse: bool = True) -> Generator:
...@@ -546,12 +558,14 @@ class FullyShardedDataParallel(nn.Module): ...@@ -546,12 +558,14 @@ class FullyShardedDataParallel(nn.Module):
def _reset_lazy_init(self) -> None: def _reset_lazy_init(self) -> None:
"""Reset instance so :func:`_lazy_init` will run on the next forward.""" """Reset instance so :func:`_lazy_init` will run on the next forward."""
self._is_root: Optional[bool] = None self._is_root: Optional[bool] = None
self._queue_wait_for_post_backward_closure: Optional[Callable] = None
self._streams: Dict[str, torch.cuda.Stream] = {} self._streams: Dict[str, torch.cuda.Stream] = {}
self._reducer: Optional[ReduceScatterBucketer] = None self._reducer: Optional[ReduceScatterBucketer] = None
def _lazy_init(self) -> None: def _lazy_init(self) -> None:
"""Initialization steps that should happen lazily, typically right """Initialization steps that should happen lazily, typically right
before the first forward pass.""" before the first forward pass.
"""
# Initialize param attributes lazily, in case the param's dtype or # Initialize param attributes lazily, in case the param's dtype or
# device changes after __init__. # device changes after __init__.
for p in self.params: for p in self.params:
...@@ -661,14 +675,26 @@ class FullyShardedDataParallel(nn.Module): ...@@ -661,14 +675,26 @@ class FullyShardedDataParallel(nn.Module):
""" """
if self._is_root is not None: if self._is_root is not None:
return return
# No FullyShardedDataParallel instance wraps this, else _is_root would be set to False # No FullyShardedDataParallel instance wraps this, else _is_root would be set to False.
self._is_root = True self._is_root = True
# As the root, we now set all children instances to False. assert self._queue_wait_for_post_backward_closure is None
self._queue_wait_for_post_backward_closure = self._queue_wait_for_post_backward
# As the root, we now set all children instances to False and
# give them a closure to try to queue a wait_for_post_backward.
self.children_share_process_group = True self.children_share_process_group = True
for n, m in self.named_modules(): for n, m in self.named_modules():
# `n != ""` excludes self.
if n != "" and isinstance(m, FullyShardedDataParallel): if n != "" and isinstance(m, FullyShardedDataParallel):
assert m._is_root is None assert m._is_root is None
m._is_root = False m._is_root = False
# When root instance doesn't have params, allow children instances
# to queue the post_backward hook.
#
# TODO (Min): we should think if we can have a empty param at the root
# so that root always have a callback on the backward graph.
if not self._has_params:
assert m._queue_wait_for_post_backward_closure is None
m._queue_wait_for_post_backward_closure = self._queue_wait_for_post_backward
if m.process_group != self.process_group: if m.process_group != self.process_group:
self.children_share_process_group = False self.children_share_process_group = False
...@@ -779,6 +805,10 @@ class FullyShardedDataParallel(nn.Module): ...@@ -779,6 +805,10 @@ class FullyShardedDataParallel(nn.Module):
"""Register backward hooks to reshard params and reduce-scatter grads.""" """Register backward hooks to reshard params and reduce-scatter grads."""
if not torch.is_grad_enabled(): if not torch.is_grad_enabled():
return # don't register grad hooks if grad isn't enabled return # don't register grad hooks if grad isn't enabled
if self._is_root:
# This actually means that only root instance has this field
# defined. Accidentally accessing this field will assert on all
# other instances, giving us a nice bug checker.
self._post_backward_callback_queued = False self._post_backward_callback_queued = False
for p in self.params: for p in self.params:
if p.requires_grad: if p.requires_grad:
...@@ -825,14 +855,15 @@ class FullyShardedDataParallel(nn.Module): ...@@ -825,14 +855,15 @@ class FullyShardedDataParallel(nn.Module):
# pre_backward_hook. # pre_backward_hook.
self._free_fp16_param_shard([param]) self._free_fp16_param_shard([param])
# Enqueue a callback at the end of the backward pass to ensure that all # (try to) Enqueue a callback at the end of the backward pass to ensure that all
# post-backward work has finished. We only need one callback and it only # post-backward work has finished. We only need one callback and all instances
# needs to be called from the outer-most (root) instance. # of FSDP (root and children) make this attempt here to queue to ensure it is queued
if self._is_root and not self._post_backward_callback_queued: # no matter which instance(s) has(have) params.
self._post_backward_callback_queued = True assert self._queue_wait_for_post_backward_closure is not None or not self._is_root
Variable._execution_engine.queue_callback(self._wait_for_post_backward) if self._queue_wait_for_post_backward_closure is not None:
self._queue_wait_for_post_backward_closure()
if not self.require_backward_grad_sync: if not self._require_backward_grad_sync:
return return
# Wait for all work in the current stream to finish, then start the # Wait for all work in the current stream to finish, then start the
...@@ -888,9 +919,22 @@ class FullyShardedDataParallel(nn.Module): ...@@ -888,9 +919,22 @@ class FullyShardedDataParallel(nn.Module):
# Don't let this memory get reused until after the transfers. # Don't let this memory get reused until after the transfers.
reduced_grad.record_stream(torch.cuda.current_stream()) reduced_grad.record_stream(torch.cuda.current_stream())
def _queue_wait_for_post_backward(self) -> None:
"""Try to queue a `wait_for_post_backward` callback.
Only called on root and only queue one callback.
But can be called by children FSDPs via a closure in case the
root instance doesn't own any params.
"""
assert self._is_root
self.assert_state(TrainingState.BACKWARD)
if not self._post_backward_callback_queued:
self._post_backward_callback_queued = True
Variable._execution_engine.queue_callback(self._wait_for_post_backward)
@torch.no_grad() @torch.no_grad()
def _wait_for_post_backward(self) -> None: def _wait_for_post_backward(self) -> None:
"""Wait for post-backward work to finish. Only called on root instance.""" """Wait for post-backward work to finish. Only called on root instance.
"""
assert self._is_root assert self._is_root
self.assert_state(TrainingState.BACKWARD) self.assert_state(TrainingState.BACKWARD)
# Flush any unreduced buckets in the post_backward stream. # Flush any unreduced buckets in the post_backward stream.
......
...@@ -15,7 +15,7 @@ from parameterized import parameterized ...@@ -15,7 +15,7 @@ from parameterized import parameterized
import torch import torch
from torch import nn from torch import nn
from fairscale.nn.data_parallel import FullyShardedDataParallel from fairscale.nn.data_parallel import FullyShardedDataParallel, TrainingState
from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper
from fairscale.utils.testing import ( from fairscale.utils.testing import (
DeviceAndTypeCheckModule, DeviceAndTypeCheckModule,
...@@ -65,8 +65,8 @@ class DistributedTest(unittest.TestCase): ...@@ -65,8 +65,8 @@ class DistributedTest(unittest.TestCase):
else: else:
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm, norm_type) torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm, norm_type)
optim.step() optim.step()
if hasattr(model, "assert_idle"): if isinstance(model, FullyShardedDataParallel):
model.assert_idle() model.assert_state(TrainingState.IDLE)
return loss.detach() return loss.detach()
@staticmethod @staticmethod
...@@ -182,6 +182,12 @@ class TestComparisonToPyTorchDDP(DistributedTest): ...@@ -182,6 +182,12 @@ class TestComparisonToPyTorchDDP(DistributedTest):
PyTorch DDP vs. FullyShardedDataParallel. PyTorch DDP vs. FullyShardedDataParallel.
""" """
def test_nested_all_wrapped_model(self):
config = {"mixed_precision": True}
model_fn = functools.partial(NestedWrappedModule, wrap_everything=True)
test_fn = functools.partial(self._test_identical_outputs, model_fn, config)
spawn_and_init(test_fn)
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test) @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_transformer_parameterized(self, config): def test_transformer_parameterized(self, config):
# Test every combination of these options: # Test every combination of these options:
...@@ -730,7 +736,7 @@ class TransformerWithSharedParams(nn.Module): ...@@ -730,7 +736,7 @@ class TransformerWithSharedParams(nn.Module):
class NestedWrappedModule(nn.Module): class NestedWrappedModule(nn.Module):
def __init__(self, group, wrapper_config): def __init__(self, group, wrapper_config, wrap_everything=False):
super().__init__() super().__init__()
self.rank = group.rank() self.rank = group.rank()
self.world_size = group.size() self.world_size = group.size()
...@@ -749,6 +755,15 @@ class NestedWrappedModule(nn.Module): ...@@ -749,6 +755,15 @@ class NestedWrappedModule(nn.Module):
nn.Linear(4, 8), nn.Linear(4, 8),
) )
# Wrap all modules triggers a corner case where root FSDP doesn't have any params.
if wrap_everything:
self.module = nn.Sequential(
_maybe_wrap(nn.Linear(8, 4)),
_maybe_wrap(nn.Linear(4, 16)),
_maybe_wrap(nn.Linear(16, 4)),
_maybe_wrap(nn.Linear(4, 8)),
)
def get_input(self, device): def get_input(self, device):
torch.manual_seed(1 + self.rank) # keep everything deterministic torch.manual_seed(1 + self.rank) # keep everything deterministic
return (torch.rand(4, 8, device=device),) return (torch.rand(4, 8, device=device),)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment