Unverified Commit 1204c7cf authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

[perf] FSDP: speed up no_sync and test communication volume (#470)

parent 0b8d0753
...@@ -211,6 +211,9 @@ class FullyShardedDataParallel(nn.Module): ...@@ -211,6 +211,9 @@ class FullyShardedDataParallel(nn.Module):
# Enum to indicate if we're in the forward/backward pass, idle, etc. # Enum to indicate if we're in the forward/backward pass, idle, etc.
self.training_state = TrainingState.IDLE self.training_state = TrainingState.IDLE
# Flag to indicate if the full params are gathered.
self.has_full_params: bool = False
# Register hook after state_dict() to remove the "_fsdp_wrapped_module." # Register hook after state_dict() to remove the "_fsdp_wrapped_module."
# prefix and before load_state_dict() to add it back. # prefix and before load_state_dict() to add it back.
self._register_state_dict_hook(_post_state_dict_hook) self._register_state_dict_hook(_post_state_dict_hook)
...@@ -511,7 +514,11 @@ class FullyShardedDataParallel(nn.Module): ...@@ -511,7 +514,11 @@ class FullyShardedDataParallel(nn.Module):
A context manager to disable gradient synchronizations across DDP A context manager to disable gradient synchronizations across DDP
processes. Within this context, gradients will be accumulated on module processes. Within this context, gradients will be accumulated on module
variables, which will later be synchronized in the first variables, which will later be synchronized in the first
forward-backward pass exiting the context. forward-backward pass after exiting the context.
.. note:: This may result in higher memory usage because we will
accumulate the full model gradients (instead of gradient shards)
until the eventual sync.
""" """
self._lazy_init() self._lazy_init()
assert self._is_root, "no_sync on inner FSDP is not supported" assert self._is_root, "no_sync on inner FSDP is not supported"
...@@ -575,6 +582,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -575,6 +582,7 @@ class FullyShardedDataParallel(nn.Module):
# forward/backward. # forward/backward.
self.training_state = TrainingState.SUMMON_FULL_PARAMS self.training_state = TrainingState.SUMMON_FULL_PARAMS
full_tensors = self._rebuild_full_params(full_precision=True) full_tensors = self._rebuild_full_params(full_precision=True)
assert full_tensors is not None
with contextlib.ExitStack() as stack: with contextlib.ExitStack() as stack:
if self.flatten_parameters and self.module.is_flattened: if self.flatten_parameters and self.module.is_flattened:
# Update flattened views to point to fully-sized tensors. We # Update flattened views to point to fully-sized tensors. We
...@@ -596,6 +604,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -596,6 +604,7 @@ class FullyShardedDataParallel(nn.Module):
p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard)) p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard))
if safe_to_free: if safe_to_free:
free_storage_(full_tensor) free_storage_(full_tensor)
self.has_full_params = False
self._use_fp32_param_shard() self._use_fp32_param_shard()
self.training_state = TrainingState.IDLE self.training_state = TrainingState.IDLE
...@@ -833,6 +842,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -833,6 +842,7 @@ class FullyShardedDataParallel(nn.Module):
self._rebuild_full_params() self._rebuild_full_params()
else: else:
self._use_full_params() self._use_full_params()
# Make sure p.grad has the correct size/device (or set it to None). # Make sure p.grad has the correct size/device (or set it to None).
self._prep_grads_for_backward() self._prep_grads_for_backward()
...@@ -891,15 +901,22 @@ class FullyShardedDataParallel(nn.Module): ...@@ -891,15 +901,22 @@ class FullyShardedDataParallel(nn.Module):
if param.grad.requires_grad: if param.grad.requires_grad:
raise RuntimeError("FullyShardedDataParallel only works with gradients that don't require grad") raise RuntimeError("FullyShardedDataParallel only works with gradients that don't require grad")
# Free full params and switch to FP32 shard after backward. if not self._is_root or self._require_backward_grad_sync:
self._free_full_params([param]) # Free full params. As a special case, we don't free the full params
self._use_fp32_param_shard([param]) # on the root instance when in a ``no_sync`` context (as indicated
# by ``self._require_backward_grad_sync``), since we will need the
# params again immediately for the next forward.
self._free_full_params([param])
if self.mixed_precision: if self.mixed_precision:
# This is a no-op if reshard_after_forward is True, since we already # This is a no-op if reshard_after_forward is True, since we already
# free the param shard when rebuilding the full params in the # free the param shard when rebuilding the full params in the
# pre_backward_hook. # pre_backward_hook.
self._free_fp16_param_shard([param]) self._free_fp16_param_shard([param])
# Switch to FP32 shard after backward.
self._use_fp32_param_shard([param])
# (try to) Enqueue a callback at the end of the backward pass to ensure that all # (try to) Enqueue a callback at the end of the backward pass to ensure that all
# post-backward work has finished. We only need one callback and all instances # post-backward work has finished. We only need one callback and all instances
# of FSDP (root and children) make this attempt here to queue to ensure it is queued # of FSDP (root and children) make this attempt here to queue to ensure it is queued
...@@ -966,9 +983,10 @@ class FullyShardedDataParallel(nn.Module): ...@@ -966,9 +983,10 @@ class FullyShardedDataParallel(nn.Module):
def _queue_wait_for_post_backward(self) -> None: def _queue_wait_for_post_backward(self) -> None:
"""Try to queue a `wait_for_post_backward` callback. """Try to queue a `wait_for_post_backward` callback.
Only called on root and only queue one callback.
But can be called by children FSDPs via a closure in case the Only called on root and only queue one callback. But can be called by
root instance doesn't own any params. children FSDPs via a closure in case the root instance doesn't own any
params.
""" """
assert self._is_root assert self._is_root
self.assert_state(TrainingState.BACKWARD) self.assert_state(TrainingState.BACKWARD)
...@@ -978,18 +996,18 @@ class FullyShardedDataParallel(nn.Module): ...@@ -978,18 +996,18 @@ class FullyShardedDataParallel(nn.Module):
@torch.no_grad() @torch.no_grad()
def _wait_for_post_backward(self) -> None: def _wait_for_post_backward(self) -> None:
"""Wait for post-backward work to finish. Only called on root instance. """Wait for post-backward to finish. Only called on root instance."""
"""
assert self._is_root assert self._is_root
self.assert_state(TrainingState.BACKWARD) self.assert_state(TrainingState.BACKWARD)
# Flush any unreduced buckets in the post_backward stream. if self._require_backward_grad_sync:
with torch.cuda.stream(self._streams["post_backward"]): # Flush any unreduced buckets in the post_backward stream.
assert self._reducer is not None with torch.cuda.stream(self._streams["post_backward"]):
self._reducer.flush() assert self._reducer is not None
torch.cuda.current_stream().wait_stream(self._streams["post_backward"]) self._reducer.flush()
if self.move_grads_to_cpu: torch.cuda.current_stream().wait_stream(self._streams["post_backward"])
# Wait for the non-blocking GPU -> CPU grad transfers to finish. if self.move_grads_to_cpu:
torch.cuda.current_stream().synchronize() # Wait for the non-blocking GPU -> CPU grad transfers to finish.
torch.cuda.current_stream().synchronize()
# A backward pass is done, update root and nested FSDP's flags. # A backward pass is done, update root and nested FSDP's flags.
for m in self.modules(): # includes self for m in self.modules(): # includes self
if isinstance(m, FullyShardedDataParallel): if isinstance(m, FullyShardedDataParallel):
...@@ -997,7 +1015,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -997,7 +1015,7 @@ class FullyShardedDataParallel(nn.Module):
m.training_state = TrainingState.IDLE m.training_state = TrainingState.IDLE
@torch.no_grad() @torch.no_grad()
def _rebuild_full_params(self, full_precision: bool = False) -> List[Tuple[torch.Tensor, bool]]: def _rebuild_full_params(self, full_precision: bool = False) -> Optional[List[Tuple[torch.Tensor, bool]]]:
""" """
Gather all shards of params. Gather all shards of params.
...@@ -1008,57 +1026,82 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1008,57 +1026,82 @@ class FullyShardedDataParallel(nn.Module):
(e.g., FP32), possibly in fresh storage. (e.g., FP32), possibly in fresh storage.
Returns: Returns:
a list of tuples, where the first element is the full-sized param A list of tuples, where the first element is the full-sized param
and the second element is a bool indicating if it's safe for the and the second element is a bool indicating if it's safe for the
caller to free the full-sized param caller to free the full-sized param. This will be ``None`` if
``full_precision=False`` and the full params are already gathered.
""" """
output_tensors: List[Tuple[torch.Tensor, bool]] = [] output_tensors: List[Tuple[torch.Tensor, bool]] = []
def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:
if custom_output_tensor is not None:
assert p._is_sharded
p.data = custom_output_tensor
output_tensors.append((p.data, True))
elif not p._is_sharded:
if self.mixed_precision and not full_precision:
p.data = p._fp16_shard
output_tensors.append((p.data, True))
else:
# Here p.data == p._fp32_shard, so it's not safe to free.
output_tensors.append((p.data, False))
else:
p.data = p._full_param_padded
output_tensors.append((p.data, True))
# Trim any padding and reshape to match original size.
p.data = p.data[: p._orig_size.numel()].view(p._orig_size)
# Early exit if we already have full params and don't need full precision.
if self.has_full_params and not full_precision:
for p in self.params:
update_p_data()
return output_tensors
self.has_full_params = True
with torch.cuda.stream(self._streams["all_gather"]): with torch.cuda.stream(self._streams["all_gather"]):
if self.mixed_precision and not full_precision: if self.mixed_precision and not full_precision:
self._cast_fp32_param_shards_to_fp16() self._cast_fp32_param_shards_to_fp16()
for p in self.params: for p in self.params:
if not p._is_sharded: # e.g., when world_size == 1 if not p._is_sharded: # e.g., when world_size == 1
if self.mixed_precision and not full_precision: update_p_data()
p.data = p._fp16_shard
output_tensors.append((p.data, True))
else:
output_tensors.append((p.data, False))
continue
# If self.cpu_offload and full_precision, we need to cast the
# FP32 CPU param to CUDA for the all-gather.
p_data = p.data.to(p._full_param_padded.device)
p_size = p._full_param_padded.size()
assert p_size.numel() % self.world_size == 0
if not self.mixed_precision or not full_precision:
if p._full_param_padded.storage().size() != p_size.numel():
# Allocate based on full size from all shards.
alloc_storage_(p._full_param_padded, size=p_size)
output_tensor = p._full_param_padded
else: else:
# Allocate fresh tensor in full precision. # If self.cpu_offload and full_precision, we need to cast
output_tensor = p_data.new_zeros(p_size) # the FP32 CPU param to CUDA for the all-gather.
output_tensors.append((output_tensor, True)) p_data = p.data.to(p._full_param_padded.device)
p_size = p._full_param_padded.size()
assert p_size.numel() % self.world_size == 0
if not self.mixed_precision or not full_precision:
if p._full_param_padded.storage().size() != p_size.numel():
# Allocate based on full size from all shards.
alloc_storage_(p._full_param_padded, size=p_size)
output_tensor = p._full_param_padded
else:
# Allocate fresh tensor in full precision.
output_tensor = p_data.new_zeros(p_size)
# Fill output_tensor with (p.data for each shard in self.world_size) # Fill output_tensor with (p.data for each shard in self.world_size)
chunks = list(output_tensor.chunk(self.world_size)) chunks = list(output_tensor.chunk(self.world_size))
dist.all_gather(chunks, p_data, group=self.process_group) dist.all_gather(chunks, p_data, group=self.process_group)
p.data = output_tensor[: p._orig_size.numel()].view(p._orig_size) # Set p.data = output_tensor (with padding trimmed)
update_p_data(output_tensor)
if self.mixed_precision and not full_precision: if self.mixed_precision and not full_precision:
self._free_fp16_param_shard([p]) self._free_fp16_param_shard([p])
torch.cuda.current_stream().wait_stream(self._streams["all_gather"]) torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
return output_tensors return output_tensors
@torch.no_grad() @torch.no_grad()
def _use_full_params(self) -> None: def _use_full_params(self) -> None:
"""Switching p.data pointers to use the full params. """
Switch p.data pointers to use the full params.
Note: this is used assuming full param gathering is already done. Note: this assumes full params are already gathered.
""" """
assert self.has_full_params
for p in self.params: for p in self.params:
if not p._is_sharded: if not p._is_sharded:
if self.mixed_precision: if self.mixed_precision:
...@@ -1080,6 +1123,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1080,6 +1123,7 @@ class FullyShardedDataParallel(nn.Module):
"""Free up storage for full parameters.""" """Free up storage for full parameters."""
if params is None: if params is None:
params = self.params params = self.params
self.has_full_params = False
current_stream = torch.cuda.current_stream() current_stream = torch.cuda.current_stream()
with torch.cuda.stream(self._streams["all_gather"]): with torch.cuda.stream(self._streams["all_gather"]):
for p in params: for p in params:
...@@ -1176,7 +1220,6 @@ def free_storage_(data: torch.Tensor) -> None: ...@@ -1176,7 +1220,6 @@ def free_storage_(data: torch.Tensor) -> None:
# Since we're modifying the Tensor's Storage directly, make sure the Tensor # Since we're modifying the Tensor's Storage directly, make sure the Tensor
# is the sole occupant of the Storage. # is the sole occupant of the Storage.
assert data.storage_offset() == 0 assert data.storage_offset() == 0
assert data.storage().size() == data.numel()
data.storage().resize_(0) data.storage().resize_(0)
......
tests/nn/data_parallel/test_fsdp_uneven.py tests/nn/data_parallel/test_fsdp_uneven.py
tests/nn/data_parallel/test_fsdp_grad_scaler.py tests/nn/data_parallel/test_fsdp_grad_scaler.py
tests/nn/data_parallel/test_fsdp_no_sync.py
tests/nn/data_parallel/test_fsdp_summon_full_params.py tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_sharded_ddp_features.py tests/nn/data_parallel/test_sharded_ddp_features.py
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
......
...@@ -660,86 +660,7 @@ class TestNoGrad(DistributedTest): ...@@ -660,86 +660,7 @@ class TestNoGrad(DistributedTest):
with torch.no_grad(): with torch.no_grad():
no_grad_output = model(*input) no_grad_output = model(*input)
assert objects_are_equal(ref_output, no_grad_output), "no_grad_output did not match ref_output" assert objects_are_equal(ref_output, no_grad_output, raise_exception=True)
class TestNoSync(DistributedTest):
def test_transformer(self):
fn = functools.partial(self._test_transformer, config={})
spawn_and_init(fn)
def test_transformer_no_flat_params(self):
config = {"flatten_parameters": False}
fn = functools.partial(self._test_transformer, config=config)
spawn_and_init(fn)
def test_nested_wrapper(self):
fn = functools.partial(self._test_nested_wrapper, config={})
spawn_and_init(fn)
def test_no_sync_before_first_forward(self):
group = DummyProcessGroup(rank=0, size=1)
model = self.get_wrapped_model(group, config={}, add_bn=False)
batch = model.module.get_input(torch.device("cuda"))
with model.no_sync():
output = model(*batch)
loss = model.module.get_loss(batch, output)
loss.backward()
output = model(*batch)
loss = model.module.get_loss(batch, output)
loss.backward()
@classmethod
def _test_transformer(self, rank, group, config):
model = self.get_wrapped_model(group, config=config, add_bn=False)
model.eval() # turn off dropout for the test
self._test_no_sync(model, batch_dim=1)
@classmethod
def _test_nested_wrapper(self, rank, group, config):
model = NestedWrappedModule(group, config)
model = FullyShardedDataParallel(model, group, **config).cuda()
self._test_no_sync(model, batch_dim=0)
@classmethod
def _test_no_sync(self, model, batch_dim):
# Generate two input batches. We'll test that we get the same grads if
# we train on them sequentially while accumulating grads (with no_sync)
# vs. concatenating the batches and training in one go.
batch1 = model.module.get_input(torch.device("cuda"))
assert isinstance(batch1, tuple)
batch2 = tuple(
# This randomly permutes the values in a multi-dim tensor.
x.view(-1)[torch.randperm(x.numel())].view_as(x)
for x in batch1
)
for x, y in zip(batch1, batch2):
assert not torch.all(x == y)
# Concat the batches along batch dimension.
concat_batch = tuple(torch.cat((x, y), dim=batch_dim) for (x, y) in zip(batch1, batch2))
# Establish reference behavior on the concat batch.
model.zero_grad()
output = model(*concat_batch)
ref_loss = model.module.get_loss(concat_batch, output)
ref_loss.backward()
ref_grads = [p.grad.detach().clone() for p in model.parameters()]
# Test that we get the same results by accumulating grads.
model.zero_grad()
with model.no_sync(): # accumulate gradients from the first batch
output = model(*batch1)
loss1 = model.module.get_loss(batch1, output)
loss1.backward()
output = model(*batch2)
loss2 = model.module.get_loss(batch2, output)
loss2.backward()
accumulated_loss = loss1 + loss2
accumulated_grads = [p.grad.detach().clone() for p in model.parameters()]
torch.testing.assert_allclose(ref_loss, accumulated_loss)
assert objects_are_equal(ref_grads, accumulated_grads, raise_exception=True)
class TransformerWithSharedParams(nn.Module): class TransformerWithSharedParams(nn.Module):
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import functools
import unittest
from unittest.mock import patch
import torch
from fairscale.nn.data_parallel import FullyShardedDataParallel
from fairscale.utils.testing import DummyProcessGroup, objects_are_equal
from .test_fsdp import DistributedTest, NestedWrappedModule, spawn_and_init
class TestNoSync(DistributedTest):
def test_transformer(self):
fn = functools.partial(self._test_transformer, config={})
spawn_and_init(fn)
def test_transformer_no_flat_params(self):
config = {"flatten_parameters": False}
fn = functools.partial(self._test_transformer, config=config)
spawn_and_init(fn)
def test_nested_wrapper(self):
fn = functools.partial(self._test_nested_wrapper, config={})
spawn_and_init(fn)
def test_no_sync_before_first_forward(self):
group = DummyProcessGroup(rank=0, size=1)
model = self.get_wrapped_model(group, config={}, add_bn=False)
batch = model.module.get_input(torch.device("cuda"))
with model.no_sync():
output = model(*batch)
loss = model.module.get_loss(batch, output)
loss.backward()
output = model(*batch)
loss = model.module.get_loss(batch, output)
loss.backward()
@classmethod
def _test_transformer(self, rank, group, config):
model = self.get_wrapped_model(group, config=config, add_bn=False)
model.eval() # turn off dropout for the test
self._test_no_sync(model, batch_dim=1)
@classmethod
def _test_nested_wrapper(self, rank, group, config):
model = NestedWrappedModule(group, config)
model = FullyShardedDataParallel(model, group, **config).cuda()
self._test_no_sync(model, batch_dim=0)
@classmethod
def _test_no_sync(self, model, batch_dim):
# Generate two input batches. We'll test that we get the same grads if
# we train on them sequentially while accumulating grads (with no_sync)
# vs. concatenating the batches and training in one go.
batch1 = model.module.get_input(torch.device("cuda"))
assert isinstance(batch1, tuple)
batch2 = tuple(
# This randomly permutes the values in a multi-dim tensor.
x.view(-1)[torch.randperm(x.numel())].view_as(x)
for x in batch1
)
for x, y in zip(batch1, batch2):
assert not torch.all(x == y)
# Concat the batches along batch dimension.
concat_batch = tuple(torch.cat((x, y), dim=batch_dim) for (x, y) in zip(batch1, batch2))
# Establish reference behavior on the concat batch.
model.zero_grad()
output = model(*concat_batch)
ref_loss = model.module.get_loss(concat_batch, output)
ref_loss.backward()
ref_grads = [p.grad.detach().clone() for p in model.parameters()]
# Test that we get the same results by accumulating grads.
model.zero_grad()
with model.no_sync(): # accumulate gradients from the first batch
output = model(*batch1)
loss1 = model.module.get_loss(batch1, output)
loss1.backward()
output = model(*batch2)
loss2 = model.module.get_loss(batch2, output)
loss2.backward()
accumulated_loss = loss1 + loss2
accumulated_grads = [p.grad.detach().clone() for p in model.parameters()]
torch.testing.assert_allclose(ref_loss, accumulated_loss)
assert objects_are_equal(ref_grads, accumulated_grads, raise_exception=True)
class TestNoSyncCommunication(DistributedTest):
def test_communication(self):
config = {"mixed_precision": True}
fn = functools.partial(self._test_communication, config=config)
spawn_and_init(fn)
@classmethod
def _test_communication(self, rank, group, config):
if group.size() == 1:
return
model = self.get_wrapped_model(group, config=config)
batch = model.module.get_input(torch.device("cuda"))
with patch("torch.distributed.all_gather") as mock_all_gather:
with patch("torch.distributed.reduce_scatter") as mock_reduce_scatter:
with model.no_sync():
output = model(*batch)
loss = model.module.get_loss(batch, output)
loss.backward()
assert mock_all_gather.call_count == 1
assert mock_reduce_scatter.call_count == 0
output = model(*batch)
loss = model.module.get_loss(batch, output)
loss.backward()
assert mock_all_gather.call_count == 1
assert mock_reduce_scatter.call_count == 1
if __name__ == "__main__":
unittest.main()
...@@ -48,7 +48,11 @@ class TestMemory(DistributedTest): ...@@ -48,7 +48,11 @@ class TestMemory(DistributedTest):
del state_dict del state_dict
mems.append(get_cuda_mem()) mems.append(get_cuda_mem())
assert mems[4] == mems[0] # Any value other than `==` indicates a memory leak. If mems[4] >
# mems[0], that indicates we're not cleaning up params properly in
# summon_full_params. If mems[4] < mems[0], that indicates there's a
# memory leak in _train_for_several_steps.
assert mems[4] == mems[0], f"memory leak detected, {mems[4]} != {mems[0]}"
class TestPersistence(DistributedTest): class TestPersistence(DistributedTest):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment