Unverified Commit a5594032 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat] save memory by using bucket buffer only in backward (#633)



* [feat] save memory by using bucket buffer only in backward

- this fixes bug #627
- added documentation to clarify the buffer's cost and speed/memory
  tradeoff
- added setup/teardown calls so that the buffer is only allocated
  during the backward pass, saving more memory for forward and stepping
  so that they can be used for things like activations.
- added a unit test that assert the memory is in range.

Comparing with DDP:

  1. buffer size scales with # of FSDP not model size
  2. buffer is only allocated during backward
  3. buffer is used for small tensors only to reduce overhead
  4. overlapping of compute-reduction is very different

* add PR number to changelog

* filled in with memory number on 1.9

* addressed comments

* update comments

* fix for 1.6

* add a todo
Co-authored-by: default avatarMin Xu <min.xu@acm.org>
parent 9b79cc02
...@@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ...@@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## NEXT - TBD ## NEXT - TBD
### Added
- FSDP: better memory usage for reduce bucket ([#633](https://github.com/facebookresearch/fairscale/pull/633))
## [0.3.6] - 2021-04-26 ## [0.3.6] - 2021-04-26
### Added ### Added
......
...@@ -155,10 +155,30 @@ class FullyShardedDataParallel(nn.Module): ...@@ -155,10 +155,30 @@ class FullyShardedDataParallel(nn.Module):
*``cpu_offload``*. *``cpu_offload``*.
bucket_cap_mb (int, Optional): bucket_cap_mb (int, Optional):
FSDP will bucket parameters so that gradient reduction can FSDP will bucket parameters so that gradient reduction can
potentially overlap with backward computation. bucket_cap_mb be more efficient for small parameters.
controls the bucket size in MegaBytes (MB). Buckets are sub-divided ``bucket_cap_mb`` controls the bucket size in MegaBytes (MB). Buckets
based on world_size, so the max shard size is roughly are sub-divided based on world_size, so the max shard size is roughly
``bucket_cap_mb / world_size``. Values <= 0 disable bucketing. ``bucket_cap_mb / world_size``. There is one bucketer (with potentially
multiple ``bucket_cap_mb`` sized buffers shared by all FSDP instances.
Large gradient tensors are directly reduced without using the buffers.
The buffers are there to reduce communication overhead for small tensors.
Overlapping with computation happens due to use of a different CUDA stream
than the computation CUDA stream. The total memory overhead per buffer is around
``bucket_cap_mb / world_size * (world_size + 1)``.
The buffers are allocated during the backward pass and freed at the end
of the backward pass to save more memory for other phases of the
training process.
Note, the memory vs. speed tradeoff of bucket size is very different
from that of the DDP engine. In DDP, the buffer size ``1MB + n*cap_mb``,
until n is big enough to cover the entire model size. The order
of which buffer is ready there is more rigid and DDP requires all
gradients to be computed in the backward. In FSDP, the buffer size
does not change with model size (it changes based on number of
<dtype, device, process_group> tuples) and gradient ready order matters
little since FSDP has a final flush call that ensures everything is reduced
and not all gradients need to be upfront known. Overlapping with compute is
done differently too.
Values <= 0 disable bucketing.
Default: 25. Default: 25.
compute_device (torch.device, Optional): compute_device (torch.device, Optional):
device for computation. If not given and module params are on a CUDA device for computation. If not given and module params are on a CUDA
...@@ -1226,15 +1246,6 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1226,15 +1246,6 @@ class FullyShardedDataParallel(nn.Module):
else: else:
self.assert_state(TrainingState.BACKWARD_PRE) self.assert_state(TrainingState.BACKWARD_PRE)
def _remove_shard_bwd_hook(fsdp_module: FullyShardedDataParallel) -> None:
"""Helper used below on all fsdp modules."""
for p in fsdp_module.params:
if p.requires_grad:
if hasattr(p, "_shard_bwd_hook"):
assert len(p._shard_bwd_hook) == 2, len(p._shard_bwd_hook)
p._shard_bwd_hook[1].remove()
delattr(p, "_shard_bwd_hook")
if self._require_backward_grad_sync: if self._require_backward_grad_sync:
# Flush any unreduced buckets in the post_backward stream. # Flush any unreduced buckets in the post_backward stream.
with torch.cuda.stream(self._streams["post_backward"]): with torch.cuda.stream(self._streams["post_backward"]):
...@@ -1244,7 +1255,23 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1244,7 +1255,23 @@ class FullyShardedDataParallel(nn.Module):
if self.move_grads_to_cpu: if self.move_grads_to_cpu:
# Wait for the non-blocking GPU -> CPU grad transfers to finish. # Wait for the non-blocking GPU -> CPU grad transfers to finish.
torch.cuda.current_stream().synchronize() torch.cuda.current_stream().synchronize()
# A backward pass is done, update root and nested FSDP's flags.
# A backward pass is done, clean up below.
# Free reducer buffers.
if self._reducer is not None:
self._reducer.teardown()
def _remove_shard_bwd_hook(fsdp_module: FullyShardedDataParallel) -> None:
"""Helper used below on all fsdp modules."""
for p in fsdp_module.params:
if p.requires_grad:
if hasattr(p, "_shard_bwd_hook"):
assert len(p._shard_bwd_hook) == 2, len(p._shard_bwd_hook)
p._shard_bwd_hook[1].remove()
delattr(p, "_shard_bwd_hook")
# Update root and nested FSDP's hooks and flags.
for m in self.modules(): # includes self for m in self.modules(): # includes self
if isinstance(m, FullyShardedDataParallel): if isinstance(m, FullyShardedDataParallel):
_remove_shard_bwd_hook(m) _remove_shard_bwd_hook(m)
...@@ -1739,6 +1766,8 @@ def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False, process_group: ...@@ -1739,6 +1766,8 @@ def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False, process_group:
# **must** be False because BN's FSDP wrapper's pre-backward callback isn't called # **must** be False because BN's FSDP wrapper's pre-backward callback isn't called
# within the checkpoint's outer backward when multiple forward passes are used. # within the checkpoint's outer backward when multiple forward passes are used.
"reshard_after_forward": False, "reshard_after_forward": False,
# No bucketing or small bucketing should be enough for BNs.
"bucket_cap_mb": 0,
} }
with enable_wrap(wrap_bn_only_policy, **fsdp_config): with enable_wrap(wrap_bn_only_policy, **fsdp_config):
......
...@@ -21,6 +21,7 @@ class Bucket: ...@@ -21,6 +21,7 @@ class Bucket:
self.output_shard = torch.zeros_like(data[0]) self.output_shard = torch.zeros_like(data[0])
def flush(self) -> None: def flush(self) -> None:
"""Flush content of the bucket."""
if self.offset == 0: if self.offset == 0:
assert len(self.callbacks) == 0 assert len(self.callbacks) == 0
return return
...@@ -37,6 +38,24 @@ class Bucket: ...@@ -37,6 +38,24 @@ class Bucket:
self.callbacks.clear() self.callbacks.clear()
self.output_shard = torch.zeros_like(self.data[0]) self.output_shard = torch.zeros_like(self.data[0])
def setup(self) -> None:
""" Setup the buffers if they are not allocated.
Using ``setup`` and ``teardown``, we can ensure that the bucket
buffers are only allocated during the backward pass, hence saving more
memory to other parts of the training process, such as the forward pass
for activation memory.
"""
for tensor in [self.data, self.output_shard]:
if tensor.storage().size() == 0:
tensor.storage().resize_(tensor.size().numel())
def teardown(self) -> None:
"""Tear down the bucket by freeing the memory"""
assert self.offset == 0 and self.callbacks == [], "Incorrect call of teardown"
for tensor in [self.data, self.output_shard]:
tensor.storage().resize_(0)
class ReduceScatterBucketer: class ReduceScatterBucketer:
""" """
...@@ -132,6 +151,12 @@ class ReduceScatterBucketer: ...@@ -132,6 +151,12 @@ class ReduceScatterBucketer:
for bucket in self.buckets.values(): for bucket in self.buckets.values():
bucket.flush() bucket.flush()
@torch.no_grad()
def teardown(self) -> None:
"""Free buffers from all buckets."""
for bucket in self.buckets.values():
bucket.teardown()
@functools.lru_cache() @functools.lru_cache()
def _get_shard_size(self, element_size: int, num_shards: int) -> int: def _get_shard_size(self, element_size: int, num_shards: int) -> int:
if self.bucket_cap_mb <= 0: # Values <= 0 disable bucketing. if self.bucket_cap_mb <= 0: # Values <= 0 disable bucketing.
...@@ -141,6 +166,10 @@ class ReduceScatterBucketer: ...@@ -141,6 +166,10 @@ class ReduceScatterBucketer:
return int(bucket_size // num_shards) return int(bucket_size // num_shards)
def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket: def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket:
# TODO (Min): the `group` used here in the key is the object hash, not the content
# hash. That means if FSDP instances are initialized with different process groups,
# even when the group members are in fact the same, we end up creating different
# buckets here.
key = (tensor.dtype, tensor.device, group) key = (tensor.dtype, tensor.device, group)
if key not in self.buckets: if key not in self.buckets:
# buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size) # buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size)
...@@ -148,4 +177,5 @@ class ReduceScatterBucketer: ...@@ -148,4 +177,5 @@ class ReduceScatterBucketer:
shard_size = self._get_shard_size(tensor.element_size(), world_size) shard_size = self._get_shard_size(tensor.element_size(), world_size)
data = tensor.new_zeros((world_size, shard_size)) data = tensor.new_zeros((world_size, shard_size))
self.buckets[key] = Bucket(data, group) self.buckets[key] = Bucket(data, group)
self.buckets[key].setup()
return self.buckets[key] return self.buckets[key]
tests/nn/data_parallel/test_fsdp_memory.py
tests/nn/data_parallel/test_fsdp_multiple_forward_checkpoint.py tests/nn/data_parallel/test_fsdp_multiple_forward_checkpoint.py
tests/nn/data_parallel/test_fsdp_multiple_wrapping.py tests/nn/data_parallel/test_fsdp_multiple_wrapping.py
tests/nn/data_parallel/test_fsdp_freezing_weights.py tests/nn/data_parallel/test_fsdp_freezing_weights.py
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
""" Test FSDP with GPU memory usage. """
import gc
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
import torch.optim as optim
from fairscale.nn import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import auto_wrap_bn
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, temp_files_ctx, torch_version
def get_global_group():
"""
Singleton pytorch distributed group
Inspired by https://github.com/pytorch/fairseq
"""
if dist.is_initialized():
if not hasattr(get_global_group, "_global_group"):
get_global_group._global_group = dist.new_group()
return get_global_group._global_group
else:
return None
def to_fsdp(module):
return FSDP(module, process_group=get_global_group())
def dump_all_tensors(rank):
"""Use this for debugging"""
if rank != 0:
return
for obj in gc.get_objects():
try:
# Only need to check parameter type objects if asked.
ttype = str(type(obj))
if torch.is_tensor(obj) or (hasattr(obj, "data") and torch.is_tensor(obj.data)):
print(ttype, obj.shape, obj.dtype, obj.device, id(obj), obj.storage().size())
except Exception as e:
pass
def get_cur_mem(rank, result, prefix):
"""Collect memory allocated values in a result dict in MB"""
result[prefix] = round(torch.cuda.memory_allocated() / 1024 / 1024)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.stem = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
self.blocks = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=5, padding=2),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=5, padding=2),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=5, padding=2),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten(),
)
self.head = nn.Linear(128, 10)
def forward(self, x):
return self.head(self.blocks(self.stem(x)))
def create_model(with_fsdp, with_checkpoint):
model = Model()
if with_fsdp:
model.stem = auto_wrap_bn(model.stem, single_rank_pg=False)
model.blocks = auto_wrap_bn(model.blocks, single_rank_pg=False)
if with_checkpoint:
model.blocks = checkpoint_wrapper(model.blocks)
model.stem = to_fsdp(model.stem)
model.blocks = to_fsdp(model.blocks)
model.head = to_fsdp(model.head)
else:
if with_checkpoint:
model.blocks = checkpoint_wrapper(model.blocks)
return model
def _distributed_worker(gpu_id, world_size, with_fsdp, with_checkpoint, filename, filename_rpc, expected):
torch.cuda.set_device(gpu_id)
rank = gpu_id
result = dist_init(rank, world_size, filename, filename_rpc)
assert result, "Dist init failed"
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
batch = torch.randn(size=(2, 3, 224, 224)).cuda()
model = create_model(with_fsdp, with_checkpoint)
model = model.cuda()
if with_fsdp:
model = to_fsdp(model)
else:
model = DistributedDataParallel(model, device_ids=[gpu_id], bucket_cap_mb=500)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=1e-4)
results = {}
for iteration in range(3):
get_cur_mem(gpu_id, results, f"iter {iteration}: start")
out = model(batch)
get_cur_mem(gpu_id, results, f"iter {iteration}: after fwd")
out = sum(o.sum() for o in out[0])
fake_loss = criterion(out, torch.tensor(0.0).cuda())
get_cur_mem(gpu_id, results, f"iter {iteration}: after loss")
fake_loss.backward()
get_cur_mem(gpu_id, results, f"iter {iteration}: after bwd")
optimizer.step()
get_cur_mem(gpu_id, results, f"iter {iteration}: after step")
# It is important to use `set_to_none` below, not optimizer.zero_grad() to reclaim memory.
if torch_version() >= (1, 7, 0):
model.zero_grad(set_to_none=True)
else:
for p in model.parameters():
p.grad = None
get_cur_mem(gpu_id, results, f"iter {iteration}: done")
assert results == expected, f"{results} but expected {expected}"
teardown()
@skip_if_single_gpu
@pytest.mark.parametrize("ckpt", ["no_ckpt", "ckpt"])
@pytest.mark.parametrize("fsdp", ["ddp", "fsdp"])
def test_fsdp_memory(fsdp, ckpt):
expected = {
("ddp", "no_ckpt"): {
"iter 0: start": 9,
"iter 0: after fwd": 346,
"iter 0: after loss": 346,
"iter 0: after bwd": 14,
"iter 0: after step": 14,
"iter 0: done": 9,
"iter 1: start": 9,
"iter 1: after fwd": 346,
"iter 1: after loss": 346,
"iter 1: after bwd": 14,
"iter 1: after step": 14,
"iter 1: done": 9,
"iter 2: start": 9,
"iter 2: after fwd": 346,
"iter 2: after loss": 346,
"iter 2: after bwd": 14,
"iter 2: after step": 14,
"iter 2: done": 9,
},
("fsdp", "no_ckpt"): {
"iter 0: start": 3,
"iter 0: after fwd": 340,
"iter 0: after loss": 340,
"iter 0: after bwd": 66,
"iter 0: after step": 66,
"iter 0: done": 3,
"iter 1: start": 3,
"iter 1: after fwd": 340,
"iter 1: after loss": 340,
"iter 1: after bwd": 66,
"iter 1: after step": 66,
"iter 1: done": 3,
"iter 2: start": 3,
"iter 2: after fwd": 340,
"iter 2: after loss": 340,
"iter 2: after bwd": 66,
"iter 2: after step": 66,
"iter 2: done": 3,
},
("ddp", "ckpt"): {
"iter 0: start": 9,
"iter 0: after fwd": 57,
"iter 0: after loss": 57,
"iter 0: after bwd": 14,
"iter 0: after step": 14,
"iter 0: done": 9,
"iter 1: start": 9,
"iter 1: after fwd": 57,
"iter 1: after loss": 57,
"iter 1: after bwd": 14,
"iter 1: after step": 14,
"iter 1: done": 9,
"iter 2: start": 9,
"iter 2: after fwd": 57,
"iter 2: after loss": 57,
"iter 2: after bwd": 14,
"iter 2: after step": 14,
"iter 2: done": 9,
},
("fsdp", "ckpt"): {
"iter 0: start": 3,
"iter 0: after fwd": 51,
"iter 0: after loss": 51,
"iter 0: after bwd": 66,
"iter 0: after step": 66,
"iter 0: done": 3,
"iter 1: start": 3,
"iter 1: after fwd": 51,
"iter 1: after loss": 51,
"iter 1: after bwd": 66,
"iter 1: after step": 66,
"iter 1: done": 3,
"iter 2: start": 3,
"iter 2: after fwd": 51,
"iter 2: after loss": 51,
"iter 2: after bwd": 66,
"iter 2: after step": 66,
"iter 2: done": 3,
},
}[(fsdp, ckpt)]
fsdp = fsdp == "fsdp"
ckpt = ckpt == "ckpt"
world_size = 2
with temp_files_ctx(num=2) as temp_files:
mp.spawn(
_distributed_worker, (world_size, fsdp, ckpt, temp_files[0], temp_files[1], expected), nprocs=world_size
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment