Unverified Commit 83b0b49e authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat]: prepare FSDP to handle multiple flatten params and fixed metadata saving for MoE (#746)



* [feat] FSDP: supporting multiple flatten parameter groups

- step 3: make FSDP use FlattenParamModule unconditionally

* fixing the auto_wrap tests

* minor

* rewrite local_metadata_dict

- updated FPW so that custom flat param name is also supported

* bug fix

* mypy

* rewrote consolidate_shard_weights

- test_consolidate passes

* comments

* fixing pickling

* Fix shared params and MoE logic (#749)

* add strict kwarg to support fairseq:gshard MoE saving logic

* Test fairseq style shard

* style

* formatting and address comments

* added changelog

* fixing a test after padding renaming
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>
parent 8bca4f87
...@@ -6,8 +6,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -6,8 +6,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## NEXT - TBD ## NEXT - TBD
### Fixed ### Fixed
- FSDP: fixed metadata saving and shard consolidation for MoE cases [#746]
### Added ### Added
- FSDP: better performance; use `_allgather_base` and `_reduce_scatter_base` when available [#729]
- FSDP: prepared FSDP internals for supporting multiple groups of flatten parameters (to support more general optimization) [#746]
## [0.3.8] - 2021-07-12 ## [0.3.8] - 2021-07-12
### Fixed ### Fixed
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
"""These functions are used by FullyShardedDataParallel to help consolidate and shard optimizer states.""" """These functions are used by FullyShardedDataParallel to help consolidate and shard optimizer states."""
import copy import copy
from typing import Any, Dict, Generator, List, Tuple, cast from typing import Any, Dict, Iterator, List, Tuple, cast
import torch import torch
...@@ -130,7 +130,7 @@ def _unflatten_optim_state( ...@@ -130,7 +130,7 @@ def _unflatten_optim_state(
v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, pad_info[local_id])] v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, pad_info[local_id])]
flat_buffer = torch.cat(v_unpad) flat_buffer = torch.cat(v_unpad)
# Casting needed only for mypy. # Casting needed only for mypy.
param_views: Generator = cast(FlattenParamsWrapper, instance_list[local_id]).get_param_views([flat_buffer]) param_views: Iterator = cast(FlattenParamsWrapper, instance_list[local_id]).get_param_views([flat_buffer])
for global_id, param_view in zip(sorted(local_to_global[local_id]), param_views): for global_id, param_view in zip(sorted(local_to_global[local_id]), param_views):
assert k not in unflat_state[global_id], f"already added {k} to {global_id} {local_id}" assert k not in unflat_state[global_id], f"already added {k} to {global_id} {local_id}"
unflat_state[global_id][k] = param_view unflat_state[global_id][k] = param_view
......
...@@ -12,7 +12,21 @@ from math import inf ...@@ -12,7 +12,21 @@ from math import inf
import time import time
import traceback import traceback
import typing import typing
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Mapping, NamedTuple, Optional, Set, Tuple, Union from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
List,
Mapping,
NamedTuple,
Optional,
Set,
Tuple,
Union,
cast,
)
import torch import torch
from torch.autograd import Variable from torch.autograd import Variable
...@@ -293,18 +307,31 @@ class FullyShardedDataParallel(nn.Module): ...@@ -293,18 +307,31 @@ class FullyShardedDataParallel(nn.Module):
params.append(param) params.append(param)
self._has_params = len(params) > 0 self._has_params = len(params) > 0
if not self._has_params:
self.flatten_parameters = False
# For now, it is either all flatten or none flatten. This will be extended to
# multiple flatten groups in my next PR.
to_be_flatten_params: List[List[Parameter]] = [[]]
non_flatten_params = params
param_name_groups = [[n] for n in param_names]
if self.flatten_parameters: if self.flatten_parameters:
self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(module, param_list=params) to_be_flatten_params = [params]
del module # free original module in case it helps garbage collection non_flatten_params = []
self.param_paths = ["flat_param"] param_name_groups = [param_names]
self.params = [self._fsdp_wrapped_module.flat_param] del param_names
else:
self._fsdp_wrapped_module = module self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(module, param_list=to_be_flatten_params)
self.param_paths = param_names del module # free original module in case it helps garbage collection
self.params = params
# Now, in this FSDP wrapper class, we keep a list of to-be-flatten and not-to-be-flatten
# params for doing sharding, gradient hooks, etc. Note, the ordering of the
# list matters: flatten params are always in the front.
#
# The self._num_flatten_params and self._param_name_groups are computed
# and kept here to support summon_full_params and shard-to-full weight
# consolidation.
self.params = cast(List[Parameter], self._fsdp_wrapped_module.flat_params) + non_flatten_params
self._num_flatten_params = len(self._fsdp_wrapped_module.flat_params)
self._param_name_groups = param_name_groups
# Shard module parameters in place # Shard module parameters in place
self._shard_parameters_() self._shard_parameters_()
...@@ -367,8 +394,10 @@ class FullyShardedDataParallel(nn.Module): ...@@ -367,8 +394,10 @@ class FullyShardedDataParallel(nn.Module):
self.gradient_postdivide_factor = post self.gradient_postdivide_factor = post
@property @property
def module(self) -> nn.Module: def module(self) -> FlattenParamsWrapper:
return self._fsdp_wrapped_module # note: may be a FlattenParamsWrapper instance """ make model.module accessible, just like DDP. """
assert isinstance(self._fsdp_wrapped_module, FlattenParamsWrapper)
return self._fsdp_wrapped_module
def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel": def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel":
""" """
...@@ -629,6 +658,10 @@ class FullyShardedDataParallel(nn.Module): ...@@ -629,6 +658,10 @@ class FullyShardedDataParallel(nn.Module):
del self.orig_sizes del self.orig_sizes
self._reset_lazy_init() self._reset_lazy_init()
def __getitem__(self, key: int) -> Any:
"""Forward indexing calls in case the module is a nn.Sequential."""
return self.module.__getitem__(key)
@typing.overload @typing.overload
def state_dict( def state_dict(
self, destination: Mapping[str, torch.Tensor], prefix: str = ..., keep_vars: bool = ... self, destination: Mapping[str, torch.Tensor], prefix: str = ..., keep_vars: bool = ...
...@@ -668,11 +701,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -668,11 +701,7 @@ class FullyShardedDataParallel(nn.Module):
state_dict = super().state_dict(*args, **kwargs) state_dict = super().state_dict(*args, **kwargs)
else: else:
maybe_cast_buffers(torch.float32) maybe_cast_buffers(torch.float32)
if self.flatten_parameters: state_dict = self.module.flat_state_dict(*args, **kwargs)
assert isinstance(self.module, FlattenParamsWrapper)
state_dict = self.module.flat_state_dict(*args, **kwargs)
else:
state_dict = super().state_dict(*args, **kwargs)
if self.move_params_to_cpu: if self.move_params_to_cpu:
for k in state_dict.keys(): for k in state_dict.keys():
...@@ -827,13 +856,15 @@ class FullyShardedDataParallel(nn.Module): ...@@ -827,13 +856,15 @@ class FullyShardedDataParallel(nn.Module):
full_tensors = self._rebuild_full_params(force_full_precision=True) full_tensors = self._rebuild_full_params(force_full_precision=True)
assert full_tensors is not None assert full_tensors is not None
with contextlib.ExitStack() as stack: with contextlib.ExitStack() as stack:
if self.flatten_parameters and self.module.is_flattened: if self.module.is_flattened:
# Update flattened views to point to fully-sized tensors. We # Update flattened views to point to fully-sized tensors. We
# use self.params[0] instead of full_tensors since the # use self.params instead of full_tensors since the
# latter may contain padding. # latter may contain padding.
assert len(self.params) == 1 stack.enter_context(
assert isinstance(self.module, FlattenParamsWrapper) self.module.unflatten_params(
stack.enter_context(self.module.unflatten_params(flat_params=[self.params[0]])) flat_params=[p.data for p in self.params[: self._num_flatten_params]]
)
)
try: try:
yield yield
finally: finally:
...@@ -1529,137 +1560,126 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1529,137 +1560,126 @@ class FullyShardedDataParallel(nn.Module):
def local_metadata_dict(self) -> Dict[str, Any]: def local_metadata_dict(self) -> Dict[str, Any]:
""" """
Get the information needed to reconstruct the model from shards offline. Get the information needed to reconstruct the model from shards offline.
"""
params_metadata = []
See the `consolidate_shard_weights` method below.
"""
param_metadata = []
for path, m in self.named_modules(): for path, m in self.named_modules():
if not isinstance(m, FullyShardedDataParallel): if isinstance(m, FullyShardedDataParallel):
continue metadata: Dict[str, Any] = {}
metadata["fsdp_path"] = _clean_path(path)
# Dealing with FSDP(flatten_parameter=False) metadata["params"] = {}
# There are as many sharded parameters as there parameters in the
# consolidated model, so we only need to export how to reshape the metadata["no_broadcast_optim_state"] = m.no_broadcast_optim_state
# parameters to their orginal shape and take care of the padding shared_param_info = []
if not m.flatten_parameters: for (mpath_dst, mpath_src, _, src_name, _, dst_name) in m._shared_param_infos:
params_metadata.append( src_param_path = _clean_path(mpath_src + "." + src_name if mpath_src else src_name)
{ dst_param_path = _clean_path(mpath_dst + "." + dst_name if mpath_dst else dst_name)
"fsdp_path": _clean_path(path), shared_param_info.append((src_param_path, dst_param_path))
"is_flat": False, metadata["shared_param_info"] = shared_param_info
"num_padded": m.numel_padded_per_param,
"param_names": [_clean_path(p) for p in m.param_paths], for i, p in enumerate(m.params):
"param_shapes": [p._orig_size for p in m.params], if i < m._num_flatten_params:
"param_numels": [_numel_from_size(p._orig_size) for p in m.params], backing_param_name = m.module.flat_param_names[i]
"no_broadcast_optim_state": m.no_broadcast_optim_state, names, shapes, numels = m.module.metadata(i)
} else:
) assert len(m._param_name_groups[i]) == 1
backing_param_name = m._param_name_groups[i][0]
# Dealing with FSDP(flatten_parameter=True) names = [backing_param_name]
# Now, there is just one flattened parameter mapped to N different shapes = [p._orig_size]
# parameters, so we need to export additional information (numels) numels = [p._orig_size.numel()]
# on how to split the "merged" parameters, by extracting the meta-data backing_param_name = _clean_path(backing_param_name)
# used in the FlattenParamsWrapper metadata["params"][backing_param_name] = {
else: "names": [_clean_path(n) for n in names], # A list of str.
param_names = [] "shapes": shapes, # A list of torch.Size.
for module_path, param_name in m.param_path_infos: "numels": numels, # A list of int.
full_param_path = module_path + "." + param_name if module_path else param_name "padding": m.numel_padded_per_param[i], # An int for padding added to the backing parameter.
param_names.append(_clean_path(full_param_path))
params_metadata.append(
{
"fsdp_path": _clean_path(path),
"is_flat": True,
"num_padded": m.numel_padded_per_param,
"param_names": param_names,
# TODO (Min): we don't want to access the private _param_shapes and
# _param_numels here. We want to dump metadata from FPW when there are
# multiple groups of params.
"param_shapes": m._fsdp_wrapped_module.flat_param._param_shapes, # type: ignore
"param_numels": m._fsdp_wrapped_module.flat_param._param_numels, # type: ignore
"no_broadcast_optim_state": m.no_broadcast_optim_state,
} }
) param_metadata.append(metadata)
buffer_names = [_clean_path(buffer_name) for buffer_name, _ in self.named_buffers(recurse=True)] buffer_names = [_clean_path(buffer_name) for buffer_name, _ in self.named_buffers(recurse=True)]
return dict(param_metadata=params_metadata, buffer_names=buffer_names) return dict(param_metadata=param_metadata, buffer_names=buffer_names)
@staticmethod @staticmethod
def consolidate_shard_weights( def consolidate_shard_weights(
shard_weights: List[Dict[str, torch.Tensor]], shard_weights: List[Dict[str, torch.Tensor]],
shard_metadata: List[Dict[str, Any]], shard_metadata: List[Dict[str, Any]],
with_module_buffers: bool = True, with_module_buffers: bool = True,
strict: bool = True,
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
""" """
Given a list of weights and meta data associated to N shards, reconstruct Given a list of weights and meta data associated to N shards, reconstruct
the weights of an equivalent consolidated (non-sharded) model. the weights of an equivalent consolidated (non-sharded) state dict.
Module parameters are consolidated using the shard metadata. Module parameters are consolidated using the shard metadata.
Module buffers are taken from shard 0: this assumes that module buffers Module buffers are taken from shard 0: this assumes that module buffers
are either synchronized or that the shard 0 value is valid for all shards. are either synchronized or that the shard 0 value is valid for all shards.
If this behavior is not correct for your module (for instance if buffers If this behavior is not correct for your module (for instance if buffers
needs to be reduced instead), you can disable it with `with_module_buffers=False`. needs to be all-reduced instead), you can disable it with `with_module_buffers=False`.
This method is used to re-assemble checkpoints of shards without This method is used to re-assemble checkpoints of shards without
having to instantiate FSDP wrappers with the world size originally used having to instantiate FSDP wrappers with the world size (i.e. large
to save the shards. number of GPUs) originally used to save the shards.
Args:
shard_weights (List[Dict[str, torch.Tensor]]):
List of dictionaries that contains sharded weights from
each rank.
shard_metadata (List[Dict[str, Any]]):
List of dictionaries that contains metadata from each shard.
See `local_metadata_dict` above.
with_module_buffers (bool):
If shard 0's buffer should be returned in the consolidated
weight dict.
Default: True.
strict (bool):
allow incomplete shard weights. if True, every key in the metadata must be present in the weights.
""" """
if len(shard_weights) != len(shard_metadata) or not len(shard_weights): if len(shard_weights) != len(shard_metadata) or not len(shard_weights):
raise ValueError("Require meta data for each shard and non-empty shards") raise ValueError("Require metadata for each shard and non-empty shards")
consolidated_weights = {} consolidated_weights = {}
original_world_size = len(shard_weights) original_world_size = len(shard_weights)
# Deal with the parameters of the model, for which there should be # For every FSDP instance.
# a corresponding entry in the metadata for fsdp_obj_idx, metadata in enumerate(shard_metadata[0]["param_metadata"]):
shard_0_metadata = shard_metadata[0]["param_metadata"] fsdp_path = metadata["fsdp_path"]
num_fsdp_wrappers = len(shard_0_metadata) params = metadata["params"]
for fsdp_wrapper_index in range(num_fsdp_wrappers): # For every this-FSDP-owned param, flattened or not.
fsdp_path = shard_0_metadata[fsdp_wrapper_index]["fsdp_path"] for backing_param_name, v in params.items():
param_names = shard_0_metadata[fsdp_wrapper_index]["param_names"] in_state_dict_key = ".".join([fsdp_path, backing_param_name]) if fsdp_path else backing_param_name
param_numels = shard_0_metadata[fsdp_wrapper_index]["param_numels"] # Get full param back with pad removed.
param_shapes = shard_0_metadata[fsdp_wrapper_index]["param_shapes"] if in_state_dict_key not in shard_weights[0] and (not strict):
continue
# Dealing with FSDP(flatten_parameter=False)
# For each parameter of the FSDP wrapper, get rid of the padding on each shard,
# concatenate the shards and reshape them to their initial shape
if not shard_0_metadata[fsdp_wrapper_index]["is_flat"]:
for i in range(len(param_names)):
param_name = param_names[i]
param_name = ".".join([fsdp_path, param_name]) if fsdp_path else param_name
shards = []
for rank in range(original_world_size):
shard = shard_weights[rank][param_name]
pad = shard_metadata[rank]["param_metadata"][fsdp_wrapper_index]["num_padded"][i]
shards.append(_unpad(shard, pad))
full_flatten_param = torch.cat(shards, dim=0)
consolidated_weights[param_name] = full_flatten_param.view(param_shapes[i])
# Dealing with FSDP(flatten_parameter=True)
# Concatenate the merged flat_param after removing the padding
# and then split the flat_param by using numel, before reshaping each
# split to the original shape
else:
# Concatenate the flat_param parameter after removing the padding
flat_param_name = ".".join([fsdp_path, "flat_param_0"]) if fsdp_path else "flat_param_0"
shards = [] shards = []
for rank in range(original_world_size): for rank in range(original_world_size):
shard = shard_weights[rank][flat_param_name] shard = shard_weights[rank][in_state_dict_key]
pad = shard_metadata[rank]["param_metadata"][fsdp_wrapper_index]["num_padded"][0] pad = shard_metadata[rank]["param_metadata"][fsdp_obj_idx]["params"][backing_param_name]["padding"]
shards.append(_unpad(shard, pad)) shards.append(_unpad(shard, pad))
full_flatten_param = torch.cat(shards, dim=0) if metadata["no_broadcast_optim_state"]:
break
# Split the flat_param into its constituents full_param = torch.cat(shards, dim=0)
assert sum(param_numels) == full_flatten_param.size(0) # (Potentially), split the full param and create original params.
for n, t, s in zip(param_names, full_flatten_param.split(param_numels), param_shapes): names, shapes, numels, _ = v.values()
full_name = fsdp_path + "." + n if fsdp_path else n assert sum(numels) == full_param.size(0)
consolidated_weights[full_name] = t.view(s) for n, t, s in zip(names, full_param.split(numels), shapes):
out_state_dict_key = ".".join([fsdp_path, n]) if fsdp_path else n
consolidated_weights[out_state_dict_key] = t.view(s)
# copy shared parameters
for src_path, dest_path in metadata["shared_param_info"]:
consolidated_weights[dest_path] = consolidated_weights[src_path]
# Deal with the buffers, which are not parameters and are not sharded by FSDP # Deal with the buffers, which are not parameters and are not sharded by FSDP
# and therefore are replicated among the different shards. # and therefore are replicated among the different shards.
# We take the values of the first shard (this assumes that there is some form # We take the values of the first shard (this assumes that there is some form
# of synchronization between shards or that all shards buffers are equivalent) # of synchronization between shards or that all shards buffers are equivalent).
if with_module_buffers: if with_module_buffers:
for buffer_name in shard_metadata[0]["buffer_names"]: for buffer_name in shard_metadata[0]["buffer_names"]:
if buffer_name not in shard_weights[0] and (not strict):
continue
consolidated_weights[buffer_name] = shard_weights[0][buffer_name] consolidated_weights[buffer_name] = shard_weights[0][buffer_name]
return consolidated_weights return consolidated_weights
...@@ -1974,16 +1994,10 @@ def _pre_load_state_dict_hook( ...@@ -1974,16 +1994,10 @@ def _pre_load_state_dict_hook(
def _clean_path(path: str) -> str: def _clean_path(path: str) -> str:
""" Remove FSDP related wrapper modules from a given state dict key str path. """
return ".".join([split for split in path.split(".") if split not in {"_fsdp_wrapped_module", "_fpw_module"}]) return ".".join([split for split in path.split(".") if split not in {"_fsdp_wrapped_module", "_fpw_module"}])
def _numel_from_size(size: torch.Size) -> int:
numel = 1
for dim in size:
numel *= dim
return numel
def _unpad(shard: torch.Tensor, pad: int) -> torch.Tensor: def _unpad(shard: torch.Tensor, pad: int) -> torch.Tensor:
if pad > 0: if pad > 0:
shard = shard[:-pad] shard = shard[:-pad]
......
...@@ -14,6 +14,7 @@ from typing import ( ...@@ -14,6 +14,7 @@ from typing import (
Any, Any,
Dict, Dict,
Generator, Generator,
Iterator,
List, List,
Mapping, Mapping,
NamedTuple, NamedTuple,
...@@ -72,7 +73,11 @@ class FlatParameter(nn.Parameter): ...@@ -72,7 +73,11 @@ class FlatParameter(nn.Parameter):
), f"Something wrong with __new__ method, {self.numel()} vs. {sum(self._param_numels)}" ), f"Something wrong with __new__ method, {self.numel()} vs. {sum(self._param_numels)}"
self._param_shapes = [p.size() for p in params] self._param_shapes = [p.size() for p in params]
def get_param_views(self, external_data: Optional[Tensor] = None) -> Generator[Tensor, None, None]: # These are set by FPW class below, not by this class itself.
self._param_infos: List[Tuple[str, nn.Module, str]] = []
self._shared_param_infos: List[Tuple[str, str, nn.Module, str, nn.Module, str]] = []
def get_param_views(self, external_data: Optional[Tensor] = None) -> Iterator[Tensor]:
""" Return a generator of views that map to the original parameters. """ """ Return a generator of views that map to the original parameters. """
# Note, self.data could be sharded, so its numel is <= to the sum. # Note, self.data could be sharded, so its numel is <= to the sum.
assert self.data.numel() <= sum( assert self.data.numel() <= sum(
...@@ -85,9 +90,14 @@ class FlatParameter(nn.Parameter): ...@@ -85,9 +90,14 @@ class FlatParameter(nn.Parameter):
) )
return (t.view(s) for (t, s) in zip(data.split(self._param_numels), self._param_shapes)) return (t.view(s) for (t, s) in zip(data.split(self._param_numels), self._param_shapes))
def __setstate__(self, state: Tuple[Any, Any]) -> None: def metadata(self) -> Tuple[List[str], List[torch.Size], List[int]]:
"""Return tuple of (names, shapes, numels) metadata for this flat parameter."""
names = [".".join([m, n]) if m else n for (m, _, n) in self._param_infos]
return names, self._param_shapes, self._param_numels
def __setstate__(self, state: Tuple[Any, Any, Any, Any]) -> None:
""" Use by pickle to set the internal states. """ """ Use by pickle to set the internal states. """
self._param_numels, self._param_shapes = state (self._param_numels, self._param_shapes, self._param_infos, self._shared_param_infos) = state
assert self.numel() <= sum( assert self.numel() <= sum(
self._param_numels self._param_numels
), f"Incorrect pickling {self.numel()} vs. {sum(self._param_numels)}" ), f"Incorrect pickling {self.numel()} vs. {sum(self._param_numels)}"
...@@ -96,8 +106,10 @@ class FlatParameter(nn.Parameter): ...@@ -96,8 +106,10 @@ class FlatParameter(nn.Parameter):
""" Support pickling between ranks. """ """ Support pickling between ranks. """
return ( return (
FlatParameter, # Callable FlatParameter, # Callable
([self.data], self.requires_grad), # Args to the callable above # Args to the callable above
(self._param_numels, self._param_shapes), # Args to __setstate__ ([self.data], self.requires_grad),
# Args to __setstate__
(self._param_numels, self._param_shapes, self._param_infos, self._shared_param_infos),
) )
...@@ -115,19 +127,27 @@ class FlattenParamsWrapper(nn.Module): ...@@ -115,19 +127,27 @@ class FlattenParamsWrapper(nn.Module):
- handles state_dict/load_state_dict transparently - handles state_dict/load_state_dict transparently
- is renamed to FlattenParamsWrapper - is renamed to FlattenParamsWrapper
- refactored to use the FlatParameter class - refactored to use the FlatParameter class
- extended to support flattening multiple groups of params - extended to support flattening multiple groups of params (useful
when different groups of params need different hyperparameters, like
learning rate or weight decay)
[1] https://github.com/SsnL/PyTorch-Reparam-Module [1] https://github.com/SsnL/PyTorch-Reparam-Module
Args: Args:
module (nn.Module): module (nn.Module):
module to wrap. The module to wrap.
param_list (Optional[List[List[nn.Parameter]]]): param_list (Optional[List[List[nn.Parameter]]]):
only flatten parameters appearing in the given groups Only flatten parameters appearing in the given groups.
If the param_list is an empty list, then no parameters will get flattened.
Note, if a single param is in one of the list, it still get flattened and the
original param is removed and replaced with the flatten one.
Default: None, flatten all parameters (if any) Default: None, flatten all parameters (if any)
flat_param_names (Optional[List[str]]):
originally, give each flat_param a unique name. Note a "flat_param_"
prefix will be added to those names.
""" """
def __init__(self, module: nn.Module, param_list: ParamGroups = None): def __init__(self, module: nn.Module, param_list: ParamGroups = None, flat_param_names: Optional[List[str]] = None):
super().__init__() super().__init__()
self._fpw_module = module self._fpw_module = module
self.is_flattened = False self.is_flattened = False
...@@ -176,16 +196,22 @@ class FlattenParamsWrapper(nn.Module): ...@@ -176,16 +196,22 @@ class FlattenParamsWrapper(nn.Module):
raise ValueError(f"Incorrect param groups {len(overall_param_set)} vs {self.num_param_managed}") raise ValueError(f"Incorrect param groups {len(overall_param_set)} vs {self.num_param_managed}")
self.flat_params: List[FlatParameter] = [] self.flat_params: List[FlatParameter] = []
self._param_infos: List[Tuple[str, nn.Module, str]] = []
self._shared_param_infos: List[Tuple[nn.Module, str, nn.Module, str]] = [] # Prepare flat param names.
if flat_param_names is None:
flat_param_names = [f"{i}" for i, _ in enumerate(self._param_sets)]
if len(flat_param_names) != len(self._param_sets):
raise ValueError("Names and number of param lists must be equal")
if len(flat_param_names) != len(set(flat_param_names)):
raise ValueError("Each flat param must be given a unique name")
self.flat_param_names = [f"flat_param_{n}" for n in flat_param_names]
# Init all flat_params. # Init all flat_params.
for new_p_set in self._param_sets: for new_p_set in self._param_sets:
params = self._init_flatten_params(new_p_set) params, param_infos, shared_param_infos = self._init_flatten_params(new_p_set)
assert (
len(set(p.requires_grad for p in params)) == 1
), "expects all parameters in the same parameter group of the module to have same requires_grad"
flat_param = FlatParameter(params, params[0].requires_grad) flat_param = FlatParameter(params, params[0].requires_grad)
flat_param._param_infos = param_infos
flat_param._shared_param_infos = shared_param_infos
self.flat_params.append(flat_param) self.flat_params.append(flat_param)
self._flatten_params(self.flat_params) self._flatten_params(self.flat_params)
...@@ -201,17 +227,12 @@ class FlattenParamsWrapper(nn.Module): ...@@ -201,17 +227,12 @@ class FlattenParamsWrapper(nn.Module):
self._auto_unflatten_state_dict = True self._auto_unflatten_state_dict = True
@property @property
def module(self) -> nn.Module: def module(self) -> Any:
""" Support fpw.module in case we are immitating DDP, which has .module """ Support fpw.module in case we are immitating DDP, which has .module
property to the underlying module. property to the underlying module.
""" """
return self._fpw_module return self._fpw_module
@property
def param_path_infos(self) -> List[Tuple[str, str]]:
""" Returns the list of tuples that contains (module_name, param_name). """
return [(m, n) for (m, _, n) in self._param_infos]
@property @property
def flat_param(self) -> nn.Parameter: def flat_param(self) -> nn.Parameter:
""" We used to support only a single flat_param. This allows us to """ We used to support only a single flat_param. This allows us to
...@@ -220,11 +241,16 @@ class FlattenParamsWrapper(nn.Module): ...@@ -220,11 +241,16 @@ class FlattenParamsWrapper(nn.Module):
assert len(self.flat_params) == 1, "Incorrect access to flat_param" assert len(self.flat_params) == 1, "Incorrect access to flat_param"
return self.flat_params[0] return self.flat_params[0]
def _init_flatten_params(self, p_set: Set[Tuple[nn.Module, str]]) -> List[nn.Parameter]: def _init_flatten_params(
self, p_set: Set[Tuple[nn.Module, str]]
) -> Tuple[
List[nn.Parameter], List[Tuple[str, nn.Module, str]], List[Tuple[str, str, nn.Module, str, nn.Module, str]]
]:
""" Build metadata for need-to-be-flatten parameters and returns a list """ Build metadata for need-to-be-flatten parameters and returns a list
contains the need-to-be-flatten parameters. contains the need-to-be-flatten parameters.
This extends self._param_infos and self._shared_param_infos lists. This also returns param_infos and shared_param_infos, which
will be attached to the flat parameter object.
Args: Args:
p_set (set): p_set (set):
...@@ -232,29 +258,33 @@ class FlattenParamsWrapper(nn.Module): ...@@ -232,29 +258,33 @@ class FlattenParamsWrapper(nn.Module):
to be flattened. There could be shared params in this set. to be flattened. There could be shared params in this set.
""" """
param_infos = [] param_infos = []
shared_param_memo: Dict[nn.Parameter, Tuple[nn.Module, str]] = {} shared_param_memo: Dict[nn.Parameter, Tuple[str, nn.Module, str]] = {}
shared_param_infos = [] shared_param_infos = []
params = [] params = []
for module_name, m in self.named_modules(): for module_name, m in self.named_modules():
for n, p in m.named_parameters(recurse=False): for n, p in m.named_parameters(recurse=False):
if p is not None and (m, n) in p_set: if p is not None and (m, n) in p_set:
if p in shared_param_memo: if p in shared_param_memo:
shared_m, shared_n = shared_param_memo[p] mname, shared_m, shared_n = shared_param_memo[p]
shared_param_infos.append((m, n, shared_m, shared_n)) shared_param_infos.append((module_name, mname, m, n, shared_m, shared_n))
else: else:
shared_param_memo[p] = (m, n) shared_param_memo[p] = (module_name, m, n)
param_infos.append((module_name, m, n)) param_infos.append((module_name, m, n))
params.append(p) params.append(p)
del shared_param_memo del shared_param_memo
assert len(set(p.dtype for p in params)) == 1, "expects all parameters in module to have same dtype" assert len(set(p.dtype for p in params)) == 1, "expects all parameters to have same dtype"
assert len(set(p.requires_grad for p in params)) == 1, "expects all parameters to have same requires_grad"
assert len(params) == len(set(params)), "params list should not have dups"
return params, param_infos, shared_param_infos
# store the info for unflatten @property
self._param_infos += param_infos def _param_infos(self) -> Iterator[Tuple[str, nn.Module, str]]:
self._shared_param_infos += shared_param_infos return chain(*[p._param_infos for p in self.flat_params])
assert len(params) == len(set(params)), "params list should not have dups" @property
return params def _shared_param_infos(self) -> Iterator[Tuple[str, str, nn.Module, str, nn.Module, str]]:
return chain(*[p._shared_param_infos for p in self.flat_params])
def _flatten_params(self, flat_params: List[FlatParameter]) -> None: def _flatten_params(self, flat_params: List[FlatParameter]) -> None:
""" Flatten the managed parameters and replaced the original """ Flatten the managed parameters and replaced the original
...@@ -263,15 +293,16 @@ class FlattenParamsWrapper(nn.Module): ...@@ -263,15 +293,16 @@ class FlattenParamsWrapper(nn.Module):
assert not self.is_flattened assert not self.is_flattened
self.is_flattened = True self.is_flattened = True
# flatten # register the flatten ones and save it to self.
assert len(self.flat_param_names) == len(flat_params), f"{len(self.flat_param_names)} vs. {len(flat_params)}"
for n, flat_param in zip(self.flat_param_names, flat_params):
self.register_parameter(n, flat_param)
self.flat_params = flat_params self.flat_params = flat_params
for i, flat_param in enumerate(flat_params):
self.register_parameter(f"flat_param_{i}", flat_param)
# deregister the names as parameters # deregister the names as parameters
for _, m, n in self._param_infos: for _, m, n in self._param_infos:
delattr(m, n) delattr(m, n)
for m, n, _, _ in self._shared_param_infos: for _, _, m, n, _, _ in self._shared_param_infos:
delattr(m, n) delattr(m, n)
# register the views as plain attributes # register the views as plain attributes
...@@ -289,14 +320,14 @@ class FlattenParamsWrapper(nn.Module): ...@@ -289,14 +320,14 @@ class FlattenParamsWrapper(nn.Module):
if hasattr(m, n): if hasattr(m, n):
delattr(m, n) delattr(m, n)
m.register_parameter(n, nn.Parameter(p)) m.register_parameter(n, nn.Parameter(p))
for (m, n, shared_m, shared_n) in self._shared_param_infos: for (_, _, m, n, shared_m, shared_n) in self._shared_param_infos:
if hasattr(m, n): if hasattr(m, n):
delattr(m, n) delattr(m, n)
m.register_parameter(n, getattr(shared_m, shared_n)) m.register_parameter(n, getattr(shared_m, shared_n))
for i, _ in enumerate(self.flat_params): for n in self.flat_param_names:
# This ensures the flat params are removed from the module. # This ensures the flat params are removed from the module.
delattr(self, f"flat_param_{i}") delattr(self, n)
self.flat_params = [] self.flat_params = []
def _unflatten_params_as_views(self) -> None: def _unflatten_params_as_views(self) -> None:
...@@ -307,7 +338,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -307,7 +338,7 @@ class FlattenParamsWrapper(nn.Module):
ps = self.get_param_views() ps = self.get_param_views()
for (_, m, n), p in zip(self._param_infos, ps): for (_, m, n), p in zip(self._param_infos, ps):
setattr(m, n, p) # This will set as plain attr setattr(m, n, p) # This will set as plain attr
for (m, n, shared_m, shared_n) in self._shared_param_infos: for (_, _, m, n, shared_m, shared_n) in self._shared_param_infos:
setattr(m, n, getattr(shared_m, shared_n)) setattr(m, n, getattr(shared_m, shared_n))
@contextmanager @contextmanager
...@@ -350,6 +381,10 @@ class FlattenParamsWrapper(nn.Module): ...@@ -350,6 +381,10 @@ class FlattenParamsWrapper(nn.Module):
except AttributeError: except AttributeError:
return getattr(self.module, name) # fallback to wrapped module return getattr(self.module, name) # fallback to wrapped module
def __getitem__(self, key: int) -> Any:
"""Forward indexing calls in case the module is a nn.Sequential."""
return self.module.__getitem__(key)
@typing.overload @typing.overload
def state_dict( def state_dict(
self, destination: Mapping[str, Tensor], prefix: str = ..., keep_vars: bool = ... self, destination: Mapping[str, Tensor], prefix: str = ..., keep_vars: bool = ...
...@@ -411,9 +446,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -411,9 +446,7 @@ class FlattenParamsWrapper(nn.Module):
self._unflatten_params_as_views() self._unflatten_params_as_views()
return self.module(*inputs, **kwinputs) return self.module(*inputs, **kwinputs)
def get_param_views( def get_param_views(self, external_data_list: Optional[List[Optional[Tensor]]] = None) -> Iterator[Tensor]:
self, external_data_list: Optional[List[Optional[Tensor]]] = None
) -> Generator[Tensor, None, None]:
""" Used to get a generator over all views from a list of external data list. """ """ Used to get a generator over all views from a list of external data list. """
params = self.flat_params params = self.flat_params
if external_data_list is None: if external_data_list is None:
...@@ -426,7 +459,11 @@ class FlattenParamsWrapper(nn.Module): ...@@ -426,7 +459,11 @@ class FlattenParamsWrapper(nn.Module):
for p, data in zip(params, external_data_list): for p, data in zip(params, external_data_list):
gens.append(p.get_param_views(data)) gens.append(p.get_param_views(data))
return chain(*gens) # type: ignore return chain(*gens)
def metadata(self, flat_param_idx: int) -> Tuple[List[str], List[torch.Size], List[int]]:
"""Return metadata for a flat param given its index in the flat_params list."""
return self.flat_params[flat_param_idx].metadata()
def _post_state_dict_hook( def _post_state_dict_hook(
......
...@@ -2,14 +2,23 @@ ...@@ -2,14 +2,23 @@
# #
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import functools
import os
import tempfile
from parameterized import parameterized
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from torch.optim import Adam
from fairscale.nn import FullyShardedDataParallel from fairscale.nn import FullyShardedDataParallel
from fairscale.utils.testing import in_temporary_directory, skip_if_single_gpu, temp_files_ctx from fairscale.utils.testing import in_temporary_directory, skip_if_single_gpu, temp_files_ctx
from tests.nn.data_parallel.test_fsdp import DistributedTest, MixtureOfExperts, rename_test, spawn_and_init
USE_TEMPFILE = True # False for debugging
class ConvolutionalModel(nn.Module): class ConvolutionalModel(nn.Module):
...@@ -145,3 +154,102 @@ def test_consolidation(embedding_size: int, flatten_parameters: bool): ...@@ -145,3 +154,102 @@ def test_consolidation(embedding_size: int, flatten_parameters: bool):
with in_temporary_directory(): with in_temporary_directory():
with temp_files_ctx(num=1) as temp_files: with temp_files_ctx(num=1) as temp_files:
mp.spawn(_worker, (temp_files[0], world_size, embedding_size, flatten_parameters), nprocs=world_size) mp.spawn(_worker, (temp_files[0], world_size, embedding_size, flatten_parameters), nprocs=world_size)
@skip_if_single_gpu
class TestConsolidatedWeights(DistributedTest):
@parameterized.expand(
[[True], [False]], name_func=rename_test,
)
def test_consolidate_weights(self, transformer):
config = {"mixed_precision": True, "flatten_parameters": True, "compute_dtype": torch.float32}
world_size = min(torch.cuda.device_count(), 4)
if USE_TEMPFILE:
with tempfile.TemporaryDirectory() as d:
paths = [os.path.join(d, f"checkpoint_{rank}.pt") for rank in range(world_size)]
test_fn = functools.partial(
self._test_consolidate_weights, config, transformer=transformer, paths=paths
)
spawn_and_init(test_fn, world_sizes=[world_size])
else:
paths = [f"checkpoint_{rank}.pt" for rank in range(world_size)]
test_fn = functools.partial(self._test_consolidate_weights, config, transformer=transformer, paths=paths)
spawn_and_init(test_fn, world_sizes=[world_size])
@classmethod
def _test_consolidate_weights(self, config, rank, group, paths=None, transformer=False):
"""FSDP.gather_full_optim_state_dict() should return something very similar to optimizer.state_dict()"""
# Establish reference behavior.
if transformer:
fsdp = self.get_wrapped_model(group, config=config).cuda()
else:
fsdp = FullyShardedDataParallel(MixtureOfExperts(group, wrapper_config=config)).cuda()
optim = Adam(fsdp.parameters(), lr=0.01,)
optim.zero_grad()
with torch.cuda.amp.autocast(enabled=True):
x = fsdp.module.get_input(torch.device("cuda"))
output = fsdp(*x)
loss = fsdp.module.get_loss(x, output).to("cuda")
fsdp.module.run_backward(loss)
optim.step()
# each worker saves a checkpoint with local_state_dict
cp_data = {
"weights": {k: v.cpu() for k, v in fsdp.local_state_dict().items()},
"meta": fsdp.local_metadata_dict(),
}
torch.save(cp_data, paths[fsdp.rank])
full_model_state_dict = fsdp.state_dict()
torch.distributed.barrier()
if fsdp.rank > 0:
return
all_checkpoints = [torch.load(p) for p in paths]
consolidated_checkpoint = FullyShardedDataParallel.consolidate_shard_weights(
shard_weights=[c["weights"] for c in all_checkpoints], shard_metadata=[c["meta"] for c in all_checkpoints],
)
full_model_extra = set(full_model_state_dict).difference(set(consolidated_checkpoint))
consolidated_extra = set(consolidated_checkpoint).difference(set(full_model_state_dict))
msg = f"full model extra keys: {full_model_extra}, consolidated extra {consolidated_extra}"
for k in full_model_state_dict.keys():
assert consolidated_checkpoint[k].shape == full_model_state_dict[k].shape
assert set(full_model_state_dict.keys()) == set(consolidated_checkpoint.keys()), msg
def test_consolidate_missing_params():
"""This tests that fairseq experts, which are saved independently from the rest of the model, can be consolidated."""
desired_path = "decoder.layers.1.moe_layer.experts.0"
shard_metadata = {
"param_metadata": [
{
"fsdp_path": "",
"params": {
"flat_param_0": {"names": ["missing"], "shapes": [(12, 4)], "numels": [12 * 4], "padding": 0}
},
"no_broadcast_optim_state": False,
"shared_param_info": [],
},
{
"fsdp_path": desired_path,
"params": {
"flat_param_0": {
"names": ["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"],
"shapes": [(4, 4), (4,), (4, 4), (4,)],
"numels": [16, 4, 16, 4],
"padding": 0,
}
},
"no_broadcast_optim_state": True,
"shared_param_info": [],
},
],
"buffer_names": ["missing.buffer"],
}
shard_weights = {"decoder.layers.1.moe_layer.experts.0.flat_param_0": torch.randn(40, dtype=torch.float16)}
consolidated_weights = FullyShardedDataParallel.consolidate_shard_weights(
[shard_weights], [shard_metadata], strict=False
)
assert len(consolidated_weights) == 4
for k in consolidated_weights:
assert k.startswith(desired_path), f"{k} doesnt start with {desired_path}"
...@@ -60,7 +60,7 @@ class TestLocalStateDict(DistributedTest): ...@@ -60,7 +60,7 @@ class TestLocalStateDict(DistributedTest):
# Assert that parameters were updated since before training # Assert that parameters were updated since before training
unchanged = [] unchanged = []
unwrapped_model = model.module.module if config["flatten_parameters"] else model.module unwrapped_model = model.module.module
buffers = {name for name, _ in unwrapped_model.named_buffers()} buffers = {name for name, _ in unwrapped_model.named_buffers()}
for k in state_1: for k in state_1:
if (state_before_training[k] == state_after_training[k]).all() and (k not in buffers): if (state_before_training[k] == state_after_training[k]).all() and (k not in buffers):
......
...@@ -24,6 +24,8 @@ except ImportError: ...@@ -24,6 +24,8 @@ except ImportError:
class TestAutoWrap(unittest.TestCase): class TestAutoWrap(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
# For all the tests here, we use a fake group and flatten being False since those should
# not affect how wrapping work.
self.process_group = DummyProcessGroup(rank=0, size=1) self.process_group = DummyProcessGroup(rank=0, size=1)
def test_wrap(self): def test_wrap(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment