Unverified Commit d60fc284 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat] supporting multiple flatten parameter groups (step 1 and step 1.5) (#708)



* refactoring FlattenParamWrapper

- use a FlatParameter class to encapsulate the logic of
  flattening and expanding into views.
- this will make it easier to have multiple groups of flatten
  parameters

* fixed testing context issues for both temp files and temp dirs

* fixing test_fsdp_metadata

* fix pickling of FlatParameter

* fixed test_fsdp_optimizer_utils.py

* minor

* fix assert

* lint

* remove nesting from the test

* step 1.5: remove the code related unnecessary nesting support in FPW

* Update fairscale/nn/misc/flatten_params_wrapper.py
Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>

* address comment
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>
parent 1e4a503c
......@@ -4,10 +4,12 @@
# LICENSE file in the root directory of this source tree.
"""These functions are used by FullyShardedDataParallel to help consolidate and shard optimizer states."""
import copy
from typing import Any, Dict, Generator, List, Tuple
from typing import Any, Dict, Generator, List, Tuple, cast
import torch
from fairscale.nn.misc import FlattenParamsWrapper
# These return keys are used by fairseq. To change, add @sshleifer as a reviewer.
UNFLAT_RETURN_KEYS = {"state", "param_groups", "uncollected_local_ids", "param_id_map"}
......@@ -96,8 +98,9 @@ def _unflatten_optim_state(
# we check that these are identical across workers and then take the first
non_tensor_state = [_extract_non_tensor_state(combined_state, id) for id in combined_state]
# local corresponds to flattened, global corresponds to unflattened
num_global_params = [len(m._param_numels) for m in instance_list] # type: ignore
# Local corresponds to flattened, global corresponds to unflattened.
# Casting needed only for mypy.
num_global_params = [cast(int, m.num_params_managed) for m in instance_list]
global_to_local_id = {}
for local_id, num_unflat in enumerate(num_global_params):
for _ in range(num_unflat):
......@@ -126,7 +129,8 @@ def _unflatten_optim_state(
assert isinstance(v, list), f"got {k}: {v} for {local_id}"
v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, pad_info[local_id])]
flat_buffer = torch.cat(v_unpad)
param_views: Generator = instance_list[local_id].get_param_views(flat_buffer) # type: ignore
# Casting needed only for mypy.
param_views: Generator = cast(FlattenParamsWrapper, instance_list[local_id]).get_param_views([flat_buffer])
for global_id, param_view in zip(sorted(local_to_global[local_id]), param_views):
assert k not in unflat_state[global_id], f"already added {k} to {global_id} {local_id}"
unflat_state[global_id][k] = param_view
......@@ -154,9 +158,10 @@ def build_unflat_state_dict(
singleton_state[local_id] = {buffer_name: [x] for buffer_name, x in v.items() if is_singleton_tensor(x)}
# local ids are in the current state, global_ids will be in returned state.
unflat_state, global_to_local_id = _unflatten_optim_state(state, instance_list, world_pad_info, singleton_state)
# Since there are no tensors in param_groups, deepcopy is fine
# Since there are no tensors in param_groups, deepcopy is fine.
param_groups = copy.deepcopy(param_groups)
num_params = sum([len(m._param_numels) for m in instance_list]) # type: ignore
# Casting needed only for mypy.
num_params = sum([cast(int, m.num_params_managed) for m in instance_list])
param_groups[0]["params"] = list(range(num_params))
unflat_optim_state_dict = {
"state": dict(sorted(unflat_state.items())), # NOTE: this is probably already sorted
......
......@@ -833,7 +833,7 @@ class FullyShardedDataParallel(nn.Module):
# latter may contain padding.
assert len(self.params) == 1
assert isinstance(self.module, FlattenParamsWrapper)
stack.enter_context(self.module.unflatten_params(recurse=False, flat_param=self.params[0]))
stack.enter_context(self.module.unflatten_params(flat_param=self.params[0]))
try:
yield
finally:
......@@ -1534,7 +1534,7 @@ class FullyShardedDataParallel(nn.Module):
# There are as many sharded parameters as there parameters in the
# consolidated model, so we only need to export how to reshape the
# parameters to their orginal shape and take care of the padding
if not hasattr(m, "_param_numels"):
if not m.flatten_parameters:
params_metadata.append(
{
"fsdp_path": _clean_path(path),
......@@ -1563,8 +1563,11 @@ class FullyShardedDataParallel(nn.Module):
"is_flat": True,
"num_padded": m.numel_padded_per_param,
"param_names": param_names,
"param_shapes": m._param_shapes,
"param_numels": m._param_numels,
# TODO (Min): we don't want to access the private _param_shapes and
# _param_numels here. We want to dump metadata from FPW when there are
# multiple groups of params.
"param_shapes": m._fsdp_wrapped_module.flat_param._param_shapes, # type: ignore
"param_numels": m._fsdp_wrapped_module.flat_param._param_numels, # type: ignore
"no_broadcast_optim_state": m.no_broadcast_optim_state,
}
)
......@@ -1589,7 +1592,7 @@ class FullyShardedDataParallel(nn.Module):
If this behavior is not correct for your module (for instance if buffers
needs to be reduced instead), you can disable it with `with_module_buffers=False`.
This method is very useful to re-assemble checkpoints of shards without
This method is used to re-assemble checkpoints of shards without
having to instantiate FSDP wrappers with the world size originally used
to save the shards.
"""
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Tongzhou Wang
# Licensed under the MIT License.
from contextlib import ExitStack, contextmanager
from contextlib import contextmanager
from itertools import chain
import typing
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Mapping, NamedTuple, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Mapping, NamedTuple, Optional, Sequence, Tuple, Union
import torch
from torch import Tensor
......@@ -15,6 +21,72 @@ if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401
class FlatParameter(nn.Parameter):
""" A parameter that is initialized from a list of parameters and can be
turned into a list of views as needed.
"""
def __new__(cls, params: Sequence[nn.Parameter], requires_grad: bool = True) -> "FlatParameter":
""" Make an object using the parent's __new__ function. """
# A empty of non-list input doesn't make sense.
if not isinstance(params, (list, tuple)) or len(params) == 0:
raise ValueError("An non-empty list or tuple argument is needed")
# Normally, all items are Parameters. But during pickling, we will have a single
# Tensor as the input and later in __init__, the correct _param_numels and _param_shapes
# are set.
if not all(isinstance(p, (nn.Parameter, Tensor)) for p in params):
raise ValueError("List items need to be Parameter types")
# Flattening involves (1) making a tensor flat (i.e. single dimensional) and (2) making a module
# heirarchy flat (using a single tensor to replace a tree of tensors). Therefore,
# adding back nesting and heirarchy is counter-productive. If nesting is encountered
# in the future, the reasonable thing to do is likely for the top level FlatParameter to
# absorb the nested one and keep the result flat, free from hierarchy.
if any(isinstance(p, FlatParameter) for p in params):
raise ValueError("Nesting FlatParameter is not supported")
data = torch.cat([p.detach().reshape(-1) if isinstance(p, nn.Parameter) else p.reshape(-1) for p in params], 0)
return super(FlatParameter, cls).__new__(cls, data, requires_grad=requires_grad)
def __init__(self, params: Sequence[nn.Parameter], requires_grad: bool = True):
""" Initialize the _param_numels and _param_shapes lists. """
self._param_numels = [p.numel() for p in params]
assert self.numel() <= sum(
self._param_numels
), f"Something wrong with __new__ method, {self.numel()} vs. {sum(self._param_numels)}"
self._param_shapes = [p.size() for p in params]
def get_param_views(self, external_data: Optional[Tensor] = None) -> Generator[Tensor, None, None]:
""" Return a generator of views that map to the original parameters. """
# Note, self.data could be sharded, so its numel is <= to the sum.
assert self.data.numel() <= sum(
self._param_numels
), f"Incorrect internal state {self.data.numel()} vs. {sum(self._param_numels)}"
data = external_data if external_data is not None else self
if data.numel() != sum(self._param_numels):
raise ValueError(
f"Incorrect numel of supplied data: got {data.numel()} but expected {sum(self._param_numels)}"
)
return (t.view(s) for (t, s) in zip(data.split(self._param_numels), self._param_shapes))
def __setstate__(self, state: Tuple[Any, Any]) -> None:
""" Use by pickle to set the internal states. """
self._param_numels, self._param_shapes = state
assert self.numel() <= sum(
self._param_numels
), f"Incorrect pickling {self.numel()} vs. {sum(self._param_numels)}"
def __reduce_ex__(self, proto: int) -> Tuple[Any, Any, Any]:
""" Support pickling between ranks. """
return (
FlatParameter, # Callable
([self.data], self.requires_grad), # Args to the callable above
(self._param_numels, self._param_shapes), # Args to __setstate__
)
class FlattenParamsWrapper(nn.Module):
"""
A wrapper for transparently flattening a Module's parameters.
......@@ -24,6 +96,7 @@ class FlattenParamsWrapper(nn.Module):
- supports shared parameters
- handles state_dict/load_state_dict transparently
- is renamed to FlattenParamsWrapper
- refactored to use the FlatParameter class
[1] https://github.com/SsnL/PyTorch-Reparam-Module
......@@ -44,6 +117,10 @@ class FlattenParamsWrapper(nn.Module):
param_list = list(module.parameters())
param_set = set(param_list)
# Since the parameters will be deleted, let's record the number original parameters
# managed by this class.
self.num_params_managed = len(param_set)
# convert from list of Parameters to set of (Module, name) tuples, which
# will survive in case the Parameter instances are reset
self._param_set = set()
......@@ -52,6 +129,7 @@ class FlattenParamsWrapper(nn.Module):
if p in param_set:
self._param_set.add((m, n))
# TODO (Min): double check we handle the special case of module without any params.
self._flatten_params()
# Register hook to be called after state_dict() to remove the
......@@ -68,14 +146,15 @@ class FlattenParamsWrapper(nn.Module):
def module(self) -> nn.Module:
return self._fpw_module
def _init_flatten_params(self) -> List[Tensor]:
def _init_flatten_params(self) -> List[nn.Parameter]:
""" Build metadata for need-to-be-flatten parameters and returns a list
contains the need-to-be-flatten parameters.
"""
param_infos = []
param_full_infos = []
shared_param_memo: Dict[nn.Parameter, Tuple[nn.Module, str]] = {}
shared_param_infos = []
params = []
param_numels = []
param_shapes = []
for module_name, m in self.named_modules():
for n, p in m.named_parameters(recurse=False):
if p is not None and (m, n) in self._param_set:
......@@ -86,9 +165,7 @@ class FlattenParamsWrapper(nn.Module):
shared_param_memo[p] = (m, n)
param_infos.append((m, n))
param_full_infos.append((module_name, n))
params.append(p.detach())
param_numels.append(p.numel())
param_shapes.append(p.size())
params.append(p)
del shared_param_memo
assert len(set(p.dtype for p in params)) <= 1, "expects all parameters in module to have same dtype"
......@@ -97,8 +174,6 @@ class FlattenParamsWrapper(nn.Module):
self._param_infos = tuple(param_infos)
self._param_full_infos = tuple(param_full_infos)
self._shared_param_infos = tuple(shared_param_infos)
self._param_numels = tuple(param_numels)
self._param_shapes = tuple(param_shapes)
return params
......@@ -109,9 +184,8 @@ class FlattenParamsWrapper(nn.Module):
if not hasattr(self, "_param_infos"):
assert flat_param is None
params = self._init_flatten_params()
flat_param = nn.Parameter(torch.cat([p.reshape(-1) for p in params], 0))
flat_param = FlatParameter(params)
self.param_numel = flat_param.numel()
del params
# flatten
assert flat_param is not None
......@@ -126,15 +200,14 @@ class FlattenParamsWrapper(nn.Module):
# register the views as plain attributes
self._unflatten_params_as_views()
def get_param_views(self, flat_param: Tensor) -> Generator:
return (t.view(s) for (t, s) in zip(flat_param.split(self._param_numels), self._param_shapes))
def _unflatten_params(self, flat_param: Optional[Tensor] = None) -> None:
assert self.is_flattened or flat_param is not None
def _unflatten_params(self, external_data: Optional[Tensor] = None) -> None:
""" Undo flattening and create separate parameters from the already flattened
self.flat_param or a user supplied external data.
"""
assert self.is_flattened or external_data is not None
self.is_flattened = False
flat_param = flat_param if flat_param is not None else self.flat_param
ps = self.get_param_views(flat_param)
ps = self.get_param_views([external_data])
for (m, n), p in zip(self._param_infos, ps):
if hasattr(m, n):
delattr(m, n)
......@@ -147,22 +220,23 @@ class FlattenParamsWrapper(nn.Module):
del self.flat_param
def _unflatten_params_as_views(self) -> None:
""" Unlike ``_unflatten_params``, this function unflatten into views and keep
self.flat_param unchanged.
"""
assert self.is_flattened
ps = self.get_param_views(self.flat_param)
ps = self.get_param_views()
for (m, n), p in zip(self._param_infos, ps):
setattr(m, n, p) # This will set as plain attr
for (m, n, shared_m, shared_n) in self._shared_param_infos:
setattr(m, n, getattr(shared_m, shared_n))
@contextmanager
def unflatten_params(self, recurse: bool = True, flat_param: Optional[Tensor] = None) -> Generator:
def unflatten_params(self, flat_param: Optional[Tensor] = None) -> Generator:
"""
Unflatten params. If the current instance is already unflattened, then
it will remain unflattened after the context manager exits.
Args:
recurse (bool, Optional): recursively unflatten all nested instances
(default: True)
flat_param (Tensor, Optional): flat param to use for unflattening.
If provided, the current instance must be in a flattened state
at the start of the context manager. The provided Tensor must be
......@@ -170,31 +244,20 @@ class FlattenParamsWrapper(nn.Module):
manager. After the context manager exits, we will revert to
using ``self.flat_param`` (default: None).
"""
if recurse:
with ExitStack() as stack:
# unflatten any nested FlattenParamsWrapper instances
for name, module in self.named_modules():
if isinstance(module, FlattenParamsWrapper):
is_self = name == ""
stack.enter_context(
module.unflatten_params(recurse=False, flat_param=flat_param if is_self else None)
)
# yield to the caller, with unflattened params in all nested instances
yield
# exiting from the ExitStack will re-flatten params
return
else:
assert (
flat_param is None or self.is_flattened
), "Unflattening with custom flat_param requires current instance to be flattened"
), "Unflattening with external flat_param requires current instance to be flattened"
orig_flattened = self.is_flattened
if orig_flattened:
orig_flat_param = self.flat_param
self._unflatten_params(flat_param)
# Put yield in a try...finally in case the caller catches the exception and handles
# it. In that case, we need to properly handle the undoing of state.
try:
yield
finally:
if orig_flattened:
self._flatten_params(orig_flat_param)
......@@ -217,29 +280,30 @@ class FlattenParamsWrapper(nn.Module):
# Since we have overloads above, we can use Any here.
def state_dict(self, *args: Any, **kwargs: Any) -> Any:
"""Return the wrapped module's state_dict (unflattened)."""
"""Return the wrapped module's state_dict."""
if self.is_flattened and self._auto_unflatten_state_dict:
with self.unflatten_params(recurse=False):
# Returns the original version.
with self.unflatten_params():
return super().state_dict(*args, **kwargs)
else:
# Returns flattened version.
return super().state_dict(*args, **kwargs)
def flat_state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
"""Return the flattened state_dict."""
assert self.is_flattened
with ExitStack() as stack:
# tell any nested FlattenParamsWrapper instances not to auto unflatten
for module in self.modules(): # includes self
if isinstance(module, FlattenParamsWrapper):
stack.enter_context(module._no_auto_unflatten_state_dict())
state_dict = self.state_dict(*args, **kwargs)
return state_dict
with self._no_auto_unflatten_state_dict():
return self.state_dict(*args, **kwargs)
@contextmanager
def _no_auto_unflatten_state_dict(self) -> Generator:
backup = self._auto_unflatten_state_dict
self._auto_unflatten_state_dict = False
# Put yield in a try...finally in case the caller catches the exception and handles
# it. In that case, we need to properly handle the undoing of state.
try:
yield
finally:
self._auto_unflatten_state_dict = backup
def load_state_dict(
......@@ -251,15 +315,31 @@ class FlattenParamsWrapper(nn.Module):
"""
# unflatten the module automatically if the state_dict is non-flat
if self.is_flattened and "flat_param" not in state_dict:
with self.unflatten_params(recurse=True):
# This object is flatten but state_dict is not. So we unflatten and load.
with self.unflatten_params():
return super().load_state_dict(state_dict, strict)
else:
# Otherwise, load as it.
return super().load_state_dict(state_dict, strict)
def forward(self, *inputs: Any, **kwinputs: Any) -> Any:
self._unflatten_params_as_views()
return self.module(*inputs, **kwinputs)
def get_param_views(
self, external_data_list: Optional[List[Optional[Tensor]]] = None
) -> Generator[Tensor, None, None]:
""" Used to get a generator over all views from a list of external data list. """
params = [self.flat_param] # For now, there is only a single flat param.
if external_data_list is None:
external_data_list = [None] * len(params)
gens = []
for p, data in zip(params, external_data_list):
gens.append(p.get_param_views(data))
return chain(*gens) # type: ignore
def _post_state_dict_hook(
module: nn.Module, state_dict: "OrderedDict[str, Tensor]", prefix: str, *args: Any
......
......@@ -674,7 +674,9 @@ def in_temporary_directory() -> Generator:
old_cwd = os.getcwd()
with tempfile.TemporaryDirectory() as temp_dir:
os.chdir(temp_dir)
try:
yield temp_dir
finally:
os.chdir(old_cwd)
......@@ -683,8 +685,9 @@ def temp_files_ctx(num: int) -> Generator:
""" A context to get tempfiles and ensure they are cleaned up. """
files = [tempfile.mkstemp()[1] for _ in range(num)]
try:
yield tuple(files)
finally:
# temp files could have been removed, so we use rmf.
for name in files:
rmf(name)
......
......@@ -121,6 +121,9 @@ class Tensor:
_has_been_cloned: Optional[bool] = ...
#END
@staticmethod
def _make_subclass(cls: Any, data: Tensor, requires_grad: builtins.bool) -> Any: ...
def __init__(self, *args, **kwargs) -> None: ...
@property
......
......@@ -15,6 +15,8 @@ class Parameter(Tensor):
_fp32_shard: Tensor
_fp16_shard: Optional[Tensor]
def __new__(cls, data: Tensor, requires_grad: builtins.bool = True): ...
def __init__(self, data: Tensor, requires_grad: builtins.bool = True): ...
...
......@@ -5,6 +5,7 @@
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from fairscale.nn import FullyShardedDataParallel
......@@ -139,7 +140,6 @@ def _worker(gpu_id: int, sync_file: str, world_size: int, embedding_size: int, f
@pytest.mark.parametrize("embedding_size", [128, 129])
@pytest.mark.parametrize("flatten_parameters", [True, False])
def test_consolidation(embedding_size: int, flatten_parameters: bool):
import torch.multiprocessing as mp
world_size = 2
with in_temporary_directory():
......
......@@ -19,8 +19,8 @@ from fairscale.utils.testing import objects_are_equal
class TestFlattenParams(unittest.TestCase):
def _get_module_init_fns(self):
return [
self._get_basic_linear_module,
self._get_shared_params_transformer,
self._get_nested_flat_module,
]
def _get_transformer(self, seed=0):
......@@ -47,13 +47,11 @@ class TestFlattenParams(unittest.TestCase):
dec_layer.linear2.weight = enc_layer.linear2.weight
return module
def _get_nested_flat_module(self, seed=0):
def _get_basic_linear_module(self, seed=0):
module = torch.nn.Sequential(
FlattenParamsWrapper(
torch.nn.Sequential(torch.nn.Linear(4, 8), FlattenParamsWrapper(torch.nn.Linear(8, 8)))
),
FlattenParamsWrapper(torch.nn.Sequential(FlattenParamsWrapper(torch.nn.Linear(8, 16)))),
FlattenParamsWrapper(torch.nn.Linear(16, 4)),
torch.nn.Sequential(torch.nn.Linear(4, 8), torch.nn.Linear(8, 8)),
torch.nn.Sequential(torch.nn.Linear(8, 16)),
torch.nn.Linear(16, 4),
)
def get_input(device, dtype):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment