Unverified Commit ba7df621 authored by Yanli Zhao's avatar Yanli Zhao Committed by GitHub
Browse files

[FSDP] Move final backward callback queueing to pre-backward hook of root instance (#753)

Move final backward callback to pre-backward hook of root FSDP instance

Summary:

Move final backward callback to pre-backward hook of root FSDP instance,
so that it is always attached to the outer most backward call and fired
after all backward calls are completed.

Also added flags to check final backward callback is fired when final
backward callback is required.

If root FSDP is checkpointed and called multiple times in forward,
check pointer counter is used to make sure final backward callback is queued inside last inner backward
call as well.

Test Plan: unit tests

Reviewers:

Subscribers:

Tasks:

Tags:

* reformat

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* nits and unit tests

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* address some comments

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* replace m with self
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* reformat

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* nits

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* remove the fired flag

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* assert state on root only

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* comments

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* comments

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
parent 61ece000
......@@ -885,7 +885,6 @@ class FullyShardedDataParallel(nn.Module):
def _reset_lazy_init(self) -> None:
"""Reset instance so :func:`_lazy_init` will run on the next forward."""
self._is_root: Optional[bool] = None
self._queue_wait_for_post_backward_closure: Optional[Callable] = None
self._streams: Dict[str, torch.cuda.Stream] = {}
self._reducer: Optional[ReduceScatterBucketer] = None
for p in self.params:
......@@ -1002,8 +1001,18 @@ class FullyShardedDataParallel(nn.Module):
return
# No FullyShardedDataParallel instance wraps this, else _is_root would be set to False.
self._is_root = True
assert self._queue_wait_for_post_backward_closure is None
self._queue_wait_for_post_backward_closure = self._queue_wait_for_post_backward
# If final backward callback is never been queued, state should be IDLE.
# If final backward callback is queued, the callback should be finished
# and the state was reset to be IDLE.
# This should be asserted at the beginning of forward pass in the root instance only.
# For children instances, if they are checkpointed, state will not be reset to
# IDLE after each inner forward/backward.
self.assert_state(TrainingState.IDLE)
# Check if the root instance is being checkpointed. It doesn't make sense to
# checkpoint the root instance since it won't save GPU memory.
assert (
getattr(self, "_checkpoint_fwd_counter", 0) == 0
), "Is the root FSDP module wrapping an activation checkpointed module? If so, please remove that."
# As the root, we now set all children instances to False and
# give them a closure to try to queue a wait_for_post_backward.
self.children_share_process_group = True
......@@ -1015,14 +1024,6 @@ class FullyShardedDataParallel(nn.Module):
assert m._is_root is None or not m._is_root
if m._is_root is None:
m._is_root = False
# When root instance doesn't have params, allow children instances
# to queue the post_backward hook.
#
# TODO (Min): we should think if we can have a empty param at the root
# so that root always have a callback on the backward graph.
if not self._has_params:
assert m._queue_wait_for_post_backward_closure is None
m._queue_wait_for_post_backward_closure = self._queue_wait_for_post_backward
if m.process_group != self.process_group:
self.children_share_process_group = False
......@@ -1139,7 +1140,20 @@ class FullyShardedDataParallel(nn.Module):
if not torch.is_grad_enabled():
return outputs # don't register hooks if grad isn't enabled
if self._is_root:
# This actually means that only root instance has
# _post_backward_callback_queued defined. Accidentally accessing this field
# will assert on all other instances, giving us a nice bug checker.
self._post_backward_callback_queued = False
def _pre_backward_hook(*unused: Any) -> None:
# try to queue final backward callback only once for root, so
# that final backward callback is attached to the outer most
# backward graph task and called after all the backward
# calls are completed.
if self._is_root:
self._queue_wait_for_post_backward()
if self._pre_backward_hook_has_run:
return # only run once (from multiple outputs or multiple forward passes)
self._pre_backward_hook_has_run = True
......@@ -1204,11 +1218,6 @@ class FullyShardedDataParallel(nn.Module):
"""
if not torch.is_grad_enabled():
return # don't register grad hooks if grad isn't enabled
if self._is_root:
# This actually means that only root instance has this field
# defined. Accidentally accessing this field will assert on all
# other instances, giving us a nice bug checker.
self._post_backward_callback_queued = False
for p in self.params:
if p.requires_grad:
if hasattr(p, "_shard_bwd_hook"):
......@@ -1281,14 +1290,6 @@ class FullyShardedDataParallel(nn.Module):
# Switch to FP32 shard after backward.
self._use_fp32_param_shard([param])
# (try to) Enqueue a callback at the end of the backward pass to ensure that all
# post-backward work has finished. We only need one callback and all instances
# of FSDP (root and children) make this attempt here to queue to ensure it is queued
# no matter which instance(s) has(have) params.
assert self._queue_wait_for_post_backward_closure is not None or not self._is_root
if self._queue_wait_for_post_backward_closure is not None:
self._queue_wait_for_post_backward_closure()
if not self._require_backward_grad_sync:
return
......@@ -1354,13 +1355,12 @@ class FullyShardedDataParallel(nn.Module):
def _queue_wait_for_post_backward(self) -> None:
"""Try to queue a `wait_for_post_backward` callback.
Only called on root and only queue one callback. But can be called by
children FSDPs via a closure in case the root instance doesn't own any
params.
Only called on root and only queue one callback at the beginning of
outer most backward.
"""
assert self._is_root
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST])
if not self._post_backward_callback_queued:
self.assert_state([TrainingState.IDLE])
self._post_backward_callback_queued = True
Variable._execution_engine.queue_callback(self._wait_for_post_backward)
......@@ -1409,12 +1409,19 @@ class FullyShardedDataParallel(nn.Module):
else:
m.assert_state(TrainingState.BACKWARD_PRE)
else:
# Unlikely case. When `m` and its children has no params or has params but
# none with `requires_grad==True`, then m's pre-backward and post-backward
# hooks aren't called by autograd. Therefore, it is in IDLE state.
m.assert_state(TrainingState.IDLE)
# When `m` and its children has no params or has params but
# none with `requires_grad==True`, there are two cases:
# 1. output tensors are `requires_grad==True`. In this case,
# pre-backward hook is still registered, so it is in BACKWARD_PRE state.
# 2. output tensors are `requires_grad==False`. In this case,
# pre-backward hook is not registered, so it is in IDLE state.
m.assert_state([TrainingState.BACKWARD_PRE, TrainingState.IDLE])
m.training_state = TrainingState.IDLE
if m._is_root:
# reset this flag for cases like "one forward pass + multiple backward passes"
self._post_backward_callback_queued = False
@torch.no_grad()
def _rebuild_full_params(self, force_full_precision: bool = False) -> Optional[List[Tuple[torch.Tensor, bool]]]:
"""
......
......@@ -278,6 +278,12 @@ class TestComparisonToPyTorchDDP(DistributedTest):
test_fn = functools.partial(self._test_identical_outputs, model_fn, config)
spawn_and_init(test_fn)
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_nested_all_wrapped_model_checkpoint(self, config):
model_fn = functools.partial(NestedWrappedModule, wrap_everything=True, checkpoint=True)
test_fn = functools.partial(self._test_identical_outputs, model_fn, config)
spawn_and_init(test_fn)
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_transformer_parameterized(self, config):
# Test every combination of these options:
......@@ -571,7 +577,7 @@ class TransformerWithSharedParams(nn.Module):
class NestedWrappedModule(nn.Module):
def __init__(self, group, wrapper_config, wrap_everything=False):
def __init__(self, group, wrapper_config, wrap_everything=False, checkpoint=False):
super().__init__()
self.rank = group.rank()
self.world_size = group.size()
......@@ -591,7 +597,18 @@ class NestedWrappedModule(nn.Module):
)
# Wrap all modules triggers a corner case where root FSDP doesn't have any params.
# Test it with checkpoint_wrapper as well to validate final backward callback
# is queued correctly when root FSDP does not have any params and every layer is
# wrapped as FSDP(checkpoint(module)).
if wrap_everything:
if checkpoint:
self.module = nn.Sequential(
_maybe_wrap(checkpoint_wrapper(nn.Linear(8, 4))),
_maybe_wrap(checkpoint_wrapper(nn.Linear(4, 16))),
_maybe_wrap(checkpoint_wrapper(nn.Linear(16, 4))),
_maybe_wrap(checkpoint_wrapper(nn.Linear(4, 8))),
)
else:
self.module = nn.Sequential(
_maybe_wrap(nn.Linear(8, 4)),
_maybe_wrap(nn.Linear(4, 16)),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment