Unverified Commit 82986ca0 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] FSDP: multi-pass autograd graph and mixed precision (#513)



* FSDP: multi-pass autograd graph and mixed precision

- added BACKWARD_PRE/POST checking
- better assert_state
- fixed issue of backward hook misfiring

* fix

* cleanup

* Update fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Co-authored-by: default avatarMyle Ott <myleott@fb.com>
Co-authored-by: default avatarMyle Ott <myleott@fb.com>
parent c79bbd01
...@@ -7,9 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -7,9 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## NEXT - TBD ## NEXT - TBD
### Added ### Added
- Experimental: Add spectrain support ([#372](https://github.com/facebookresearch/fairscale/issues/372))
### Fixed ### Fixed
- OSS: fix a compatibily problem with lightning wrt optimizer state dict ([#510](https://github.com/facebookresearch/fairscale/issues/510)) - OSS: fix a compatibily problem with lightning wrt optimizer state dict ([#510](https://github.com/facebookresearch/fairscale/issues/510))
- FSDP: fixed a bug when part of autograd graph is traversed multiple times in mixed precision mode ([#513](https://github.com/facebookresearch/fairscale/pull/513())
## [0.3.1] - 2021-03-09 ## [0.3.1] - 2021-03-09
### Added ### Added
......
...@@ -34,6 +34,13 @@ class TrainingState(Enum): ...@@ -34,6 +34,13 @@ class TrainingState(Enum):
Simple enum to indicate what state FSDP is in. Used for asserting Simple enum to indicate what state FSDP is in. Used for asserting
to make sure APIs are called in the correct state. to make sure APIs are called in the correct state.
..note::
BACKWARD_PRE and BACKWARD_POST states are used to ensure we
receives backward hooks in the correct order. It is used to catch
unexpected order of hooks being called (likely due to our
hook registration logic or autograd engine logic changes).
TODO (Min): It would be nice to capture the stepping state as well. TODO (Min): It would be nice to capture the stepping state as well.
Maybe we can use the model.zero_grad() call, but not sure if it Maybe we can use the model.zero_grad() call, but not sure if it
is called if optim.zero_grad() is used instead. is called if optim.zero_grad() is used instead.
...@@ -45,7 +52,8 @@ class TrainingState(Enum): ...@@ -45,7 +52,8 @@ class TrainingState(Enum):
IDLE = auto() IDLE = auto()
FORWARD = auto() FORWARD = auto()
BACKWARD = auto() BACKWARD_PRE = auto()
BACKWARD_POST = auto()
SUMMON_FULL_PARAMS = auto() SUMMON_FULL_PARAMS = auto()
...@@ -581,7 +589,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -581,7 +589,7 @@ class FullyShardedDataParallel(nn.Module):
# Set the state so that we assert when trying to go into # Set the state so that we assert when trying to go into
# forward/backward. # forward/backward.
self.training_state = TrainingState.SUMMON_FULL_PARAMS self.training_state = TrainingState.SUMMON_FULL_PARAMS
full_tensors = self._rebuild_full_params(full_precision=True) full_tensors = self._rebuild_full_params(force_full_precision=True)
assert full_tensors is not None assert full_tensors is not None
with contextlib.ExitStack() as stack: with contextlib.ExitStack() as stack:
if self.flatten_parameters and self.module.is_flattened: if self.flatten_parameters and self.module.is_flattened:
...@@ -835,7 +843,8 @@ class FullyShardedDataParallel(nn.Module): ...@@ -835,7 +843,8 @@ class FullyShardedDataParallel(nn.Module):
pre_backward_hook_has_run[0] = True pre_backward_hook_has_run[0] = True
# Start of a backward pass. # Start of a backward pass.
self.training_state = TrainingState.BACKWARD self.assert_state([TrainingState.IDLE, TrainingState.BACKWARD_PRE])
self.training_state = TrainingState.BACKWARD_PRE
# All-gather full parameters. # All-gather full parameters.
if self.reshard_after_forward: if self.reshard_after_forward:
...@@ -857,7 +866,40 @@ class FullyShardedDataParallel(nn.Module): ...@@ -857,7 +866,40 @@ class FullyShardedDataParallel(nn.Module):
return outputs return outputs
def _register_post_backward_hooks(self) -> None: def _register_post_backward_hooks(self) -> None:
"""Register backward hooks to reshard params and reduce-scatter grads.""" """
Register backward hooks to reshard params and reduce-scatter grads.
This is called during forward pass. The goal is to attach a hook
on each of the parameter's gradient generating function (``grad_acc``
below) so that the hook is called *after* all gradients for that
param are computed.
Goals:
1. We want the hook to fire once and only once *after* all gradients
are accumulated for a param.
2. If it fires more than once, we end up incorrectly shard the grad
multiple times. (could lead to dimension too small)
3. If it fires once but too early or doesn't fire, we leave gradients
unsharded. (could lead to dimension too large)
Due to multiple-pass forward, this function can be called on
the same parameter multiple times in a single forward pass. If we register
the hook multiple time, we end up getting called multiple times. We
could try to get a new hook every time and delete the previous one
registered. However, due to *unknown reason* (I have debugged it for
a long time!), in mixed precision mode, we get two different ``grad_acc``
objects below during different calls of this function (in the same
forward pass). If we keep the last one, the hook end up firing too
early. In full precision mode, we luckily get the *same* ``grad_acc``
object, so deleting and re-registering still ensured the hook fire
once after all gradients are generated.
Empirically, keep the first hook register per forward pass seems to
work the best. We do need to remove the hook at the end of the
backward pass. Otherwise, the next forward pass will not register
a new hook, which is needed for a new forward pass.
"""
if not torch.is_grad_enabled(): if not torch.is_grad_enabled():
return # don't register grad hooks if grad isn't enabled return # don't register grad hooks if grad isn't enabled
if self._is_root: if self._is_root:
...@@ -868,9 +910,11 @@ class FullyShardedDataParallel(nn.Module): ...@@ -868,9 +910,11 @@ class FullyShardedDataParallel(nn.Module):
for p in self.params: for p in self.params:
if p.requires_grad: if p.requires_grad:
if hasattr(p, "_shard_bwd_hook"): if hasattr(p, "_shard_bwd_hook"):
p._shard_bwd_hook[1].remove() # remove existing handle continue
p_tmp = p.expand_as(p) # Register a hook on the first call, empirically, autograd
grad_acc = p_tmp.grad_fn.next_functions[0][0] # fires it at the end for this param, which makes sense.
p_tmp = p.expand_as(p) # Get a grad_fn on p_tmp.
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
handle = grad_acc.register_hook(functools.partial(self._post_backward_hook, p)) handle = grad_acc.register_hook(functools.partial(self._post_backward_hook, p))
p._shard_bwd_hook = (grad_acc, handle) p._shard_bwd_hook = (grad_acc, handle)
...@@ -895,7 +939,10 @@ class FullyShardedDataParallel(nn.Module): ...@@ -895,7 +939,10 @@ class FullyShardedDataParallel(nn.Module):
alignment is created by :func:`_shard_parameters_`, which ensures that alignment is created by :func:`_shard_parameters_`, which ensures that
the local optimizer only sees the relevant parameter shard. the local optimizer only sees the relevant parameter shard.
""" """
self.assert_state(TrainingState.BACKWARD) # First hook callback will see PRE state. If we have multiple params,
# then subsequent hook callbacks will see POST state.
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST])
self.training_state = TrainingState.BACKWARD_POST
if param.grad is None: if param.grad is None:
return return
if param.grad.requires_grad: if param.grad.requires_grad:
...@@ -966,7 +1013,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -966,7 +1013,7 @@ class FullyShardedDataParallel(nn.Module):
"""Hook to call on each param after the reduce-scatter.""" """Hook to call on each param after the reduce-scatter."""
assert torch.cuda.current_stream() == self._streams["post_backward"] assert torch.cuda.current_stream() == self._streams["post_backward"]
assert param.grad is not None assert param.grad is not None
self.assert_state(TrainingState.BACKWARD) self.assert_state(TrainingState.BACKWARD_POST)
param.grad.data = reduced_grad param.grad.data = reduced_grad
# Cast grad to param's dtype (typically FP32). Note: we do this # Cast grad to param's dtype (typically FP32). Note: we do this
# before the move_grads_to_cpu step so that this entire hook remains # before the move_grads_to_cpu step so that this entire hook remains
...@@ -992,7 +1039,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -992,7 +1039,7 @@ class FullyShardedDataParallel(nn.Module):
params. params.
""" """
assert self._is_root assert self._is_root
self.assert_state(TrainingState.BACKWARD) self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST])
if not self._post_backward_callback_queued: if not self._post_backward_callback_queued:
self._post_backward_callback_queued = True self._post_backward_callback_queued = True
Variable._execution_engine.queue_callback(self._wait_for_post_backward) Variable._execution_engine.queue_callback(self._wait_for_post_backward)
...@@ -1001,7 +1048,20 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1001,7 +1048,20 @@ class FullyShardedDataParallel(nn.Module):
def _wait_for_post_backward(self) -> None: def _wait_for_post_backward(self) -> None:
"""Wait for post-backward to finish. Only called on root instance.""" """Wait for post-backward to finish. Only called on root instance."""
assert self._is_root assert self._is_root
self.assert_state(TrainingState.BACKWARD) if self._has_params:
self.assert_state(TrainingState.BACKWARD_POST)
else:
self.assert_state(TrainingState.BACKWARD_PRE)
def _remove_shard_bwd_hook(fsdp_module: FullyShardedDataParallel) -> None:
"""Helper used below on all fsdp modules."""
for p in fsdp_module.params:
if p.requires_grad:
if hasattr(p, "_shard_bwd_hook"):
assert len(p._shard_bwd_hook) == 2, len(p._shard_bwd_hook)
p._shard_bwd_hook[1].remove()
delattr(p, "_shard_bwd_hook")
if self._require_backward_grad_sync: if self._require_backward_grad_sync:
# Flush any unreduced buckets in the post_backward stream. # Flush any unreduced buckets in the post_backward stream.
with torch.cuda.stream(self._streams["post_backward"]): with torch.cuda.stream(self._streams["post_backward"]):
...@@ -1014,35 +1074,47 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1014,35 +1074,47 @@ class FullyShardedDataParallel(nn.Module):
# A backward pass is done, update root and nested FSDP's flags. # A backward pass is done, update root and nested FSDP's flags.
for m in self.modules(): # includes self for m in self.modules(): # includes self
if isinstance(m, FullyShardedDataParallel): if isinstance(m, FullyShardedDataParallel):
m.assert_state(TrainingState.BACKWARD) _remove_shard_bwd_hook(m)
if m._has_params:
m.assert_state(TrainingState.BACKWARD_POST)
else:
m.assert_state(TrainingState.BACKWARD_PRE)
m.training_state = TrainingState.IDLE m.training_state = TrainingState.IDLE
@torch.no_grad() @torch.no_grad()
def _rebuild_full_params(self, full_precision: bool = False) -> Optional[List[Tuple[torch.Tensor, bool]]]: def _rebuild_full_params(self, force_full_precision: bool = False) -> Optional[List[Tuple[torch.Tensor, bool]]]:
""" """
Gather all shards of params. Gather all shards of params.
Args: Args:
full_precision (bool, Optional): by default params will be gathered force_full_precision (bool, Optional): by default params will be gathered
in ``compute_dtype`` (e.g., FP16), unless *full_precision* is in ``compute_dtype`` (e.g., FP16), unless *force_full_precision* is
``True``, in which case they will be gathered in full precision ``True``, in which case they will be gathered in full precision
(e.g., FP32), possibly in fresh storage. (e.g., FP32), possibly in fresh storage. The parameter that's being
rebuilt will end up in full precision as well.
Returns: Returns:
A list of tuples, where the first element is the full-sized param A list of tuples, where the first element is the full-sized param
and the second element is a bool indicating if it's safe for the and the second element is a bool indicating if it's safe for the
caller to free the full-sized param. This will be ``None`` if caller to free the full-sized param. This will be ``None`` if
``full_precision=False`` and the full params are already gathered. ``force_full_precision=False`` and the full params are already gathered.
""" """
output_tensors: List[Tuple[torch.Tensor, bool]] = [] output_tensors: List[Tuple[torch.Tensor, bool]] = []
def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None: def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:
"""
Helper function to update p.data pointer.
Args:
custom_output_tensor (torch.Tensor, Optional): if not None, this
tensor contains the data we just gathered.
"""
if custom_output_tensor is not None: if custom_output_tensor is not None:
assert p._is_sharded assert p._is_sharded
p.data = custom_output_tensor p.data = custom_output_tensor
output_tensors.append((p.data, True)) output_tensors.append((p.data, True))
elif not p._is_sharded: elif not p._is_sharded:
if self.mixed_precision and not full_precision: if self.mixed_precision and not force_full_precision:
p.data = p._fp16_shard p.data = p._fp16_shard
output_tensors.append((p.data, True)) output_tensors.append((p.data, True))
else: else:
...@@ -1055,7 +1127,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1055,7 +1127,7 @@ class FullyShardedDataParallel(nn.Module):
p.data = p.data[: p._orig_size.numel()].view(p._orig_size) p.data = p.data[: p._orig_size.numel()].view(p._orig_size)
# Early exit if we already have full params and don't need full precision. # Early exit if we already have full params and don't need full precision.
if self.has_full_params and not full_precision: if self.has_full_params and not force_full_precision:
for p in self.params: for p in self.params:
update_p_data() update_p_data()
return output_tensors return output_tensors
...@@ -1063,27 +1135,28 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1063,27 +1135,28 @@ class FullyShardedDataParallel(nn.Module):
self.has_full_params = True self.has_full_params = True
with torch.cuda.stream(self._streams["all_gather"]): with torch.cuda.stream(self._streams["all_gather"]):
if self.mixed_precision and not full_precision: if self.mixed_precision and not force_full_precision:
self._cast_fp32_param_shards_to_fp16() self._cast_fp32_param_shards_to_fp16()
for p in self.params: for p in self.params:
if not p._is_sharded: # e.g., when world_size == 1 if not p._is_sharded: # e.g., when world_size == 1
update_p_data() update_p_data()
else: else:
# If self.cpu_offload and full_precision, we need to cast # If self.cpu_offload and force_full_precision, we need to cast
# the FP32 CPU param to CUDA for the all-gather. # the FP32 CPU param to CUDA for the all-gather.
p_data = p.data.to(p._full_param_padded.device) p_data = p.data.to(p._full_param_padded.device)
p_size = p._full_param_padded.size() p_size = p._full_param_padded.size()
assert p_size.numel() % self.world_size == 0 assert p_size.numel() % self.world_size == 0
if not self.mixed_precision or not full_precision: if self.mixed_precision and force_full_precision:
# Allocate fresh tensor in full precision since we are in
# mixed precision and full precision rebuild is asked.
output_tensor = p_data.new_zeros(p_size)
else:
if p._full_param_padded.storage().size() != p_size.numel(): if p._full_param_padded.storage().size() != p_size.numel():
# Allocate based on full size from all shards. # Allocate based on full size from all shards.
alloc_storage_(p._full_param_padded, size=p_size) alloc_storage_(p._full_param_padded, size=p_size)
output_tensor = p._full_param_padded output_tensor = p._full_param_padded
else:
# Allocate fresh tensor in full precision.
output_tensor = p_data.new_zeros(p_size)
# Fill output_tensor with (p.data for each shard in self.world_size) # Fill output_tensor with (p.data for each shard in self.world_size)
chunks = list(output_tensor.chunk(self.world_size)) chunks = list(output_tensor.chunk(self.world_size))
...@@ -1092,7 +1165,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1092,7 +1165,7 @@ class FullyShardedDataParallel(nn.Module):
# Set p.data = output_tensor (with padding trimmed) # Set p.data = output_tensor (with padding trimmed)
update_p_data(output_tensor) update_p_data(output_tensor)
if self.mixed_precision and not full_precision: if self.mixed_precision and not force_full_precision:
self._free_fp16_param_shard([p]) self._free_fp16_param_shard([p])
torch.cuda.current_stream().wait_stream(self._streams["all_gather"]) torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
return output_tensors return output_tensors
...@@ -1180,11 +1253,19 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1180,11 +1253,19 @@ class FullyShardedDataParallel(nn.Module):
p._fp16_shard.record_stream(current_stream) p._fp16_shard.record_stream(current_stream)
free_storage_(p._fp16_shard) free_storage_(p._fp16_shard)
def assert_state(self, state: TrainingState) -> None: def assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None:
"""Assert we are in the given state.""" """Assert we are in the given state."""
assert ( # Since assert can be turned off and this error checking
self.training_state == state # is really important, we use explicit error checking
), f"expected to be in state {state} but current state is {self.training_state}" # and raise a ValueError if needed.
if isinstance(state, TrainingState):
state = [state]
if self.training_state not in state:
msg = f"expected to be in states {state} but current state " f"is {self.training_state}"
# In case we are failing in the context of autograd hook, asserting
# may not generate useful msg. So, let's print it to be sure.
print(msg)
raise ValueError(msg)
@torch.no_grad() @torch.no_grad()
......
...@@ -3,6 +3,7 @@ tests/nn/data_parallel/test_fsdp_grad_scaler.py ...@@ -3,6 +3,7 @@ tests/nn/data_parallel/test_fsdp_grad_scaler.py
tests/nn/data_parallel/test_fsdp_no_sync.py tests/nn/data_parallel/test_fsdp_no_sync.py
tests/nn/data_parallel/test_fsdp_summon_full_params.py tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_fsdp_input.py tests/nn/data_parallel/test_fsdp_input.py
tests/nn/data_parallel/test_fsdp_multiple_forward.py
tests/nn/data_parallel/test_sharded_ddp_features.py tests/nn/data_parallel/test_sharded_ddp_features.py
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
tests/nn/pipe/skip/test_gpipe.py tests/nn/pipe/skip/test_gpipe.py
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
""" Test FSDP with different multiple forward of the same module. """
import tempfile
import pytest
import torch
import torch.multiprocessing as mp
from torch.nn import Linear, Module
from torch.optim import SGD
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, torch_version
def _test_func(rank, world_size, fsdp_config, tempfile_name, unused):
result = dist_init(rank, world_size, tempfile_name, unused)
assert result, "Dist init failed"
assert isinstance(fsdp_config, dict), str(fsdp_config)
class Model(Module):
def __init__(self):
super().__init__()
self.inner = FSDP(Linear(4, 4), **fsdp_config)
self.outer = Linear(4, 5)
def forward(self, x):
# Forward twice.
i = self.inner(x)
j = self.inner(x)
return self.outer(i + j)
model = FSDP(Model(), **fsdp_config).cuda()
optim = SGD(model.parameters(), lr=0.1)
for _ in range(3):
in_data = torch.rand(64, 4).cuda()
in_data.requires_grad = True
out = model(in_data)
out.sum().backward()
optim.step()
optim.zero_grad()
model.assert_state(TrainingState.IDLE)
teardown()
# We use strings for precision and flatten instead of bool to
# make the pytest output more readable.
@skip_if_single_gpu
@pytest.mark.parametrize("precision", ["full", "mixed"])
@pytest.mark.parametrize("flatten", ["flatten", "no_flatten"])
def test1(precision, flatten):
if torch_version() < (1, 6, 0):
pytest.skip("older pytorch doesn't support reduce_scatter")
temp_file_name = tempfile.mkstemp()[1]
unused = tempfile.mkstemp()[1]
fsdp_config = {}
fsdp_config["mixed_precision"] = precision == "mixed"
fsdp_config["flatten_parameters"] = flatten == "flatten"
# Some bugs only show up when we are in world_size > 1 due to sharding changing
# the tensor dimensions.
world_size = 2
mp.spawn(
_test_func, args=(world_size, fsdp_config, temp_file_name, unused), nprocs=world_size, join=True,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment