Unverified Commit f4fcee7e authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[Fix][FSDP] Don't remove post backward hooks for multiple backward fix (#1079)



* tmp

* test again

* test again

* add new test

* clean up

* add test file to the testlist

* more comments

* add changelog
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent 8f8f8ef9
......@@ -5,7 +5,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.4.7] - TBD
## [0.4.11] - TBD
- cleaned up some old issues and fixed a few bug in FSDP
- removing SSD offload to simplify the FSDP code
## [0.4.8]/[0.4.9]/[0.4.10]
### Added
......@@ -48,6 +53,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
parameter internally.
### Fixed
- fixed some bugs in FSDP related to supporting data2vec EMA modules.
[0.4.6] - 2022-03-08
......
......@@ -586,7 +586,7 @@ class FullyShardedDataParallel(nn.Module):
p must to be already sharded by the owning module.
Check the corresponding unit test to see how is it used and tested.
Check the corresponding unit tests to see how is it used and tested.
In particular, the sharing FSDP wrappers are "siblings" not "parent"
and "child" of each other in the nested module structure.
......@@ -1700,35 +1700,38 @@ class FullyShardedDataParallel(nn.Module):
3. If it fires once but too early or doesn't fire, we leave gradients
unsharded. (could lead to dimension too large)
Due to multiple-pass forward, this function can be called on
the same parameter multiple times in a single forward pass. If we register
the hook multiple time, we end up getting called multiple times. We
could try to get a new hook every time and delete the previous one
registered. However, due to *unknown reason* (I have debugged it for
a long time!), in mixed precision mode, we get two different ``grad_acc``
objects below during different calls of this function (in the same
forward pass). If we keep the last one, the hook end up firing too
early. In full precision mode, we luckily get the *same* ``grad_acc``
object, so deleting and re-registering still ensured the hook fire
once after all gradients are generated.
Empirically, keep the first hook register per forward pass seems to
work the best. We do need to remove the hook at the end of the
backward pass. Otherwise, the next forward pass will not register
a new hook, which is needed for a new forward pass.
There are several cases here:
1. We can call the same module multiple times in a single outer forward
pass. We register multiple hooks but autograd should fire the last
one after the total gradient is computed and accumulated. If it does
fire multiple times, we may have a crash due to gradient being already
sharded and shape mismatch.
On the other hand, due to _saved_grad_shard, this case may also work
but with extra grad scatter-gather.
2. With activation checkpointing and case 1.
3. The same outer forward can be called multiple times before any backward
is called (within the no_sync context) for a special way of gradient
accumulation. (see test_fsdp_fwd_fwd_bwd_bwd.py)
4. When a param is shared by multiple FSDP wrapper instances, this can
register multiple times. (See test_fsdp_shared_weights.py)
It appears that registering the hook everytime and let them fire and
hook being removed/freed automatically is the correct thing to do. But this
is purely based on experiments.
"""
if not torch.is_grad_enabled():
return # don't register grad hooks if grad isn't enabled
for p in self.params:
if p.requires_grad:
if hasattr(p, "_shard_bwd_hook"):
continue
# Register a hook on the first call, empirically, autograd
# fires it at the end for this param, which makes sense.
# Register a hook.
p_tmp = p.expand_as(p) # Get a grad_fn on p_tmp.
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
handle = grad_acc.register_hook(functools.partial(self._post_backward_hook, p))
# Important, we need to save the hook, otherwise, it appears to be
# deleted/freed/unregistered.
# However, we don't free/unhook at the end of bwd (as we used to do it
# in _finalize_parameters below). If we do, that may unregister the wrong hook.
p._shard_bwd_hook = (grad_acc, handle)
@torch.no_grad()
......@@ -1756,20 +1759,26 @@ class FullyShardedDataParallel(nn.Module):
# then subsequent hook callbacks will see POST state.
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST])
self.training_state = TrainingState.BACKWARD_POST
if param.grad is None:
return
if hasattr(param, "_linked_param"):
# This links to a shared param. We should finalize the linked param here.
assert param.shape == (1,), param.shape
# This links to a shared param. We should try to finalize the linked param here.
# This is done by module code to ensure correct gradient computation.
# p._is_shared and p._linked_param are closely related but not the same.
# See fairscale/experimental/nn/mevo.py.
assert param.shape == (1,), param.shape # This param should have this special dim.
# If the _is_shared flag is set, then this shared weight is indeed being
# shared between different FSDP wrappers. Otherwise, they are linked but
# likely in the same FSDP wrapper, which means we shouldn't finalize the
# linked param..
if hasattr(param._linked_param, "_is_shared") and param._linked_param._is_shared:
# param._linked_param may or may not have .grad since this callback
# could happen multiple times to support #918. Since we check `if param.grad is None`
# below anyway, this is OK.
param = param._linked_param
assert param.grad is not None, param.shape
if param.grad is None:
return
if param.grad.requires_grad:
raise RuntimeError("FSDP only works with gradients that don't require gradients")
......@@ -1950,10 +1959,6 @@ class FullyShardedDataParallel(nn.Module):
for p in fsdp_module.params:
if not p.requires_grad:
continue
if hasattr(p, "_shard_bwd_hook"):
p_assert(len(p._shard_bwd_hook) == 2, f"WFPB: incorrect hook num: {len(p._shard_bwd_hook)}")
p._shard_bwd_hook[1].remove()
delattr(p, "_shard_bwd_hook")
# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
# remains the unsharded gradient accumulated from prior no-sync passes, and p._saved_grad_shard
......@@ -1969,11 +1974,13 @@ class FullyShardedDataParallel(nn.Module):
elif hasattr(p, "_saved_grad_shard"):
p_assert(
p.device == p._saved_grad_shard.device,
f"WFPB: incorrect saved_grad_shard device {p.device} vs {p._saved_grad_shard.device}",
f"WFPB: incorrect saved_grad_shard device p.device={p.device} "
f"vs p._saved_grad_shard.device={p._saved_grad_shard.device}",
)
p_assert(
p.shape == p._saved_grad_shard.shape,
f"WFPB: incorrect saved_grad_shard shape {p.shape} vs {p._saved_grad_shard.shape}",
f"WFPB: incorrect saved_grad_shard shape p.shape={p.shape} "
f"vs p._saved_grad_shard.shape={p._saved_grad_shard.shape}",
)
p.grad = p._saved_grad_shard
......
......@@ -51,3 +51,4 @@ tests/nn/pipe/test_stream.py
tests/nn/moe/test_moe_layer.py
tests/nn/moe/test_top2gating.py
tests/nn/data_parallel/test_fsdp_offload.py
tests/nn/data_parallel/test_fsdp_fwd_fwd_bwd_bwd.py
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from fairscale.fair_dev.testing.testing import skip_if_single_gpu, temp_files_ctx
from fairscale.nn import enable_wrap, wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel
class FFN(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 10)
self.fc2 = nn.Linear(10, 10)
self.relu = nn.ReLU()
def forward(self, x):
return self.fc2(self.relu(self.fc1(x)))
def main(rank, sync_file):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.set_device(rank)
torch.distributed.init_process_group(
backend="nccl",
init_method=f"file://{sync_file}",
world_size=2,
rank=rank,
)
ffn = FFN().cuda().half()
with enable_wrap(wrapper_cls=FullyShardedDataParallel):
model = wrap(
ffn,
process_group=torch.distributed.new_group(),
flatten_parameters=True,
compute_dtype=torch.float16,
)
model = model.train()
# We test this behavior because it might be used by pipelining.
# However, we don't check if the speed (compute/comm overlapping)
# and memory (necessary all-gather & free) are optimal.
losses = []
for _ in range(3):
x = torch.rand((10, 10)).cuda().half()
out = model(x)
loss = out.sum()
losses.append(loss)
# Only the last bwd can be outside of no_sync context.
with model.no_sync():
losses[0].backward()
losses[1].backward()
losses[2].backward()
@skip_if_single_gpu
def test_fwd_fwd_bwd_bwd():
with temp_files_ctx(num=1) as temp_files:
torch.multiprocessing.spawn(
fn=main,
nprocs=2,
args=(temp_files[0],),
join=True,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment