Unverified Commit 95d31d4d authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] [FSDP] making sure we use full params for multiple backwards within an iteration (#775)



* [bug] [FSDP] making sure we use full params for multiple backwards within an iteration

* changelog
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent c10447f9
......@@ -11,10 +11,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
be set in the post backward hook. Modified the assert to account for the fact that the root
FSDP module can have child modules with params that require grad and it can contain params
that don't require grad and hence can fail the previous assert. [#761]
- FSDP: Fixed a bug when multiple backward pass is called within an iteration, parameters' sharding
state might be incorrect. [#775]
### Added
- FSDP: Added support for returning the original names of parameters when `named_parameters` is called on
the module. To retrieve the orginal names of the parameters along with the params, you need to
the module. To retrieve the orginal names of the parameters along with the params, you need to
call `named_parameters` under the `summon_full_params` context when using flattened params or original
params. If you are using original params (i.e flatten_params=False), calling `named_parameters` outside
of the `summon_full_params` context will still return the original param names along with the local shards. [#755]
......
......@@ -665,10 +665,10 @@ class FullyShardedDataParallel(nn.Module):
parameter as well as the parameter.
With FSDP, the `named_parameters` function implemented in `nn.Module` will not
be able to return the name and param when we use flattened parameters unless
be able to return the name and param when we use flattened parameters unless
we call this function under a `summon_full_params` context.
If you want the full param to be returned, you should call this function
If you want the full param to be returned, you should call this function
under a `summon_full_params` context when using flattened or original params.
"""
named_param = super().named_parameters(*args, **kwargs)
......@@ -1184,21 +1184,36 @@ class FullyShardedDataParallel(nn.Module):
if self._is_root:
self._queue_wait_for_post_backward()
# All-gather full parameters or switching to the full params.
#
# This needs to be done on every pre_backward hook, even within the same
# iteration (i.e. for checkpointed, multiple forward pass modules). This is
# because after the forward pass (i.e. in checkpoint inner graph), we always
# switch to fp32_shard in the ``forward`` function.
#
# We used to do this only after the ``self._pre_backward_hook_has_run``
# boolean guard below, which is incorrect. It worked in pytorch < 1.9 for
# some unknown reason, but pytorch 1.10 nightly exposed this bug.
#
# Note, both ``self._rebuild_full_params`` and ``self._use_full_params`` are
# idempotent. So in case they are called unnecessarily, they don't incur much
# overhead.
if self.reshard_after_forward:
self._rebuild_full_params()
else:
self._use_full_params()
# Only run the ``self._prep_grads_for_backward`` once per iteration (i.e. in case
# it is multiple outputs or multiple forward passes).
if self._pre_backward_hook_has_run:
return # only run once (from multiple outputs or multiple forward passes)
return
self._pre_backward_hook_has_run = True
# Start of a backward pass.
# Start of a backward pass for the first time in an iteration.
self.assert_state([TrainingState.IDLE, TrainingState.BACKWARD_PRE])
self.training_state = TrainingState.BACKWARD_PRE
# All-gather full parameters.
if self.reshard_after_forward:
self._rebuild_full_params()
else:
self._use_full_params()
# Prepare p.grad.
# Prepare p.grad so that it is in the right shape, device, accumulated values, etc.
self._prep_grads_for_backward()
def _register_hook(t: torch.Tensor) -> torch.Tensor:
......@@ -1472,6 +1487,9 @@ class FullyShardedDataParallel(nn.Module):
"""
Gather all shards of params.
Note, this is idempotent if full params are already gathered. Callers
assume the idempotency. So please keep it that way.
Args:
force_full_precision (bool, Optional): by default params will be gathered
in ``compute_dtype`` (e.g., FP16), unless *force_full_precision* is
......@@ -1567,6 +1585,9 @@ class FullyShardedDataParallel(nn.Module):
Switch p.data pointers to use the full params.
Note: this assumes full params are already gathered.
Note: this might be called after full_params is already in used. So please
make sure it is idempotent in that case.
"""
assert self.has_full_params
for p in self.params:
......@@ -1581,7 +1602,9 @@ class FullyShardedDataParallel(nn.Module):
@torch.no_grad()
def _prep_grads_for_backward(self) -> None:
"""Make sure p.grad is correctly prepared for the backward."""
""" Make sure p.grad is correctly prepared for the backward with
right shape, device, accumulated values, etc.
"""
for p in self.params:
if p.grad is not None:
if p.grad.device != p.data.device:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment