Unverified Commit 81c20f72 authored by Quentin Duval's avatar Quentin Duval Committed by GitHub
Browse files

[feat] Save FSDP metadata for offline unflattening + Consolidate checkpoints (#683)



* Save FSDP metadata for offline unflattening

* Complete the meta-data saving method with all the information needed to reconstruct a checkpoint offline, and implement the method that reconstruct a consolidated checkpoint from a sharded checkpoint

* Complete the meta-data saving method with all the information needed to reconstruct a checkpoint offline, and implement the method that reconstruct a consolidated checkpoint from a sharded checkpoint

* Add a unit test to show how to use the function

* Code review + improvement of the unit tests

* Code review: extract clean_path

* Make meta data and consolidation of checkpoint work for flatten_parameter=False

* Add new unit test file in CI

* Complete changelog and fix mypy issues

* Add support for module buffers in the consolidation of sharded checkpoints

* Better support for module buffers: save them in the meta data

* Refactoring: use a data-format for the meta data that is simpler to understand (move from object of array to array of object format)

* Renaming to make code clearer

* Code review: in_temporary_directory rework and typo correction

* Renaming
Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>
Co-authored-by: default avatarQuentinDuval <QuentinDuval@users.noreply.github.com>
parent d240b748
......@@ -20,6 +20,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- FSDP: added `force_input_to_fp32` flag for SyncBatchNorm [#659]
- FSDP: better memory usage for reduce bucket [#633]
- Experimental SyncBatchNorm [#662]
- FSDP: added `local_metadata_dict` to save sharding relating information [#683]
- FSDP: added `consolidate_shard_weights` to reconstruct the consolidated (non-sharded) model weights from saved sharded weights and metadata on the disk [#683]
## [0.3.6] - 2021-04-26
### Added
......
......@@ -285,7 +285,12 @@ class FullyShardedDataParallel(nn.Module):
# Only handle params which are not already sharded. This enables
# sharding individual layers of a Module, with an outer wrapper to
# shard any leftover parameters.
params = list(p for p in module.parameters() if not hasattr(p, "_is_sharded"))
param_names = []
params = []
for param_name, param in module.named_parameters():
if not hasattr(param, "_is_sharded"):
param_names.append(param_name)
params.append(param)
self._has_params = len(params) > 0
if not self._has_params:
......@@ -294,9 +299,11 @@ class FullyShardedDataParallel(nn.Module):
if self.flatten_parameters:
self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(module, param_list=params)
del module # free original module in case it helps garbage collection
self.param_paths = ["flat_param"]
self.params = [self._fsdp_wrapped_module.flat_param]
else:
self._fsdp_wrapped_module = module
self.param_paths = param_names
self.params = params
# Shard module parameters in place
......@@ -1509,6 +1516,141 @@ class FullyShardedDataParallel(nn.Module):
# the Storage to 0 to save memory.
free_storage_(p._full_param_padded)
def local_metadata_dict(self) -> Dict[str, Any]:
"""
Get the information needed to reconstruct the model from shards offline.
"""
params_metadata = []
for path, m in self.named_modules():
if not isinstance(m, FullyShardedDataParallel):
continue
# Dealing with FSDP(flatten_parameter=False)
# There are as many sharded parameters as there parameters in the
# consolidated model, so we only need to export how to reshape the
# parameters to their orginal shape and take care of the padding
if not hasattr(m, "_param_numels"):
params_metadata.append(
{
"fsdp_path": _clean_path(path),
"is_flat": False,
"num_padded": m.numel_padded_per_param,
"param_names": [_clean_path(p) for p in m.param_paths],
"param_shapes": [p._orig_size for p in m.params],
"param_numels": [_numel_from_size(p._orig_size) for p in m.params],
"no_broadcast_optim_state": m.no_broadcast_optim_state,
}
)
# Dealing with FSDP(flatten_parameter=True)
# Now, there is just one flattened parameter mapped to N different
# parameters, so we need to export additional information (numels)
# on how to split the "merged" parameters, by extracting the meta-data
# used in the FlattenParamsWrapper
else:
param_names = []
for module_path, param_name in m._param_full_infos:
full_param_path = module_path + "." + param_name if module_path else param_name
param_names.append(_clean_path(full_param_path))
params_metadata.append(
{
"fsdp_path": _clean_path(path),
"is_flat": True,
"num_padded": m.numel_padded_per_param,
"param_names": param_names,
"param_shapes": m._param_shapes,
"param_numels": m._param_numels,
"no_broadcast_optim_state": m.no_broadcast_optim_state,
}
)
buffer_names = [_clean_path(buffer_name) for buffer_name, _ in self.named_buffers(recurse=True)]
return dict(param_metadata=params_metadata, buffer_names=buffer_names)
@staticmethod
def consolidate_shard_weights(
shard_weights: List[Dict[str, torch.Tensor]],
shard_metadata: List[Dict[str, Any]],
with_module_buffers: bool = True,
) -> Dict[str, torch.Tensor]:
"""
Given a list of weights and meta data associated to N shards, reconstruct
the weights of an equivalent consolidated (non-sharded) model.
Module parameters are consolidated using the shard metadata.
Module buffers are taken from shard 0: this assumes that module buffers
are either synchronized or that the shard 0 value is valid for all shards.
If this behavior is not correct for your module (for instance if buffers
needs to be reduced instead), you can disable it with `with_module_buffers=False`.
This method is very useful to re-assemble checkpoints of shards without
having to instantiate FSDP wrappers with the world size originally used
to save the shards.
"""
if len(shard_weights) != len(shard_metadata) or not len(shard_weights):
raise ValueError("Require meta data for each shard and non-empty shards")
consolidated_weights = {}
original_world_size = len(shard_weights)
# Deal with the parameters of the model, for which there should be
# a corresponding entry in the metadata
shard_0_metadata = shard_metadata[0]["param_metadata"]
num_fsdp_wrappers = len(shard_0_metadata)
for fsdp_wrapper_index in range(num_fsdp_wrappers):
fsdp_path = shard_0_metadata[fsdp_wrapper_index]["fsdp_path"]
param_names = shard_0_metadata[fsdp_wrapper_index]["param_names"]
param_numels = shard_0_metadata[fsdp_wrapper_index]["param_numels"]
param_shapes = shard_0_metadata[fsdp_wrapper_index]["param_shapes"]
# Dealing with FSDP(flatten_parameter=False)
# For each parameter of the FSDP wrapper, get rid of the padding on each shard,
# concatenate the shards and reshape them to their initial shape
if not shard_0_metadata[fsdp_wrapper_index]["is_flat"]:
for i in range(len(param_names)):
param_name = param_names[i]
param_name = ".".join([fsdp_path, param_name]) if fsdp_path else param_name
shards = []
for rank in range(original_world_size):
shard = shard_weights[rank][param_name]
pad = shard_metadata[rank]["param_metadata"][fsdp_wrapper_index]["num_padded"][i]
shards.append(_unpad(shard, pad))
full_flatten_param = torch.cat(shards, dim=0)
consolidated_weights[param_name] = full_flatten_param.view(param_shapes[i])
# Dealing with FSDP(flatten_parameter=True)
# Concatenate the merged flat_param after removing the padding
# and then split the flat_param by using numel, before reshaping each
# split to the original shape
else:
# Concatenate the flat_param parameter after removing the padding
flat_param_name = ".".join([fsdp_path, "flat_param"]) if fsdp_path else "flat_param"
shards = []
for rank in range(original_world_size):
shard = shard_weights[rank][flat_param_name]
pad = shard_metadata[rank]["param_metadata"][fsdp_wrapper_index]["num_padded"][0]
shards.append(_unpad(shard, pad))
full_flatten_param = torch.cat(shards, dim=0)
# Split the flat_param into its constituents
assert sum(param_numels) == full_flatten_param.size(0)
for n, t, s in zip(param_names, full_flatten_param.split(param_numels), param_shapes):
full_name = fsdp_path + "." + n if fsdp_path else n
consolidated_weights[full_name] = t.view(s)
# Deal with the buffers, which are not parameters and are not sharded by FSDP
# and therefore are replicated among the different shards.
# We take the values of the first shard (this assumes that there is some form
# of synchronization between shards or that all shards buffers are equivalent)
if with_module_buffers:
for buffer_name in shard_metadata[0]["buffer_names"]:
consolidated_weights[buffer_name] = shard_weights[0][buffer_name]
return consolidated_weights
@torch.no_grad()
def _use_fp32_param_shard(self, params: Optional[List[Parameter]] = None) -> None:
"""Use FP32 shard for a list of params."""
......@@ -1818,6 +1960,23 @@ def _pre_load_state_dict_hook(
replace_by_prefix_(state_dict, prefix, prefix + "_fsdp_wrapped_module.")
def _clean_path(path: str) -> str:
return ".".join([split for split in path.split(".") if split not in {"_fsdp_wrapped_module", "_fpw_module"}])
def _numel_from_size(size: torch.Size) -> int:
numel = 1
for dim in size:
numel *= dim
return numel
def _unpad(shard: torch.Tensor, pad: int) -> torch.Tensor:
if pad > 0:
shard = shard[:-pad]
return shard
########################################################################################
# Below are APIs used together with FSDP, but not directly part of FSDP.
########################################################################################
......
......@@ -70,12 +70,13 @@ class FlattenParamsWrapper(nn.Module):
def _init_flatten_params(self) -> List[Tensor]:
param_infos = []
param_full_infos = []
shared_param_memo: Dict[nn.Parameter, Tuple[nn.Module, str]] = {}
shared_param_infos = []
params = []
param_numels = []
param_shapes = []
for m in self.modules():
for module_name, m in self.named_modules():
for n, p in m.named_parameters(recurse=False):
if p is not None and (m, n) in self._param_set:
if p in shared_param_memo:
......@@ -84,6 +85,7 @@ class FlattenParamsWrapper(nn.Module):
else:
shared_param_memo[p] = (m, n)
param_infos.append((m, n))
param_full_infos.append((module_name, n))
params.append(p.detach())
param_numels.append(p.numel())
param_shapes.append(p.size())
......@@ -93,6 +95,7 @@ class FlattenParamsWrapper(nn.Module):
# store the info for unflatten
self._param_infos = tuple(param_infos)
self._param_full_infos = tuple(param_full_infos)
self._shared_param_infos = tuple(shared_param_infos)
self._param_numels = tuple(param_numels)
self._param_shapes = tuple(param_shapes)
......
......@@ -665,6 +665,19 @@ def rmf(filename: str) -> None:
pass
@contextlib.contextmanager
def in_temporary_directory() -> Generator:
"""
Context manager to create a temporary direction and remove
it at the end of the context
"""
old_cwd = os.getcwd()
with tempfile.TemporaryDirectory() as temp_dir:
os.chdir(temp_dir)
yield temp_dir
os.chdir(old_cwd)
@contextlib.contextmanager
def temp_files_ctx(num: int) -> Generator:
""" A context to get tempfiles and ensure they are cleaned up. """
......@@ -686,7 +699,7 @@ def dump_all_tensors(rank: int) -> None:
ttype = str(type(obj))
if torch.is_tensor(obj) or (hasattr(obj, "data") and torch.is_tensor(obj.data)):
print(ttype, obj.shape, obj.dtype, obj.device, obj.storage().size())
except Exception as e:
except Exception:
pass
print(torch.cuda.memory_summary())
......
......@@ -2,6 +2,7 @@ tests/nn/data_parallel/test_fsdp_overlap.py
tests/nn/data_parallel/test_fsdp_multiple_forward.py
tests/nn/data_parallel/test_fsdp_apply.py
tests/nn/data_parallel/test_fsdp_state_dict.py
tests/nn/data_parallel/test_fsdp_metadata.py
tests/utils/test_reduce_scatter_bucketer.py
tests/utils/test_containers.py
tests/utils/test_parallel.py
......
......@@ -15,6 +15,7 @@ from unittest import mock
from parameterized import parameterized
import torch
from torch import nn
import torch.distributed
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel, TrainingState
......@@ -119,6 +120,9 @@ class DistributedTest(unittest.TestCase):
assert objects_are_equal(ref_state_dict, shard_state_dict, raise_exception=True)
except (AssertionError, RuntimeError) as e:
raise Exception(f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}")
if config.get("flatten_parameters", True):
metadata = model.local_metadata_dict()
assert isinstance(metadata, dict)
class TestMixedPrecision(DistributedTest):
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from fairscale.nn import FullyShardedDataParallel
from fairscale.utils.testing import in_temporary_directory, skip_if_single_gpu, temp_files_ctx
class ConvolutionalModel(nn.Module):
def __init__(self, embedding_size: int, with_fsdp: bool, process_group):
super().__init__()
self.conv1 = self._conv_block(3, embedding_size)
self.conv2: nn.Module = self._conv_block(embedding_size, embedding_size // 2)
self.conv3: nn.Module = self._conv_block(embedding_size // 2, embedding_size)
self.pool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
self.flatten = nn.Flatten(start_dim=1)
self.relu = nn.ReLU()
self.fc1: nn.Module = nn.Linear(embedding_size, 2 * embedding_size)
self.fc2: nn.Module = nn.Linear(2 * embedding_size, 2 * embedding_size)
self.fc3: nn.Module = nn.Linear(2 * embedding_size, embedding_size + 1)
self.fc4: nn.Module = nn.Linear(embedding_size + 1, embedding_size)
if with_fsdp:
self.conv2 = FullyShardedDataParallel(self.conv2, process_group=process_group)
self.conv3 = FullyShardedDataParallel(self.conv3, process_group=process_group, flatten_parameters=False)
self.fc1 = FullyShardedDataParallel(self.fc1, process_group=process_group)
self.fc3 = FullyShardedDataParallel(self.fc3, process_group=process_group, flatten_parameters=False)
@staticmethod
def _conv_block(in_channels: int, out_channels: int):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3)), nn.BatchNorm2d(out_channels), nn.ReLU(),
)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.pool(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
x = self.relu(x)
x = self.fc4(x)
return x
def _create_model(embedding_size: int, with_fsdp: bool, process_group, flatten_parameters: bool = True):
model = ConvolutionalModel(with_fsdp=with_fsdp, process_group=process_group, embedding_size=embedding_size).cuda()
if with_fsdp:
return FullyShardedDataParallel(model, process_group=process_group, flatten_parameters=flatten_parameters)
else:
return model
def _load_sharded_checkpoint(rank: int):
return torch.load(f"checkpoint_{rank}.torch") # type: ignore
def _worker(gpu_id: int, sync_file: str, world_size: int, embedding_size: int, flatten_parameters: bool):
torch.manual_seed(0)
torch.cuda.set_device(gpu_id)
torch.distributed.init_process_group(
backend="nccl", init_method=f"file://{sync_file}", world_size=world_size, rank=gpu_id,
)
process_group = torch.distributed.new_group()
# Create a dummy model with dummy inputs and targets
batch_size = 4
input = torch.randn(size=(batch_size, 3, 32, 32)).cuda()
target = torch.zeros(size=(batch_size, embedding_size)).cuda()
model = _create_model(
with_fsdp=True,
process_group=process_group,
embedding_size=embedding_size,
flatten_parameters=flatten_parameters,
)
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
# Train the model for a few epochs
for epoch in range(2):
out = model(input)
loss = criterion(out, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Save a bunch of checkpoint, one by shard
cp_data = {
"weights": {k: v.cpu() for k, v in model.local_state_dict().items()},
"meta": model.local_metadata_dict(),
}
torch.save(cp_data, f"checkpoint_{gpu_id}.torch")
# Wait for all files to be written on the disk
dist.barrier() # type: ignore
# Reconstruct a full checkpoint from the sharded checkpoints
all_checkpoints = [_load_sharded_checkpoint(rank) for rank in range(world_size)]
consolidated_checkpoint = FullyShardedDataParallel.consolidate_shard_weights(
shard_weights=[c["weights"] for c in all_checkpoints], shard_metadata=[c["meta"] for c in all_checkpoints],
)
# Check that the reconstructed parameters are correct and of the right shape
full_model = _create_model(with_fsdp=False, process_group=process_group, embedding_size=embedding_size)
full_model_state_dict = full_model.state_dict()
assert set(full_model_state_dict.keys()) == set(consolidated_checkpoint.keys())
for k in full_model_state_dict.keys():
assert consolidated_checkpoint[k].shape == full_model_state_dict[k].shape
# Verify that the checkpoint can be loaded by a FSDP model
loaded_model = _create_model(
with_fsdp=True,
process_group=process_group,
embedding_size=embedding_size,
flatten_parameters=flatten_parameters,
)
loaded_model.load_state_dict(consolidated_checkpoint)
for m in loaded_model.modules():
if isinstance(m, FullyShardedDataParallel):
m._reset_lazy_init()
# Verify that the model saved and the model loaded give the same results
with torch.no_grad():
before_checkpoint_loss = criterion(model(input), target).item()
after_checkpoint_loss = criterion(loaded_model(input), target).item()
assert before_checkpoint_loss == after_checkpoint_loss
@skip_if_single_gpu
@pytest.mark.parametrize("embedding_size", [128, 129])
@pytest.mark.parametrize("flatten_parameters", [True, False])
def test_consolidation(embedding_size: int, flatten_parameters: bool):
import torch.multiprocessing as mp
world_size = 2
with in_temporary_directory():
with temp_files_ctx(num=1) as temp_files:
mp.spawn(_worker, (temp_files[0], world_size, embedding_size, flatten_parameters), nprocs=world_size)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment