Unverified Commit ba7df621 authored by Yanli Zhao's avatar Yanli Zhao Committed by GitHub
Browse files

[FSDP] Move final backward callback queueing to pre-backward hook of root instance (#753)

Move final backward callback to pre-backward hook of root FSDP instance

Summary:

Move final backward callback to pre-backward hook of root FSDP instance,
so that it is always attached to the outer most backward call and fired
after all backward calls are completed.

Also added flags to check final backward callback is fired when final
backward callback is required.

If root FSDP is checkpointed and called multiple times in forward,
check pointer counter is used to make sure final backward callback is queued inside last inner backward
call as well.

Test Plan: unit tests

Reviewers:

Subscribers:

Tasks:

Tags:

* reformat

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* nits and unit tests

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* address some comments

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* replace m with self
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* reformat

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* nits

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* remove the fired flag

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* assert state on root only

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* comments

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* comments

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
parent 61ece000
...@@ -885,7 +885,6 @@ class FullyShardedDataParallel(nn.Module): ...@@ -885,7 +885,6 @@ class FullyShardedDataParallel(nn.Module):
def _reset_lazy_init(self) -> None: def _reset_lazy_init(self) -> None:
"""Reset instance so :func:`_lazy_init` will run on the next forward.""" """Reset instance so :func:`_lazy_init` will run on the next forward."""
self._is_root: Optional[bool] = None self._is_root: Optional[bool] = None
self._queue_wait_for_post_backward_closure: Optional[Callable] = None
self._streams: Dict[str, torch.cuda.Stream] = {} self._streams: Dict[str, torch.cuda.Stream] = {}
self._reducer: Optional[ReduceScatterBucketer] = None self._reducer: Optional[ReduceScatterBucketer] = None
for p in self.params: for p in self.params:
...@@ -1002,8 +1001,18 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1002,8 +1001,18 @@ class FullyShardedDataParallel(nn.Module):
return return
# No FullyShardedDataParallel instance wraps this, else _is_root would be set to False. # No FullyShardedDataParallel instance wraps this, else _is_root would be set to False.
self._is_root = True self._is_root = True
assert self._queue_wait_for_post_backward_closure is None # If final backward callback is never been queued, state should be IDLE.
self._queue_wait_for_post_backward_closure = self._queue_wait_for_post_backward # If final backward callback is queued, the callback should be finished
# and the state was reset to be IDLE.
# This should be asserted at the beginning of forward pass in the root instance only.
# For children instances, if they are checkpointed, state will not be reset to
# IDLE after each inner forward/backward.
self.assert_state(TrainingState.IDLE)
# Check if the root instance is being checkpointed. It doesn't make sense to
# checkpoint the root instance since it won't save GPU memory.
assert (
getattr(self, "_checkpoint_fwd_counter", 0) == 0
), "Is the root FSDP module wrapping an activation checkpointed module? If so, please remove that."
# As the root, we now set all children instances to False and # As the root, we now set all children instances to False and
# give them a closure to try to queue a wait_for_post_backward. # give them a closure to try to queue a wait_for_post_backward.
self.children_share_process_group = True self.children_share_process_group = True
...@@ -1015,14 +1024,6 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1015,14 +1024,6 @@ class FullyShardedDataParallel(nn.Module):
assert m._is_root is None or not m._is_root assert m._is_root is None or not m._is_root
if m._is_root is None: if m._is_root is None:
m._is_root = False m._is_root = False
# When root instance doesn't have params, allow children instances
# to queue the post_backward hook.
#
# TODO (Min): we should think if we can have a empty param at the root
# so that root always have a callback on the backward graph.
if not self._has_params:
assert m._queue_wait_for_post_backward_closure is None
m._queue_wait_for_post_backward_closure = self._queue_wait_for_post_backward
if m.process_group != self.process_group: if m.process_group != self.process_group:
self.children_share_process_group = False self.children_share_process_group = False
...@@ -1139,7 +1140,20 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1139,7 +1140,20 @@ class FullyShardedDataParallel(nn.Module):
if not torch.is_grad_enabled(): if not torch.is_grad_enabled():
return outputs # don't register hooks if grad isn't enabled return outputs # don't register hooks if grad isn't enabled
if self._is_root:
# This actually means that only root instance has
# _post_backward_callback_queued defined. Accidentally accessing this field
# will assert on all other instances, giving us a nice bug checker.
self._post_backward_callback_queued = False
def _pre_backward_hook(*unused: Any) -> None: def _pre_backward_hook(*unused: Any) -> None:
# try to queue final backward callback only once for root, so
# that final backward callback is attached to the outer most
# backward graph task and called after all the backward
# calls are completed.
if self._is_root:
self._queue_wait_for_post_backward()
if self._pre_backward_hook_has_run: if self._pre_backward_hook_has_run:
return # only run once (from multiple outputs or multiple forward passes) return # only run once (from multiple outputs or multiple forward passes)
self._pre_backward_hook_has_run = True self._pre_backward_hook_has_run = True
...@@ -1204,11 +1218,6 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1204,11 +1218,6 @@ class FullyShardedDataParallel(nn.Module):
""" """
if not torch.is_grad_enabled(): if not torch.is_grad_enabled():
return # don't register grad hooks if grad isn't enabled return # don't register grad hooks if grad isn't enabled
if self._is_root:
# This actually means that only root instance has this field
# defined. Accidentally accessing this field will assert on all
# other instances, giving us a nice bug checker.
self._post_backward_callback_queued = False
for p in self.params: for p in self.params:
if p.requires_grad: if p.requires_grad:
if hasattr(p, "_shard_bwd_hook"): if hasattr(p, "_shard_bwd_hook"):
...@@ -1281,14 +1290,6 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1281,14 +1290,6 @@ class FullyShardedDataParallel(nn.Module):
# Switch to FP32 shard after backward. # Switch to FP32 shard after backward.
self._use_fp32_param_shard([param]) self._use_fp32_param_shard([param])
# (try to) Enqueue a callback at the end of the backward pass to ensure that all
# post-backward work has finished. We only need one callback and all instances
# of FSDP (root and children) make this attempt here to queue to ensure it is queued
# no matter which instance(s) has(have) params.
assert self._queue_wait_for_post_backward_closure is not None or not self._is_root
if self._queue_wait_for_post_backward_closure is not None:
self._queue_wait_for_post_backward_closure()
if not self._require_backward_grad_sync: if not self._require_backward_grad_sync:
return return
...@@ -1354,13 +1355,12 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1354,13 +1355,12 @@ class FullyShardedDataParallel(nn.Module):
def _queue_wait_for_post_backward(self) -> None: def _queue_wait_for_post_backward(self) -> None:
"""Try to queue a `wait_for_post_backward` callback. """Try to queue a `wait_for_post_backward` callback.
Only called on root and only queue one callback. But can be called by Only called on root and only queue one callback at the beginning of
children FSDPs via a closure in case the root instance doesn't own any outer most backward.
params.
""" """
assert self._is_root assert self._is_root
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST])
if not self._post_backward_callback_queued: if not self._post_backward_callback_queued:
self.assert_state([TrainingState.IDLE])
self._post_backward_callback_queued = True self._post_backward_callback_queued = True
Variable._execution_engine.queue_callback(self._wait_for_post_backward) Variable._execution_engine.queue_callback(self._wait_for_post_backward)
...@@ -1409,12 +1409,19 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1409,12 +1409,19 @@ class FullyShardedDataParallel(nn.Module):
else: else:
m.assert_state(TrainingState.BACKWARD_PRE) m.assert_state(TrainingState.BACKWARD_PRE)
else: else:
# Unlikely case. When `m` and its children has no params or has params but # When `m` and its children has no params or has params but
# none with `requires_grad==True`, then m's pre-backward and post-backward # none with `requires_grad==True`, there are two cases:
# hooks aren't called by autograd. Therefore, it is in IDLE state. # 1. output tensors are `requires_grad==True`. In this case,
m.assert_state(TrainingState.IDLE) # pre-backward hook is still registered, so it is in BACKWARD_PRE state.
# 2. output tensors are `requires_grad==False`. In this case,
# pre-backward hook is not registered, so it is in IDLE state.
m.assert_state([TrainingState.BACKWARD_PRE, TrainingState.IDLE])
m.training_state = TrainingState.IDLE m.training_state = TrainingState.IDLE
if m._is_root:
# reset this flag for cases like "one forward pass + multiple backward passes"
self._post_backward_callback_queued = False
@torch.no_grad() @torch.no_grad()
def _rebuild_full_params(self, force_full_precision: bool = False) -> Optional[List[Tuple[torch.Tensor, bool]]]: def _rebuild_full_params(self, force_full_precision: bool = False) -> Optional[List[Tuple[torch.Tensor, bool]]]:
""" """
......
...@@ -278,6 +278,12 @@ class TestComparisonToPyTorchDDP(DistributedTest): ...@@ -278,6 +278,12 @@ class TestComparisonToPyTorchDDP(DistributedTest):
test_fn = functools.partial(self._test_identical_outputs, model_fn, config) test_fn = functools.partial(self._test_identical_outputs, model_fn, config)
spawn_and_init(test_fn) spawn_and_init(test_fn)
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_nested_all_wrapped_model_checkpoint(self, config):
model_fn = functools.partial(NestedWrappedModule, wrap_everything=True, checkpoint=True)
test_fn = functools.partial(self._test_identical_outputs, model_fn, config)
spawn_and_init(test_fn)
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test) @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_transformer_parameterized(self, config): def test_transformer_parameterized(self, config):
# Test every combination of these options: # Test every combination of these options:
...@@ -571,7 +577,7 @@ class TransformerWithSharedParams(nn.Module): ...@@ -571,7 +577,7 @@ class TransformerWithSharedParams(nn.Module):
class NestedWrappedModule(nn.Module): class NestedWrappedModule(nn.Module):
def __init__(self, group, wrapper_config, wrap_everything=False): def __init__(self, group, wrapper_config, wrap_everything=False, checkpoint=False):
super().__init__() super().__init__()
self.rank = group.rank() self.rank = group.rank()
self.world_size = group.size() self.world_size = group.size()
...@@ -591,13 +597,24 @@ class NestedWrappedModule(nn.Module): ...@@ -591,13 +597,24 @@ class NestedWrappedModule(nn.Module):
) )
# Wrap all modules triggers a corner case where root FSDP doesn't have any params. # Wrap all modules triggers a corner case where root FSDP doesn't have any params.
# Test it with checkpoint_wrapper as well to validate final backward callback
# is queued correctly when root FSDP does not have any params and every layer is
# wrapped as FSDP(checkpoint(module)).
if wrap_everything: if wrap_everything:
self.module = nn.Sequential( if checkpoint:
_maybe_wrap(nn.Linear(8, 4)), self.module = nn.Sequential(
_maybe_wrap(nn.Linear(4, 16)), _maybe_wrap(checkpoint_wrapper(nn.Linear(8, 4))),
_maybe_wrap(nn.Linear(16, 4)), _maybe_wrap(checkpoint_wrapper(nn.Linear(4, 16))),
_maybe_wrap(nn.Linear(4, 8)), _maybe_wrap(checkpoint_wrapper(nn.Linear(16, 4))),
) _maybe_wrap(checkpoint_wrapper(nn.Linear(4, 8))),
)
else:
self.module = nn.Sequential(
_maybe_wrap(nn.Linear(8, 4)),
_maybe_wrap(nn.Linear(4, 16)),
_maybe_wrap(nn.Linear(16, 4)),
_maybe_wrap(nn.Linear(4, 8)),
)
def get_input(self, device): def get_input(self, device):
torch.manual_seed(1 + self.rank) # keep everything deterministic torch.manual_seed(1 + self.rank) # keep everything deterministic
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment