"docs/vscode:/vscode.git/clone" did not exist on "03df281275ad3fcb732a41ab1638c2e89afddb25"
Unverified Commit 1204c7cf authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

[perf] FSDP: speed up no_sync and test communication volume (#470)

parent 0b8d0753
......@@ -211,6 +211,9 @@ class FullyShardedDataParallel(nn.Module):
# Enum to indicate if we're in the forward/backward pass, idle, etc.
self.training_state = TrainingState.IDLE
# Flag to indicate if the full params are gathered.
self.has_full_params: bool = False
# Register hook after state_dict() to remove the "_fsdp_wrapped_module."
# prefix and before load_state_dict() to add it back.
self._register_state_dict_hook(_post_state_dict_hook)
......@@ -511,7 +514,11 @@ class FullyShardedDataParallel(nn.Module):
A context manager to disable gradient synchronizations across DDP
processes. Within this context, gradients will be accumulated on module
variables, which will later be synchronized in the first
forward-backward pass exiting the context.
forward-backward pass after exiting the context.
.. note:: This may result in higher memory usage because we will
accumulate the full model gradients (instead of gradient shards)
until the eventual sync.
"""
self._lazy_init()
assert self._is_root, "no_sync on inner FSDP is not supported"
......@@ -575,6 +582,7 @@ class FullyShardedDataParallel(nn.Module):
# forward/backward.
self.training_state = TrainingState.SUMMON_FULL_PARAMS
full_tensors = self._rebuild_full_params(full_precision=True)
assert full_tensors is not None
with contextlib.ExitStack() as stack:
if self.flatten_parameters and self.module.is_flattened:
# Update flattened views to point to fully-sized tensors. We
......@@ -596,6 +604,7 @@ class FullyShardedDataParallel(nn.Module):
p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard))
if safe_to_free:
free_storage_(full_tensor)
self.has_full_params = False
self._use_fp32_param_shard()
self.training_state = TrainingState.IDLE
......@@ -833,6 +842,7 @@ class FullyShardedDataParallel(nn.Module):
self._rebuild_full_params()
else:
self._use_full_params()
# Make sure p.grad has the correct size/device (or set it to None).
self._prep_grads_for_backward()
......@@ -891,15 +901,22 @@ class FullyShardedDataParallel(nn.Module):
if param.grad.requires_grad:
raise RuntimeError("FullyShardedDataParallel only works with gradients that don't require grad")
# Free full params and switch to FP32 shard after backward.
self._free_full_params([param])
self._use_fp32_param_shard([param])
if not self._is_root or self._require_backward_grad_sync:
# Free full params. As a special case, we don't free the full params
# on the root instance when in a ``no_sync`` context (as indicated
# by ``self._require_backward_grad_sync``), since we will need the
# params again immediately for the next forward.
self._free_full_params([param])
if self.mixed_precision:
# This is a no-op if reshard_after_forward is True, since we already
# free the param shard when rebuilding the full params in the
# pre_backward_hook.
self._free_fp16_param_shard([param])
# Switch to FP32 shard after backward.
self._use_fp32_param_shard([param])
# (try to) Enqueue a callback at the end of the backward pass to ensure that all
# post-backward work has finished. We only need one callback and all instances
# of FSDP (root and children) make this attempt here to queue to ensure it is queued
......@@ -966,9 +983,10 @@ class FullyShardedDataParallel(nn.Module):
def _queue_wait_for_post_backward(self) -> None:
"""Try to queue a `wait_for_post_backward` callback.
Only called on root and only queue one callback.
But can be called by children FSDPs via a closure in case the
root instance doesn't own any params.
Only called on root and only queue one callback. But can be called by
children FSDPs via a closure in case the root instance doesn't own any
params.
"""
assert self._is_root
self.assert_state(TrainingState.BACKWARD)
......@@ -978,18 +996,18 @@ class FullyShardedDataParallel(nn.Module):
@torch.no_grad()
def _wait_for_post_backward(self) -> None:
"""Wait for post-backward work to finish. Only called on root instance.
"""
"""Wait for post-backward to finish. Only called on root instance."""
assert self._is_root
self.assert_state(TrainingState.BACKWARD)
# Flush any unreduced buckets in the post_backward stream.
with torch.cuda.stream(self._streams["post_backward"]):
assert self._reducer is not None
self._reducer.flush()
torch.cuda.current_stream().wait_stream(self._streams["post_backward"])
if self.move_grads_to_cpu:
# Wait for the non-blocking GPU -> CPU grad transfers to finish.
torch.cuda.current_stream().synchronize()
if self._require_backward_grad_sync:
# Flush any unreduced buckets in the post_backward stream.
with torch.cuda.stream(self._streams["post_backward"]):
assert self._reducer is not None
self._reducer.flush()
torch.cuda.current_stream().wait_stream(self._streams["post_backward"])
if self.move_grads_to_cpu:
# Wait for the non-blocking GPU -> CPU grad transfers to finish.
torch.cuda.current_stream().synchronize()
# A backward pass is done, update root and nested FSDP's flags.
for m in self.modules(): # includes self
if isinstance(m, FullyShardedDataParallel):
......@@ -997,7 +1015,7 @@ class FullyShardedDataParallel(nn.Module):
m.training_state = TrainingState.IDLE
@torch.no_grad()
def _rebuild_full_params(self, full_precision: bool = False) -> List[Tuple[torch.Tensor, bool]]:
def _rebuild_full_params(self, full_precision: bool = False) -> Optional[List[Tuple[torch.Tensor, bool]]]:
"""
Gather all shards of params.
......@@ -1008,57 +1026,82 @@ class FullyShardedDataParallel(nn.Module):
(e.g., FP32), possibly in fresh storage.
Returns:
a list of tuples, where the first element is the full-sized param
A list of tuples, where the first element is the full-sized param
and the second element is a bool indicating if it's safe for the
caller to free the full-sized param
caller to free the full-sized param. This will be ``None`` if
``full_precision=False`` and the full params are already gathered.
"""
output_tensors: List[Tuple[torch.Tensor, bool]] = []
def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:
if custom_output_tensor is not None:
assert p._is_sharded
p.data = custom_output_tensor
output_tensors.append((p.data, True))
elif not p._is_sharded:
if self.mixed_precision and not full_precision:
p.data = p._fp16_shard
output_tensors.append((p.data, True))
else:
# Here p.data == p._fp32_shard, so it's not safe to free.
output_tensors.append((p.data, False))
else:
p.data = p._full_param_padded
output_tensors.append((p.data, True))
# Trim any padding and reshape to match original size.
p.data = p.data[: p._orig_size.numel()].view(p._orig_size)
# Early exit if we already have full params and don't need full precision.
if self.has_full_params and not full_precision:
for p in self.params:
update_p_data()
return output_tensors
self.has_full_params = True
with torch.cuda.stream(self._streams["all_gather"]):
if self.mixed_precision and not full_precision:
self._cast_fp32_param_shards_to_fp16()
for p in self.params:
if not p._is_sharded: # e.g., when world_size == 1
if self.mixed_precision and not full_precision:
p.data = p._fp16_shard
output_tensors.append((p.data, True))
else:
output_tensors.append((p.data, False))
continue
# If self.cpu_offload and full_precision, we need to cast the
# FP32 CPU param to CUDA for the all-gather.
p_data = p.data.to(p._full_param_padded.device)
p_size = p._full_param_padded.size()
assert p_size.numel() % self.world_size == 0
if not self.mixed_precision or not full_precision:
if p._full_param_padded.storage().size() != p_size.numel():
# Allocate based on full size from all shards.
alloc_storage_(p._full_param_padded, size=p_size)
output_tensor = p._full_param_padded
update_p_data()
else:
# Allocate fresh tensor in full precision.
output_tensor = p_data.new_zeros(p_size)
output_tensors.append((output_tensor, True))
# If self.cpu_offload and full_precision, we need to cast
# the FP32 CPU param to CUDA for the all-gather.
p_data = p.data.to(p._full_param_padded.device)
p_size = p._full_param_padded.size()
assert p_size.numel() % self.world_size == 0
if not self.mixed_precision or not full_precision:
if p._full_param_padded.storage().size() != p_size.numel():
# Allocate based on full size from all shards.
alloc_storage_(p._full_param_padded, size=p_size)
output_tensor = p._full_param_padded
else:
# Allocate fresh tensor in full precision.
output_tensor = p_data.new_zeros(p_size)
# Fill output_tensor with (p.data for each shard in self.world_size)
chunks = list(output_tensor.chunk(self.world_size))
dist.all_gather(chunks, p_data, group=self.process_group)
# Fill output_tensor with (p.data for each shard in self.world_size)
chunks = list(output_tensor.chunk(self.world_size))
dist.all_gather(chunks, p_data, group=self.process_group)
p.data = output_tensor[: p._orig_size.numel()].view(p._orig_size)
# Set p.data = output_tensor (with padding trimmed)
update_p_data(output_tensor)
if self.mixed_precision and not full_precision:
self._free_fp16_param_shard([p])
if self.mixed_precision and not full_precision:
self._free_fp16_param_shard([p])
torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
return output_tensors
@torch.no_grad()
def _use_full_params(self) -> None:
"""Switching p.data pointers to use the full params.
"""
Switch p.data pointers to use the full params.
Note: this is used assuming full param gathering is already done.
Note: this assumes full params are already gathered.
"""
assert self.has_full_params
for p in self.params:
if not p._is_sharded:
if self.mixed_precision:
......@@ -1080,6 +1123,7 @@ class FullyShardedDataParallel(nn.Module):
"""Free up storage for full parameters."""
if params is None:
params = self.params
self.has_full_params = False
current_stream = torch.cuda.current_stream()
with torch.cuda.stream(self._streams["all_gather"]):
for p in params:
......@@ -1176,7 +1220,6 @@ def free_storage_(data: torch.Tensor) -> None:
# Since we're modifying the Tensor's Storage directly, make sure the Tensor
# is the sole occupant of the Storage.
assert data.storage_offset() == 0
assert data.storage().size() == data.numel()
data.storage().resize_(0)
......
tests/nn/data_parallel/test_fsdp_uneven.py
tests/nn/data_parallel/test_fsdp_grad_scaler.py
tests/nn/data_parallel/test_fsdp_no_sync.py
tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_sharded_ddp_features.py
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
......
......@@ -660,86 +660,7 @@ class TestNoGrad(DistributedTest):
with torch.no_grad():
no_grad_output = model(*input)
assert objects_are_equal(ref_output, no_grad_output), "no_grad_output did not match ref_output"
class TestNoSync(DistributedTest):
def test_transformer(self):
fn = functools.partial(self._test_transformer, config={})
spawn_and_init(fn)
def test_transformer_no_flat_params(self):
config = {"flatten_parameters": False}
fn = functools.partial(self._test_transformer, config=config)
spawn_and_init(fn)
def test_nested_wrapper(self):
fn = functools.partial(self._test_nested_wrapper, config={})
spawn_and_init(fn)
def test_no_sync_before_first_forward(self):
group = DummyProcessGroup(rank=0, size=1)
model = self.get_wrapped_model(group, config={}, add_bn=False)
batch = model.module.get_input(torch.device("cuda"))
with model.no_sync():
output = model(*batch)
loss = model.module.get_loss(batch, output)
loss.backward()
output = model(*batch)
loss = model.module.get_loss(batch, output)
loss.backward()
@classmethod
def _test_transformer(self, rank, group, config):
model = self.get_wrapped_model(group, config=config, add_bn=False)
model.eval() # turn off dropout for the test
self._test_no_sync(model, batch_dim=1)
@classmethod
def _test_nested_wrapper(self, rank, group, config):
model = NestedWrappedModule(group, config)
model = FullyShardedDataParallel(model, group, **config).cuda()
self._test_no_sync(model, batch_dim=0)
@classmethod
def _test_no_sync(self, model, batch_dim):
# Generate two input batches. We'll test that we get the same grads if
# we train on them sequentially while accumulating grads (with no_sync)
# vs. concatenating the batches and training in one go.
batch1 = model.module.get_input(torch.device("cuda"))
assert isinstance(batch1, tuple)
batch2 = tuple(
# This randomly permutes the values in a multi-dim tensor.
x.view(-1)[torch.randperm(x.numel())].view_as(x)
for x in batch1
)
for x, y in zip(batch1, batch2):
assert not torch.all(x == y)
# Concat the batches along batch dimension.
concat_batch = tuple(torch.cat((x, y), dim=batch_dim) for (x, y) in zip(batch1, batch2))
# Establish reference behavior on the concat batch.
model.zero_grad()
output = model(*concat_batch)
ref_loss = model.module.get_loss(concat_batch, output)
ref_loss.backward()
ref_grads = [p.grad.detach().clone() for p in model.parameters()]
# Test that we get the same results by accumulating grads.
model.zero_grad()
with model.no_sync(): # accumulate gradients from the first batch
output = model(*batch1)
loss1 = model.module.get_loss(batch1, output)
loss1.backward()
output = model(*batch2)
loss2 = model.module.get_loss(batch2, output)
loss2.backward()
accumulated_loss = loss1 + loss2
accumulated_grads = [p.grad.detach().clone() for p in model.parameters()]
torch.testing.assert_allclose(ref_loss, accumulated_loss)
assert objects_are_equal(ref_grads, accumulated_grads, raise_exception=True)
assert objects_are_equal(ref_output, no_grad_output, raise_exception=True)
class TransformerWithSharedParams(nn.Module):
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import functools
import unittest
from unittest.mock import patch
import torch
from fairscale.nn.data_parallel import FullyShardedDataParallel
from fairscale.utils.testing import DummyProcessGroup, objects_are_equal
from .test_fsdp import DistributedTest, NestedWrappedModule, spawn_and_init
class TestNoSync(DistributedTest):
def test_transformer(self):
fn = functools.partial(self._test_transformer, config={})
spawn_and_init(fn)
def test_transformer_no_flat_params(self):
config = {"flatten_parameters": False}
fn = functools.partial(self._test_transformer, config=config)
spawn_and_init(fn)
def test_nested_wrapper(self):
fn = functools.partial(self._test_nested_wrapper, config={})
spawn_and_init(fn)
def test_no_sync_before_first_forward(self):
group = DummyProcessGroup(rank=0, size=1)
model = self.get_wrapped_model(group, config={}, add_bn=False)
batch = model.module.get_input(torch.device("cuda"))
with model.no_sync():
output = model(*batch)
loss = model.module.get_loss(batch, output)
loss.backward()
output = model(*batch)
loss = model.module.get_loss(batch, output)
loss.backward()
@classmethod
def _test_transformer(self, rank, group, config):
model = self.get_wrapped_model(group, config=config, add_bn=False)
model.eval() # turn off dropout for the test
self._test_no_sync(model, batch_dim=1)
@classmethod
def _test_nested_wrapper(self, rank, group, config):
model = NestedWrappedModule(group, config)
model = FullyShardedDataParallel(model, group, **config).cuda()
self._test_no_sync(model, batch_dim=0)
@classmethod
def _test_no_sync(self, model, batch_dim):
# Generate two input batches. We'll test that we get the same grads if
# we train on them sequentially while accumulating grads (with no_sync)
# vs. concatenating the batches and training in one go.
batch1 = model.module.get_input(torch.device("cuda"))
assert isinstance(batch1, tuple)
batch2 = tuple(
# This randomly permutes the values in a multi-dim tensor.
x.view(-1)[torch.randperm(x.numel())].view_as(x)
for x in batch1
)
for x, y in zip(batch1, batch2):
assert not torch.all(x == y)
# Concat the batches along batch dimension.
concat_batch = tuple(torch.cat((x, y), dim=batch_dim) for (x, y) in zip(batch1, batch2))
# Establish reference behavior on the concat batch.
model.zero_grad()
output = model(*concat_batch)
ref_loss = model.module.get_loss(concat_batch, output)
ref_loss.backward()
ref_grads = [p.grad.detach().clone() for p in model.parameters()]
# Test that we get the same results by accumulating grads.
model.zero_grad()
with model.no_sync(): # accumulate gradients from the first batch
output = model(*batch1)
loss1 = model.module.get_loss(batch1, output)
loss1.backward()
output = model(*batch2)
loss2 = model.module.get_loss(batch2, output)
loss2.backward()
accumulated_loss = loss1 + loss2
accumulated_grads = [p.grad.detach().clone() for p in model.parameters()]
torch.testing.assert_allclose(ref_loss, accumulated_loss)
assert objects_are_equal(ref_grads, accumulated_grads, raise_exception=True)
class TestNoSyncCommunication(DistributedTest):
def test_communication(self):
config = {"mixed_precision": True}
fn = functools.partial(self._test_communication, config=config)
spawn_and_init(fn)
@classmethod
def _test_communication(self, rank, group, config):
if group.size() == 1:
return
model = self.get_wrapped_model(group, config=config)
batch = model.module.get_input(torch.device("cuda"))
with patch("torch.distributed.all_gather") as mock_all_gather:
with patch("torch.distributed.reduce_scatter") as mock_reduce_scatter:
with model.no_sync():
output = model(*batch)
loss = model.module.get_loss(batch, output)
loss.backward()
assert mock_all_gather.call_count == 1
assert mock_reduce_scatter.call_count == 0
output = model(*batch)
loss = model.module.get_loss(batch, output)
loss.backward()
assert mock_all_gather.call_count == 1
assert mock_reduce_scatter.call_count == 1
if __name__ == "__main__":
unittest.main()
......@@ -48,7 +48,11 @@ class TestMemory(DistributedTest):
del state_dict
mems.append(get_cuda_mem())
assert mems[4] == mems[0]
# Any value other than `==` indicates a memory leak. If mems[4] >
# mems[0], that indicates we're not cleaning up params properly in
# summon_full_params. If mems[4] < mems[0], that indicates there's a
# memory leak in _train_for_several_steps.
assert mems[4] == mems[0], f"memory leak detected, {mems[4]} != {mems[0]}"
class TestPersistence(DistributedTest):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment