Unverified Commit 5da5c0eb authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix]: Fixes an issue with pre_backward hook registering (#833)



* added the failing test

* fixed the bug

* fine-tune the condition

* typo

* typo

* changelog and added test to test files
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent cabad2f7
...@@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## NEXT - TBD ## NEXT - TBD
### Fixed ### Fixed
- FSDP: Fixed an pre-backward hook bug for certain type of models and FSDP config. [#833]
### Added ### Added
- LayerwiseMemoryTracker[feature][experimental] - This is a new experimental tool to help track, visualize and suggest fix for memory issues occurring during the forward/backward pass of your models. [#808] - LayerwiseMemoryTracker[feature][experimental] - This is a new experimental tool to help track, visualize and suggest fix for memory issues occurring during the forward/backward pass of your models. [#808]
......
...@@ -942,6 +942,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -942,6 +942,7 @@ class FullyShardedDataParallel(nn.Module):
for p in self.params: for p in self.params:
if hasattr(p, "_fp32_shard"): if hasattr(p, "_fp32_shard"):
del p._fp32_shard # reset _init_param_attributes del p._fp32_shard # reset _init_param_attributes
self._output_pre_backward_hook_registered: Optional[List] = None
def _lazy_init(self) -> None: def _lazy_init(self) -> None:
"""Initialization steps that should happen lazily, typically right """Initialization steps that should happen lazily, typically right
...@@ -958,6 +959,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -958,6 +959,7 @@ class FullyShardedDataParallel(nn.Module):
if self._is_root is None: if self._is_root is None:
self._set_is_root() self._set_is_root()
self._setup_streams() self._setup_streams()
self._setup_output_hook_list()
if self._is_root: if self._is_root:
# Buffers stay on GPU, and don't get sharded. Since _cast_buffers # Buffers stay on GPU, and don't get sharded. Since _cast_buffers
...@@ -1104,6 +1106,16 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1104,6 +1106,16 @@ class FullyShardedDataParallel(nn.Module):
m._streams = self._streams m._streams = self._streams
m._reducer = self._reducer m._reducer = self._reducer
def _setup_output_hook_list(self) -> None:
""" set up a list to avoid registering pre-backward hooks
incorrectly.
"""
assert self._is_root, "This should only be called on the root"
self._output_pre_backward_hook_registered = []
for n, m in self.named_modules():
if n != "" and isinstance(m, FullyShardedDataParallel):
m._output_pre_backward_hook_registered = self._output_pre_backward_hook_registered
def _wait_for_previous_optim_step(self) -> None: def _wait_for_previous_optim_step(self) -> None:
""" """
The outer-most :class:`FullyShardedDataParallel` instance (i.e., the root The outer-most :class:`FullyShardedDataParallel` instance (i.e., the root
...@@ -1236,9 +1248,15 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1236,9 +1248,15 @@ class FullyShardedDataParallel(nn.Module):
self.training_state = TrainingState.BACKWARD_PRE self.training_state = TrainingState.BACKWARD_PRE
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST]) self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST])
_registered = 0
def _register_hook(t: torch.Tensor) -> torch.Tensor: def _register_hook(t: torch.Tensor) -> torch.Tensor:
if t.requires_grad: nonlocal _registered
assert self._output_pre_backward_hook_registered is not None
if t.requires_grad and (_registered == 0 or id(t) not in self._output_pre_backward_hook_registered):
t.register_hook(_pre_backward_hook) t.register_hook(_pre_backward_hook)
self._output_pre_backward_hook_registered.append(id(t))
_registered += 1
return t return t
# Attach hooks to Tensor outputs. # Attach hooks to Tensor outputs.
...@@ -1526,6 +1544,9 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1526,6 +1544,9 @@ class FullyShardedDataParallel(nn.Module):
if m._is_root: if m._is_root:
# reset this flag for cases like "one forward pass + multiple backward passes" # reset this flag for cases like "one forward pass + multiple backward passes"
self._post_backward_callback_queued = False self._post_backward_callback_queued = False
# clear this list for next iteration
assert self._output_pre_backward_hook_registered is not None
self._output_pre_backward_hook_registered.clear()
@torch.no_grad() @torch.no_grad()
def _rebuild_full_params(self, force_full_precision: bool = False) -> Optional[List[Tuple[torch.Tensor, bool]]]: def _rebuild_full_params(self, force_full_precision: bool = False) -> Optional[List[Tuple[torch.Tensor, bool]]]:
......
tests/nn/data_parallel/test_fsdp_pre_backward_hook.py
tests/nn/data_parallel/test_fsdp_overlap.py tests/nn/data_parallel/test_fsdp_overlap.py
tests/nn/data_parallel/test_fsdp_multiple_forward.py tests/nn/data_parallel/test_fsdp_multiple_forward.py
tests/nn/data_parallel/test_fsdp_apply.py tests/nn/data_parallel/test_fsdp_apply.py
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
""" Test FSDP with pre-backward hook bug. """
import pytest
import torch
from torch.nn import Linear, Module
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils.testing import dist_init, skip_if_no_cuda, teardown, temp_files_ctx
# A fixture to get tempfiles and ensure they are cleaned up.
@pytest.fixture()
def temp_files():
# dist_init needs 2 files
with temp_files_ctx(2) as files:
yield files
@skip_if_no_cuda
def test_pre_backward_hook(temp_files):
"""Test FSDP with a model that triggers a pre_backward hook bug."""
result = dist_init(rank=0, world_size=1, filename=temp_files[0], filename_rpc=temp_files[1])
assert result, "Dist init failed"
class Model(Module):
def __init__(self):
super().__init__()
self.l1 = Linear(4, 4).cuda()
self.l2 = FSDP(Linear(4, 4).cuda())
self.l3 = Linear(4, 4).cuda()
def forward(self, x):
x = self.l1(x)
x = self.l2(x)
inner_result = x
x = self.l3(x)
return x, inner_result
def assert_and_clear_grad(self):
for p in self.parameters():
assert p.shape in [(4, 4), (4,), (4 * 4 + 4,)], p.shape
assert p.grad is not None
p.grad = None
model = FSDP(Model(), flatten_parameters=False).cuda()
in_data = torch.rand(1, 4).cuda()
for _ in range(3):
out, _ = model(in_data)
out.sum().backward()
model.assert_and_clear_grad()
teardown()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment