Unverified Commit ab71efb3 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat] FSDP: supporting multiple flatten parameter groups (#711)



* [feat] FSDP: supporting multiple flatten parameter groups

- step 2: extending FPW to support multiple flat params groups
- FSDP still only use one group
- unit test does this the new code paths
- updated the changelog

* first cut, mypy passed

* test_flatten_params_wrapper.py::TestFlattenParams tests pass

* added two more test cases and fixed a case in the code

* fixed one bug with param_path_infos

* fixed two more tests with hardcoded flat_param names

* Update CHANGELOG.md
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent cec011bb
......@@ -6,11 +6,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## NEXT - TBD
### Fixed
- doc: thoroughly improved the doc
- checkpointing: use dummy tensor to ensure backward pass is called [#701]
- checkpointing: ensure internal fwd counter is not incremented in eval mode [#709]
- FSDP: fixed bug where buffers returned in `state_dict()` could still be half precision when `mixed_precision` is set to `True`.
- FSDP: fixed bug where buffers returned in `state_dict()` could still be half precision when `mixed_precision` is set to `True`. [#705]
### Added
- FSDP: supporting multiple flatten parameter groups [#708] [#711]
## [0.3.7] - 2021-05-17
### Fixed
......
......@@ -833,7 +833,7 @@ class FullyShardedDataParallel(nn.Module):
# latter may contain padding.
assert len(self.params) == 1
assert isinstance(self.module, FlattenParamsWrapper)
stack.enter_context(self.module.unflatten_params(flat_param=self.params[0]))
stack.enter_context(self.module.unflatten_params(flat_params=[self.params[0]]))
try:
yield
finally:
......@@ -1185,6 +1185,7 @@ class FullyShardedDataParallel(nn.Module):
# Register a hook on the first call, empirically, autograd
# fires it at the end for this param, which makes sense.
p_tmp = p.expand_as(p) # Get a grad_fn on p_tmp.
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
handle = grad_acc.register_hook(functools.partial(self._post_backward_hook, p))
p._shard_bwd_hook = (grad_acc, handle)
......@@ -1417,6 +1418,7 @@ class FullyShardedDataParallel(nn.Module):
output_tensors.append((p.data, True))
elif not p._is_sharded:
if self.mixed_precision and not force_full_precision:
assert p._fp16_shard is not None
p.data = p._fp16_shard
output_tensors.append((p.data, True))
else:
......@@ -1483,6 +1485,7 @@ class FullyShardedDataParallel(nn.Module):
for p in self.params:
if not p._is_sharded:
if self.mixed_precision:
assert p._fp16_shard is not None
assert p._fp16_shard.storage().size() != 0
p.data = p._fp16_shard
else:
......@@ -1554,7 +1557,7 @@ class FullyShardedDataParallel(nn.Module):
# used in the FlattenParamsWrapper
else:
param_names = []
for module_path, param_name in m._param_full_infos:
for module_path, param_name in m.param_path_infos:
full_param_path = module_path + "." + param_name if module_path else param_name
param_names.append(_clean_path(full_param_path))
params_metadata.append(
......@@ -1633,7 +1636,7 @@ class FullyShardedDataParallel(nn.Module):
# split to the original shape
else:
# Concatenate the flat_param parameter after removing the padding
flat_param_name = ".".join([fsdp_path, "flat_param"]) if fsdp_path else "flat_param"
flat_param_name = ".".join([fsdp_path, "flat_param_0"]) if fsdp_path else "flat_param_0"
shards = []
for rank in range(original_world_size):
shard = shard_weights[rank][flat_param_name]
......
......@@ -9,7 +9,21 @@
from contextlib import contextmanager
from itertools import chain
import typing
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Mapping, NamedTuple, Optional, Sequence, Tuple, Union
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generator,
List,
Mapping,
NamedTuple,
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
)
import torch
from torch import Tensor
......@@ -87,6 +101,10 @@ class FlatParameter(nn.Parameter):
)
# Static types.
ParamGroups = Optional[Union[List[List[nn.Parameter]], List[nn.Parameter]]]
class FlattenParamsWrapper(nn.Module):
"""
A wrapper for transparently flattening a Module's parameters.
......@@ -97,40 +115,77 @@ class FlattenParamsWrapper(nn.Module):
- handles state_dict/load_state_dict transparently
- is renamed to FlattenParamsWrapper
- refactored to use the FlatParameter class
- extended to support flattening multiple groups of params
[1] https://github.com/SsnL/PyTorch-Reparam-Module
Args:
module (nn.Module): module to wrap
param_list (Optional[List[nn.Parameter]]): only flatten parameters
appearing in the given list (default: flatten all parameters)
module (nn.Module):
module to wrap.
param_list (Optional[List[List[nn.Parameter]]]):
only flatten parameters appearing in the given groups
Default: None, flatten all parameters (if any)
"""
def __init__(self, module: nn.Module, param_list: Optional[List[nn.Parameter]] = None):
def __init__(self, module: nn.Module, param_list: ParamGroups = None):
super().__init__()
self._fpw_module = module
self.is_flattened = False
if param_list is not None:
assert len(param_list) > 0, "param_list can't be empty"
else:
# Handle param_list being None.
if param_list is None:
param_list = list(module.parameters())
param_set = set(param_list)
# Since the parameters will be deleted, let's record the number original parameters
# managed by this class.
self.num_params_managed = len(param_set)
# convert from list of Parameters to set of (Module, name) tuples, which
# will survive in case the Parameter instances are reset
self._param_set = set()
# Be backward compatible and turn a single param list into a list of
# a single list.
if len(param_list) > 0 and isinstance(param_list[0], nn.Parameter):
param_list = [cast(List[nn.Parameter], param_list)]
# Since the parameters will be deleted, let's record the number original
# parameters managed by this class. This and get_param_views function
# below are used by fsdp_optim_utils.py to save/restore optimizer state,
# which mirrors the flatten parameters here.
self.num_params_managed = 0
self._param_sets = []
overall_param_set: Set[nn.Parameter] = set()
for p_list in param_list:
# Remove any duplicates from the list.
p_set: Set[nn.Parameter] = set(cast(List[nn.Parameter], p_list))
self.num_params_managed += len(p_set)
overall_param_set = overall_param_set.union(p_set)
# Convert from list of Parameters to set of (Module, name) tuples,
# which will survive in case the parameter instances are reset.
# Also, a shared param will correctly appear under multiple modules
# as they should.
new_p_set_with_names = set()
for m in self.modules():
for n, p in m.named_parameters(recurse=False):
if p in param_set:
self._param_set.add((m, n))
if p in p_set:
new_p_set_with_names.add((m, n))
if new_p_set_with_names:
self._param_sets.append(new_p_set_with_names)
if len(overall_param_set) != self.num_params_managed:
# Each p_list above could have shared params. However, you can't
# have shared params cross different p_list. That means part of
# the flattened parameter must be shared, which is impossible to
# support.
raise ValueError(f"Incorrect param groups {len(overall_param_set)} vs {self.num_param_managed}")
self.flat_params: List[FlatParameter] = []
self._param_infos: List[Tuple[str, nn.Module, str]] = []
self._shared_param_infos: List[Tuple[nn.Module, str, nn.Module, str]] = []
# Init all flat_params.
for new_p_set in self._param_sets:
params = self._init_flatten_params(new_p_set)
flat_param = FlatParameter(params)
self.flat_params.append(flat_param)
# TODO (Min): double check we handle the special case of module without any params.
self._flatten_params()
self._flatten_params(self.flat_params)
# Register hook to be called after state_dict() to remove the
# "_fpw_module." prefix and before load_state_dict() to add it back.
......@@ -144,55 +199,74 @@ class FlattenParamsWrapper(nn.Module):
@property
def module(self) -> nn.Module:
""" Support fpw.module in case we are immitating DDP, which has .module
property to the underlying module.
"""
return self._fpw_module
def _init_flatten_params(self) -> List[nn.Parameter]:
@property
def param_path_infos(self) -> List[Tuple[str, str]]:
""" Returns the list of tuples that contains (module_name, param_name). """
return [(m, n) for (m, _, n) in self._param_infos]
@property
def flat_param(self) -> nn.Parameter:
""" We used to support only a single flat_param. This allows us to
be backward compatible.
"""
assert len(self.flat_params) == 1, "Incorrect access to flat_param"
return self.flat_params[0]
def _init_flatten_params(self, p_set: Set[Tuple[nn.Module, str]]) -> List[nn.Parameter]:
""" Build metadata for need-to-be-flatten parameters and returns a list
contains the need-to-be-flatten parameters.
This extends self._param_infos and self._shared_param_infos lists.
Args:
p_set (set):
A set of (module, param_name) for a set of params that needed
to be flattened. There could be shared params in this set.
"""
param_infos = []
param_full_infos = []
shared_param_memo: Dict[nn.Parameter, Tuple[nn.Module, str]] = {}
shared_param_infos = []
params = []
for module_name, m in self.named_modules():
for n, p in m.named_parameters(recurse=False):
if p is not None and (m, n) in self._param_set:
if p is not None and (m, n) in p_set:
if p in shared_param_memo:
shared_m, shared_n = shared_param_memo[p]
shared_param_infos.append((m, n, shared_m, shared_n))
else:
shared_param_memo[p] = (m, n)
param_infos.append((m, n))
param_full_infos.append((module_name, n))
param_infos.append((module_name, m, n))
params.append(p)
del shared_param_memo
assert len(set(p.dtype for p in params)) <= 1, "expects all parameters in module to have same dtype"
assert len(set(p.dtype for p in params)) == 1, "expects all parameters in module to have same dtype"
# store the info for unflatten
self._param_infos = tuple(param_infos)
self._param_full_infos = tuple(param_full_infos)
self._shared_param_infos = tuple(shared_param_infos)
self._param_infos += param_infos
self._shared_param_infos += shared_param_infos
assert len(params) == len(set(params)), "params list should not have dups"
return params
def _flatten_params(self, flat_param: Optional[nn.Parameter] = None) -> None:
def _flatten_params(self, flat_params: List[FlatParameter]) -> None:
""" Flatten the managed parameters and replaced the original
attributes with views to the flat params.
"""
assert not self.is_flattened
self.is_flattened = True
if not hasattr(self, "_param_infos"):
assert flat_param is None
params = self._init_flatten_params()
flat_param = FlatParameter(params)
self.param_numel = flat_param.numel()
# flatten
assert flat_param is not None
self.register_parameter("flat_param", flat_param)
self.flat_params = flat_params
for i, flat_param in enumerate(flat_params):
self.register_parameter(f"flat_param_{i}", flat_param)
# deregister the names as parameters
for m, n in self._param_infos:
for _, m, n in self._param_infos:
delattr(m, n)
for m, n, _, _ in self._shared_param_infos:
delattr(m, n)
......@@ -200,15 +274,15 @@ class FlattenParamsWrapper(nn.Module):
# register the views as plain attributes
self._unflatten_params_as_views()
def _unflatten_params(self, external_data: Optional[Tensor] = None) -> None:
def _unflatten_params(self, external_data: Optional[List[Optional[Tensor]]] = None) -> None:
""" Undo flattening and create separate parameters from the already flattened
self.flat_param or a user supplied external data.
"""
assert self.is_flattened or external_data is not None
self.is_flattened = False
ps = self.get_param_views([external_data])
for (m, n), p in zip(self._param_infos, ps):
ps = self.get_param_views(external_data)
for (_, m, n), p in zip(self._param_infos, ps):
if hasattr(m, n):
delattr(m, n)
m.register_parameter(n, nn.Parameter(p))
......@@ -216,8 +290,11 @@ class FlattenParamsWrapper(nn.Module):
if hasattr(m, n):
delattr(m, n)
m.register_parameter(n, getattr(shared_m, shared_n))
if hasattr(self, "flat_param"):
del self.flat_param
for i, _ in enumerate(self.flat_params):
# This ensures the flat params are removed from the module.
delattr(self, f"flat_param_{i}")
self.flat_params = []
def _unflatten_params_as_views(self) -> None:
""" Unlike ``_unflatten_params``, this function unflatten into views and keep
......@@ -225,41 +302,43 @@ class FlattenParamsWrapper(nn.Module):
"""
assert self.is_flattened
ps = self.get_param_views()
for (m, n), p in zip(self._param_infos, ps):
for (_, m, n), p in zip(self._param_infos, ps):
setattr(m, n, p) # This will set as plain attr
for (m, n, shared_m, shared_n) in self._shared_param_infos:
setattr(m, n, getattr(shared_m, shared_n))
@contextmanager
def unflatten_params(self, flat_param: Optional[Tensor] = None) -> Generator:
def unflatten_params(self, flat_params: Optional[List[Tensor]] = None) -> Generator:
"""
Unflatten params. If the current instance is already unflattened, then
it will remain unflattened after the context manager exits.
Args:
flat_param (Tensor, Optional): flat param to use for unflattening.
flat_params (List[Tensor], Optional):
flat params to use for unflattening.
If provided, the current instance must be in a flattened state
at the start of the context manager. The provided Tensor must be
appropriately sized and will only be used within the context
manager. After the context manager exits, we will revert to
using ``self.flat_param`` (default: None).
using ``self.flat_params``
Default: None.
"""
assert (
flat_param is None or self.is_flattened
flat_params is None or self.is_flattened
), "Unflattening with external flat_param requires current instance to be flattened"
orig_flattened = self.is_flattened
if orig_flattened:
orig_flat_param = self.flat_param
self._unflatten_params(flat_param)
orig_flat_params = self.flat_params
self._unflatten_params(cast(Optional[List[Optional[Tensor]]], flat_params))
# Put yield in a try...finally in case the caller catches the exception and handles
# it. In that case, we need to properly handle the undoing of state.
# it. In that case, we need to properly handle the undoing of state here.
try:
yield
finally:
if orig_flattened:
self._flatten_params(orig_flat_param)
self._flatten_params(orig_flat_params)
def __getattr__(self, name: str) -> Any:
"""Forward missing attributes to wrapped module."""
......@@ -314,12 +393,15 @@ class FlattenParamsWrapper(nn.Module):
match the input state_dict.
"""
# unflatten the module automatically if the state_dict is non-flat
if self.is_flattened and "flat_param" not in state_dict:
if self.is_flattened and "flat_param_0" not in state_dict:
# This object is flatten but state_dict is not. So we unflatten and load.
with self.unflatten_params():
return super().load_state_dict(state_dict, strict)
else:
# Otherwise, load as it.
# Otherwise, load it as is but make older state dict compatible.
if "flat_param" in state_dict:
state_dict["flat_param_0"] = state_dict["flat_param"]
del state_dict["flat_param"]
return super().load_state_dict(state_dict, strict)
def forward(self, *inputs: Any, **kwinputs: Any) -> Any:
......@@ -330,9 +412,12 @@ class FlattenParamsWrapper(nn.Module):
self, external_data_list: Optional[List[Optional[Tensor]]] = None
) -> Generator[Tensor, None, None]:
""" Used to get a generator over all views from a list of external data list. """
params = [self.flat_param] # For now, there is only a single flat param.
params = self.flat_params
if external_data_list is None:
external_data_list = [None] * len(params)
assert len(external_data_list) == len(
params
), f"Incorrect external data list: {len(external_data_list)} vs. {len(params)}"
gens = []
for p, data in zip(params, external_data_list):
......@@ -344,6 +429,7 @@ class FlattenParamsWrapper(nn.Module):
def _post_state_dict_hook(
module: nn.Module, state_dict: "OrderedDict[str, Tensor]", prefix: str, *args: Any
) -> "OrderedDict[str, Tensor]":
# Move everything from .fpw_module up one level.
replace_by_prefix_(state_dict, prefix + "_fpw_module.", prefix)
return state_dict
......@@ -351,8 +437,12 @@ def _post_state_dict_hook(
def _pre_load_state_dict_hook(
state_dict: Union[Dict[str, Tensor], "OrderedDict[str, Tensor]"], prefix: str, *args: Any
) -> None:
# Push everything down to ._fpw_module level.
replace_by_prefix_(state_dict, prefix, prefix + "_fpw_module.")
# flat_param actually needs to move one level up though
# The flat_param_* keys actually needs to move one level up.
flat_param_key = prefix + "_fpw_module.flat_param"
if flat_param_key in state_dict:
replace_by_prefix_(state_dict, flat_param_key, prefix + "flat_param")
for k in list(state_dict.keys()):
if k.startswith(flat_param_key):
last_part = k.split(".")[-1]
assert last_part.startswith("flat_param_"), last_part
replace_by_prefix_(state_dict, k, prefix + last_part)
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Optional
from typing import Optional, Tuple, Any
from .. import Size, Tensor
from ..cuda import Stream
import builtins
......@@ -14,6 +14,7 @@ class Parameter(Tensor):
_full_param_padded: Tensor
_fp32_shard: Tensor
_fp16_shard: Optional[Tensor]
_shard_bwd_hook: Tuple[Any, Any]
def __new__(cls, data: Tensor, requires_grad: builtins.bool = True): ...
......
......@@ -41,7 +41,7 @@ class TestLocalStateDict(DistributedTest):
state_before_training = {k: v.cpu().clone() for k, v in state_1.items()}
assert len(state_1) > 0
model.load_local_state_dict(state_1)
weight_key = "flat_param" if model.flatten_parameters else "embed_tokens.weight"
weight_key = "flat_param_0" if model.flatten_parameters else "embed_tokens.weight"
state_1_weight = state_1[weight_key]
assert state_1_weight.dtype == torch.float32, f"got dtype {state_1_weight.dtype} expected torch.float32"
......
......@@ -44,8 +44,8 @@ def _test_func(rank, world_size, tempfile_name, unused):
# For clarity, this is what `expected_param_shapes` should look like depending on world size:
assert expected_param_shapes == {
"_fsdp_wrapped_module.flat_param": (12,),
"_fsdp_wrapped_module._fpw_module.ffn.1._fsdp_wrapped_module.flat_param": (6,),
"_fsdp_wrapped_module.flat_param_0": (12,),
"_fsdp_wrapped_module._fpw_module.ffn.1._fsdp_wrapped_module.flat_param_0": (6,),
}, expected_param_shapes
torch.manual_seed(1 + rank)
......
......@@ -3,9 +3,7 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""
Test FlattenParamsWrapper
"""
""" Test FlattenParamsWrapper on CPU and GPU (FP32 & FP16 on GPU). """
from collections import OrderedDict
import unittest
......@@ -17,12 +15,30 @@ from fairscale.utils.testing import objects_are_equal
class TestFlattenParams(unittest.TestCase):
""" Base test class and used for CPU case. """
def _get_module_init_fns(self):
return [
self._get_basic_linear_module,
self._get_shared_params_transformer,
]
def _get_empty_module(self, seed=0):
torch.manual_seed(seed) # keep everything deterministic
class Test(torch.nn.Module):
def forward(self, x):
return x + 1
module = Test()
def get_input(device, dtype):
torch.manual_seed(1) # keep everything deterministic
return torch.rand(1).to(device=device, dtype=dtype)
module.get_input = get_input
return module
def _get_transformer(self, seed=0):
torch.manual_seed(seed) # keep everything deterministic
module = torch.nn.Transformer(
......@@ -118,6 +134,44 @@ class TestFlattenParams(unittest.TestCase):
assert module.flat_param.dtype == new_dtype
assert all(p.dtype == new_dtype for p in module.encoder.layers[0].parameters())
def test_two_flattening_group(self):
module = self._get_transformer()
num_params = sum(p.numel() for p in module.parameters())
params_to_flatten1 = list(module.encoder.layers[1].parameters()) + list(module.decoder.layers[0].parameters())
params_to_flatten2 = list(module.encoder.layers[0].parameters()) + list(module.decoder.layers[1].parameters())
num_params_to_flatten1 = sum(p.numel() for p in params_to_flatten1)
num_params_to_flatten2 = sum(p.numel() for p in params_to_flatten2)
module = FlattenParamsWrapper(module, param_list=[params_to_flatten1, params_to_flatten2])
assert module.flat_params[0].numel() == num_params_to_flatten1
assert module.flat_params[1].numel() == num_params_to_flatten2
assert sum(p.numel() for p in module.parameters()) == num_params
def test_flatten_nothing(self):
module = self._get_transformer()
ref_out = self._get_output(module)
ref_state_dict = module.state_dict()
for k, v in ref_state_dict.items():
ref_state_dict[k] = v.clone()
module = FlattenParamsWrapper(module, param_list=[[]])
fpw_state_dict = module.state_dict()
assert ref_state_dict.keys() == fpw_state_dict.keys()
for k, v in ref_state_dict.items():
torch.testing.assert_allclose(v, fpw_state_dict[k])
fpw_out = self._get_output(module)
torch.testing.assert_allclose(ref_out, fpw_out)
def test_empty_module(self):
module = self._get_empty_module()
in_data = torch.rand(1)
ref_out = module(in_data)
module = FlattenParamsWrapper(module)
assert len(list(module.parameters())) == 0
assert len(module.state_dict()) == 0
fpw_out = module(in_data)
torch.testing.assert_allclose(ref_out, fpw_out)
def test_num_params(self):
module = self._get_transformer()
self._test_num_params(module)
......@@ -210,14 +264,14 @@ class TestFlattenParams(unittest.TestCase):
# confirm that unflatten_params reflects values from new_flat_param
new_flat_param = torch.full_like(module.flat_param, fill_value=42.0)
with module.unflatten_params(flat_param=new_flat_param):
with module.unflatten_params(flat_params=[new_flat_param]):
new_state_dict = clone_state_dict()
assert new_state_dict.keys() == ref_state_dict.keys()
for k, v in new_state_dict.items():
if k in buffers: # buffers are not changed
torch.testing.assert_allclose(v, ref_state_dict[k])
else: # params reflect new_flat_param value
assert torch.all(v == 42.0)
torch.testing.assert_allclose(v, torch.ones_like(v) * 42.0)
# after context manager exits, we go back to previous (reference) state
torch.testing.assert_allclose(module.flat_param, ref_flat_param)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment