Unverified Commit cd0f0b88 authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

FSDP: supporting gradient accumulation without no_sync context manager to save GPU memory (#752)



* Add test (broken) for gradient accumulation without no_sync context manager

* changelog

* no_sync to grad_acc renaming for tests

* clean up tmp files

* support grad acc without no_sync

* minor

* update changelog

* Update fairscale/nn/data_parallel/fully_sharded_data_parallel.py

Better assertion from Sam.
Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>

* lint
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
Co-authored-by: default avatarMin Xu <24926999+min-xu-ai@users.noreply.github.com>
Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>
parent ba7df621
...@@ -6,15 +6,26 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -6,15 +6,26 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## NEXT - TBD ## NEXT - TBD
### Fixed ### Fixed
- FSDP: fixed final backward callback in certain activation checkpointed cases. Before this fix,
if a model is activation checkpointed in a certain way, the final backward
callback can fire incorrectly. That's due to autograd and reentrant backward
graphs. With this fix, the final callback is always registered on the outer
most root FSDP instance (i.e. the outer most backward graph), which result
in reliably firing. This makes FSDP much more robust with respect to different
models and activation checkpoints. [#753]
### Added ### Added
- FSDP: support gradient accumulation without the `no_sync` context. This is useful
in training with smaller number of GPU with same overall batch size as large
number of GPUs. Compared with the `no_sync` context, this mode consumes less
GPU memory but uses more networking bandwidth. [#752]
## [0.3.9] - 2021-07-26 ## [0.3.9] - 2021-07-26
### Fixed ### Fixed
- FSDP: fixed metadata saving and shard consolidation for MoE cases. When a model has - FSDP: fixed metadata saving and shard consolidation for MoE cases. When a model has
shared parameters or mixture of expert layers, the handling of state dict shared parameters or mixture of expert layers, the handling of state dict
metadata was broken. This release fixes that. [#746] metadata was broken. This release fixes that. [#746]
- OSS: fixed the buckets which would stay in fp16 if `broadcast fp16` was required (#751) - OSS: fixed the buckets which would stay in fp16 if `broadcast fp16` was required [#751]
### Added ### Added
- FSDP: better performance; use `_allgather_base` and `_reduce_scatter_base` when they are - FSDP: better performance; use `_allgather_base` and `_reduce_scatter_base` when they are
......
...@@ -87,6 +87,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -87,6 +87,7 @@ class FullyShardedDataParallel(nn.Module):
""" """
A wrapper for sharding Module parameters across data parallel workers. This A wrapper for sharding Module parameters across data parallel workers. This
is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_. is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_.
FullyShardedDataParallel is commonly shorten to FSDP.
.. _`Xu et al.`: https://arxiv.org/abs/2004.13336 .. _`Xu et al.`: https://arxiv.org/abs/2004.13336
.. _DeepSpeed: https://www.deepspeed.ai/ .. _DeepSpeed: https://www.deepspeed.ai/
...@@ -94,9 +95,9 @@ class FullyShardedDataParallel(nn.Module): ...@@ -94,9 +95,9 @@ class FullyShardedDataParallel(nn.Module):
Usage:: Usage::
import torch import torch
from fairscale.nn.data_parallel import FullyShardedDataParallel from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
torch.cuda.set_device(device_id) torch.cuda.set_device(device_id)
sharded_module = FullyShardedDataParallel(my_module) sharded_module = FSDP(my_module)
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
x = sharded_module(x, y=3, z=torch.Tensor([1])) x = sharded_module(x, y=3, z=torch.Tensor([1]))
loss = x.sum() loss = x.sum()
...@@ -143,7 +144,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -143,7 +144,7 @@ class FullyShardedDataParallel(nn.Module):
Args: Args:
module (nn.Module): module (nn.Module):
module to be wrapped with FullyShardedDataParallel. module to be wrapped with FSDP.
process_group (Optional): process_group (Optional):
process group for sharding process group for sharding
reshard_after_forward (bool, Optional): reshard_after_forward (bool, Optional):
...@@ -627,7 +628,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -627,7 +628,7 @@ class FullyShardedDataParallel(nn.Module):
return getattr(self.module, name) return getattr(self.module, name)
def __getstate__(self) -> Dict[str, str]: def __getstate__(self) -> Dict[str, str]:
"""Serialize the state of the current FullyShardedDataParallel instance. """Serialize the state of the current FSDP instance.
Some properties are not serializable (e.g., process groups, streams), so Some properties are not serializable (e.g., process groups, streams), so
we remove them and try to reconstruct them in :func:`__setstate__`. we remove them and try to reconstruct them in :func:`__setstate__`.
...@@ -726,7 +727,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -726,7 +727,7 @@ class FullyShardedDataParallel(nn.Module):
""" """
Returns the local (sharded) state of the module. Parameters are sharded, Returns the local (sharded) state of the module. Parameters are sharded,
so the resulting state_dict can only be loaded after the Module has been so the resulting state_dict can only be loaded after the Module has been
wrapped with FullyShardedDataParallel. wrapped with FSDP.
""" """
with contextlib.ExitStack() as stack: with contextlib.ExitStack() as stack:
# Tell any nested FSDP instances not to auto summon full params. # Tell any nested FSDP instances not to auto summon full params.
...@@ -783,19 +784,23 @@ class FullyShardedDataParallel(nn.Module): ...@@ -783,19 +784,23 @@ class FullyShardedDataParallel(nn.Module):
@contextlib.contextmanager @contextlib.contextmanager
def no_sync(self) -> Generator: def no_sync(self) -> Generator:
""" """
A context manager to disable gradient synchronizations across DDP A context manager to disable gradient synchronizations across FSDP
processes. Within this context, gradients will be accumulated on module processes. Within this context, gradients will be accumulated on module
variables, which will later be synchronized in the first variables, which will later be synchronized in the first
forward-backward pass after exiting the context. forward-backward pass after exiting the context.
.. note:: This may result in higher memory usage because we will .. note:: This likely results in higher memory usage because FSDP will
accumulate the full model gradients (instead of gradient shards) accumulate the full model gradients (instead of gradient shards)
until the eventual sync. until the eventual sync.
.. note:: Gradient accumulation can be done without this context,
avoiding the extra GPU memory overhead, but with the extra
networking overhead.
""" """
self._lazy_init() self._lazy_init()
assert self._is_root, "no_sync on inner FSDP is not supported" assert self._is_root, "no_sync on inner FSDP is not supported"
self.assert_state(TrainingState.IDLE) self.assert_state(TrainingState.IDLE)
# This instance may wrap other FullyShardedDataParallel instances and we # This instance may wrap other FSDP instances and we
# need to set all of them to accumulate gradients. # need to set all of them to accumulate gradients.
old_flags = [] old_flags = []
for m in self.modules(): # includes self for m in self.modules(): # includes self
...@@ -806,6 +811,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -806,6 +811,7 @@ class FullyShardedDataParallel(nn.Module):
yield yield
finally: finally:
for m, old_flag in old_flags: for m, old_flag in old_flags:
assert m._require_backward_grad_sync is False
m._require_backward_grad_sync = old_flag m._require_backward_grad_sync = old_flag
@contextlib.contextmanager @contextlib.contextmanager
...@@ -999,7 +1005,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -999,7 +1005,7 @@ class FullyShardedDataParallel(nn.Module):
""" """
if self._is_root is not None: if self._is_root is not None:
return return
# No FullyShardedDataParallel instance wraps this, else _is_root would be set to False. # No FSDP instance wraps this, else _is_root would be set to False.
self._is_root = True self._is_root = True
# If final backward callback is never been queued, state should be IDLE. # If final backward callback is never been queued, state should be IDLE.
# If final backward callback is queued, the callback should be finished # If final backward callback is queued, the callback should be finished
...@@ -1168,7 +1174,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1168,7 +1174,7 @@ class FullyShardedDataParallel(nn.Module):
else: else:
self._use_full_params() self._use_full_params()
# Make sure p.grad has the correct size/device (or set it to None). # Prepare p.grad.
self._prep_grads_for_backward() self._prep_grads_for_backward()
def _register_hook(t: torch.Tensor) -> torch.Tensor: def _register_hook(t: torch.Tensor) -> torch.Tensor:
...@@ -1265,7 +1271,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1265,7 +1271,7 @@ class FullyShardedDataParallel(nn.Module):
return return
if param.grad.requires_grad: if param.grad.requires_grad:
raise RuntimeError("FullyShardedDataParallel only works with gradients that don't require gradients") raise RuntimeError("FSDP only works with gradients that don't require gradients")
# If this is a checkpointed module, we check if the following # If this is a checkpointed module, we check if the following
# counter reaches 0. If not, it is not the final backward call # counter reaches 0. If not, it is not the final backward call
...@@ -1278,7 +1284,8 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1278,7 +1284,8 @@ class FullyShardedDataParallel(nn.Module):
# Free full params. As a special case, we don't free the full params # Free full params. As a special case, we don't free the full params
# when in a ``no_sync`` context (as inversely indicated by # when in a ``no_sync`` context (as inversely indicated by
# ``self._require_backward_grad_sync``), since the params will not # ``self._require_backward_grad_sync``), since the params will not
# get updated before the next forward. # get updated before the next forward. This saves networking
# bandwidth but uses more GPU memory.
self._free_full_params([param]) self._free_full_params([param])
if self.mixed_precision: if self.mixed_precision:
...@@ -1344,6 +1351,12 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1344,6 +1351,12 @@ class FullyShardedDataParallel(nn.Module):
param.grad.data = param.grad.data.to(dtype=param.data.dtype) param.grad.data = param.grad.data.to(dtype=param.data.dtype)
# Don't let this memory get reused until after the transfer. # Don't let this memory get reused until after the transfer.
orig_param_grad_data.record_stream(torch.cuda.current_stream()) orig_param_grad_data.record_stream(torch.cuda.current_stream())
if hasattr(param, "_saved_grad_shard") and param._saved_grad_shard is not None:
assert (
param._saved_grad_shard.shape == param.grad.shape
), f"{param._saved_grad_shard.shape} vs {param.grad.shape}"
param.grad.data += param._saved_grad_shard
delattr(param, "_saved_grad_shard")
# Optionally move gradients to CPU, typically used if one is running # Optionally move gradients to CPU, typically used if one is running
# the optimizer on the CPU. # the optimizer on the CPU.
if self.move_grads_to_cpu: if self.move_grads_to_cpu:
...@@ -1536,10 +1549,23 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1536,10 +1549,23 @@ class FullyShardedDataParallel(nn.Module):
@torch.no_grad() @torch.no_grad()
def _prep_grads_for_backward(self) -> None: def _prep_grads_for_backward(self) -> None:
"""Make sure p.grad has the correct size/device, otherwise set it to None.""" """Make sure p.grad is correctly prepared for the backward."""
for p in self.params: for p in self.params:
if p.grad is not None and (p.grad.size() != p._orig_size or p.grad.device != p.data.device): if p.grad is not None:
p.grad = None if p.grad.device != p.data.device:
p.grad = None
elif p.grad.size() == p._orig_size:
# This is gradient accumulation with no_sync context.
pass
elif p.grad.size() == p._fp32_shard.shape:
# This is gradient accumulation without no_sync context.
# We save the grad shard and set p.grad to None for this backward pass.
# We will accumulate after this pass's grad is generated and reduced and
# sharded.
p._saved_grad_shard = p.grad.data
p.grad = None
else:
raise AssertionError(f"unexpected grad shape: {p.grad.size()}")
@torch.no_grad() @torch.no_grad()
def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None: def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None:
......
...@@ -214,8 +214,13 @@ def spawn_for_all_world_sizes(test_func: Callable, world_sizes: List[int] = get_ ...@@ -214,8 +214,13 @@ def spawn_for_all_world_sizes(test_func: Callable, world_sizes: List[int] = get_
_, filename = tempfile.mkstemp() _, filename = tempfile.mkstemp()
_, filename_rpc = tempfile.mkstemp() _, filename_rpc = tempfile.mkstemp()
# (lefaudeux) Let mp handle the process joining, join=False and handling context has been unstable in the past try:
mp.spawn(test_func, args=(world_size, filename, filename_rpc, *args), nprocs=world_size, join=True) # (lefaudeux) Let mp handle the process joining, join=False and handling context has
# been unstable in the past.
mp.spawn(test_func, args=(world_size, filename, filename_rpc, *args), nprocs=world_size, join=True)
finally:
rmf(filename)
rmf(filename_rpc)
def worker_process( def worker_process(
......
...@@ -15,6 +15,7 @@ class Parameter(Tensor): ...@@ -15,6 +15,7 @@ class Parameter(Tensor):
_fp32_shard: Tensor _fp32_shard: Tensor
_fp16_shard: Optional[Tensor] _fp16_shard: Optional[Tensor]
_shard_bwd_hook: Tuple[Any, Any] _shard_bwd_hook: Tuple[Any, Any]
_saved_grad_shard: Tensor
def __new__(cls, data: Tensor, requires_grad: builtins.bool = True): ... def __new__(cls, data: Tensor, requires_grad: builtins.bool = True): ...
......
...@@ -4,7 +4,7 @@ tests/nn/data_parallel/test_fsdp_freezing_weights.py ...@@ -4,7 +4,7 @@ tests/nn/data_parallel/test_fsdp_freezing_weights.py
tests/nn/data_parallel/test_fsdp_regnet.py tests/nn/data_parallel/test_fsdp_regnet.py
tests/nn/data_parallel/test_fsdp_uneven.py tests/nn/data_parallel/test_fsdp_uneven.py
tests/nn/data_parallel/test_fsdp_grad_scaler.py tests/nn/data_parallel/test_fsdp_grad_scaler.py
tests/nn/data_parallel/test_fsdp_no_sync.py tests/nn/data_parallel/test_fsdp_grad_acc.py
tests/nn/data_parallel/test_fsdp_summon_full_params.py tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_fsdp_input.py tests/nn/data_parallel/test_fsdp_input.py
tests/nn/data_parallel/test_fsdp_optimizer_utils.py tests/nn/data_parallel/test_fsdp_optimizer_utils.py
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import contextlib
import functools import functools
import itertools import itertools
import unittest import unittest
...@@ -17,11 +18,15 @@ from fairscale.utils.testing import DummyProcessGroup, objects_are_equal ...@@ -17,11 +18,15 @@ from fairscale.utils.testing import DummyProcessGroup, objects_are_equal
from .test_fsdp import DistributedTest, NestedWrappedModule, rename_test, spawn_and_init from .test_fsdp import DistributedTest, NestedWrappedModule, rename_test, spawn_and_init
class TestNoSync(DistributedTest): class TestGradAcc(DistributedTest):
def test_transformer(self): def test_transformer(self):
fn = functools.partial(self._test_transformer, config={}) fn = functools.partial(self._test_transformer, config={})
spawn_and_init(fn) spawn_and_init(fn)
def test_transformer_grad_acc_without_no_sync(self):
fn = functools.partial(self._test_transformer, config={}, use_no_sync_context=False)
spawn_and_init(fn)
def test_transformer_no_flat_params(self): def test_transformer_no_flat_params(self):
config = {"flatten_parameters": False} config = {"flatten_parameters": False}
fn = functools.partial(self._test_transformer, config=config) fn = functools.partial(self._test_transformer, config=config)
...@@ -44,22 +49,25 @@ class TestNoSync(DistributedTest): ...@@ -44,22 +49,25 @@ class TestNoSync(DistributedTest):
loss.backward() loss.backward()
@classmethod @classmethod
def _test_transformer(self, rank, group, config): def _test_transformer(self, rank, group, config, use_no_sync_context=True):
model = self.get_wrapped_model(group, config=config, add_bn=False) model = self.get_wrapped_model(group, config=config, add_bn=False)
model.eval() # turn off dropout for the test model.eval() # turn off dropout for the test
self._test_no_sync(model, batch_dim=1) self._test_grad_acc(model, batch_dim=1, use_no_sync_context=use_no_sync_context)
@classmethod @classmethod
def _test_nested_wrapper(self, rank, group, config): def _test_nested_wrapper(self, rank, group, config):
model = NestedWrappedModule(group, config) model = NestedWrappedModule(group, config)
model = FullyShardedDataParallel(model, group, **config).cuda() model = FullyShardedDataParallel(model, group, **config).cuda()
self._test_no_sync(model, batch_dim=0) self._test_grad_acc(model, batch_dim=0)
@classmethod @classmethod
def _test_no_sync(self, model, batch_dim): def _test_grad_acc(self, model, batch_dim, use_no_sync_context=True):
# Generate two input batches. We'll test that we get the same grads if # Generate two input batches. We'll test that we get the same grads if
# we train on them sequentially while accumulating grads (with no_sync) # we train on them sequentially while accumulating grads (with no_sync
# vs. concatenating the batches and training in one go. # or without no_sync) vs. concatenating the batches and training in one go.
#
# The difference between with no_sync and without is GPU memory vs. networking
# bandwidth tradeoff.
batch1 = model.module.get_input(torch.device("cuda")) batch1 = model.module.get_input(torch.device("cuda"))
assert isinstance(batch1, tuple) assert isinstance(batch1, tuple)
batch2 = tuple( batch2 = tuple(
...@@ -82,7 +90,10 @@ class TestNoSync(DistributedTest): ...@@ -82,7 +90,10 @@ class TestNoSync(DistributedTest):
# Test that we get the same results by accumulating grads. # Test that we get the same results by accumulating grads.
model.zero_grad() model.zero_grad()
with model.no_sync(): # accumulate gradients from the first batch context = contextlib.suppress()
if use_no_sync_context:
context = model.no_sync()
with context: # accumulate gradients from the first batch
output = model(*batch1) output = model(*batch1)
loss1 = model.module.get_loss(batch1, output) loss1 = model.module.get_loss(batch1, output)
loss1.backward() loss1.backward()
...@@ -100,7 +111,7 @@ keys = ["reshard_after_forward", "mixed_precision"] ...@@ -100,7 +111,7 @@ keys = ["reshard_after_forward", "mixed_precision"]
COMM_CONFIG_OPTIONS = [[dict(zip(keys, config))] for config in itertools.product([True, False], repeat=len(keys))] COMM_CONFIG_OPTIONS = [[dict(zip(keys, config))] for config in itertools.product([True, False], repeat=len(keys))]
class TestNoSyncCommunication(DistributedTest): class TestGradAccCommunication(DistributedTest):
@parameterized.expand(COMM_CONFIG_OPTIONS, name_func=rename_test) @parameterized.expand(COMM_CONFIG_OPTIONS, name_func=rename_test)
def test_communication(self, config): def test_communication(self, config):
fn = functools.partial(self._test_communication, config=config) fn = functools.partial(self._test_communication, config=config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment