Unverified Commit 9f347f37 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] FSDP: EMA related fixes (#922)



* add an ignore file

* [fix] FSDP: handle the lazy_init better

- when state_dict and load_state_dict is called, let'em not change
  the lazy_init state.

* changelog

* longer timeout

* Revert "longer timeout"

This reverts commit 00cc145fe86210a0972a1e7ba4f37531b9e091eb.

* testing

* adding the failed test

* fix the global to local id

* formatting

* more complete fix and test

* minor fix for an assert

* update changelog

* remove an extra line

* Update fairscale/nn/data_parallel/fsdp_optim_utils.py
Co-authored-by: default avataranj-s <32556631+anj-s@users.noreply.github.com>

* Update fairscale/nn/data_parallel/fsdp_optim_utils.py
Co-authored-by: default avataranj-s <32556631+anj-s@users.noreply.github.com>

* Update fairscale/nn/data_parallel/fsdp_optim_utils.py
Co-authored-by: default avataranj-s <32556631+anj-s@users.noreply.github.com>

* addressed review comments
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
Co-authored-by: default avataranj-s <32556631+anj-s@users.noreply.github.com>
parent 2ca4f0ee
......@@ -19,6 +19,7 @@ test-results/
# Coverage reports
.coverage
.coverage.*
./coverage.xml
# Environments
.env
......
......@@ -18,6 +18,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
left after children were wrapped. [#930]
- FSDP: Add support for saving optimizer state when using expert replicas with FSDP.
### Fixed
- FSDP: fixed handling of internal states with state_dict and load_state_dict
function so that they don't change lazy init state if training hasn't started. [#922]
- FSDP: added support of optimizer state handling when some of the parameters are
not used. An example is that in a model with a EMA copy that doesn't get trained
but still wants to be sharded. [#922]
## [0.4.5] - 2022-01-14
### Added
......
......@@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
"""These functions are used by FullyShardedDataParallel to help consolidate and shard optimizer states."""
import copy
from itertools import groupby
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Tuple, cast
import torch
......@@ -20,9 +21,12 @@ UNFLAT_RETURN_KEYS = {"state", "param_groups", "uncollected_local_ids", "param_i
def flatten_optim_state_dict(sd: Dict) -> Dict:
"""Shard a full optimizer state dict (called by FSDP.get_shard_from_optim_state_dict)"""
param_id_map = sd["param_id_map"]
num_local_params = len(set(param_id_map.values()))
# Get a set of local ids, like {0, None, 2}, then we remove None from it.
local_ids = set(param_id_map.values())
if None in local_ids:
local_ids.remove(None)
if sd["state"]:
new_state: Dict = {local_id: {} for local_id in range(num_local_params)}
new_state: Dict = {local_id: {} for local_id in local_ids}
singleton_state: Dict = copy.deepcopy(new_state)
else:
new_state = {}
......@@ -55,7 +59,10 @@ def flatten_optim_state_dict(sd: Dict) -> Dict:
# add pointers from the `params` dict.
for pg_id, _ in enumerate(sd["param_groups"]):
# TODO: this list could be huge. Can we avoid materializing?
# The values() list may look like [0,0,None,None,2,2]. We use
# groupby to remove the duplicates and then count the length of
# resulting iter.
num_local_params = sum(1 for _ in groupby(param_id_map.values()))
new_sd["param_groups"][pg_id]["params"] = list(range(num_local_params))
return new_sd
......@@ -91,7 +98,18 @@ def _unflatten_optim_state(
world_pad_info: List[List[List[int]]],
singleton_state: Dict[int, Dict],
) -> Tuple[Dict[int, Dict], Dict[int, int]]:
"""Convert optimizer state for flattened parameters into original, unflatten ones."""
"""Convert optimizer state for flattened parameters into original, unflattened ones.
Args:
combined_state: all-gathered state with tensors
instance_list: list of FSDP wrapper object instances
world_pad_info: [param_id][fsdp_instance_id][bytes_padded_per_rank]
singleton_state: all-gathered dimensionless tensors
Returns:
state: unflattened state dict
idx_mapping: a mapping from global ID to local ID
"""
# local ids are the keys in the current state (combined_state), (usually fewer)
# global ids will be the keys in the unflattened state
next_global_id = 0 # gets incremented
......@@ -100,7 +118,7 @@ def _unflatten_optim_state(
# non_tensor_state refers to entries in sd[state][param_id] that are not tensors, like "step".
# we check that these are identical across workers and then take the first
non_tensor_state = [_extract_non_tensor_state(combined_state, id) for id in combined_state]
non_tensor_state = {id: _extract_non_tensor_state(combined_state, id) for id in combined_state}
# Local corresponds to flattened, global corresponds to unflattened.
# Casting needed only for mypy.
......@@ -114,7 +132,19 @@ def _unflatten_optim_state(
global_to_local_id = {}
for local_id, num_unflat in enumerate(num_global_params):
for _ in range(num_unflat):
global_to_local_id[next_global_id] = local_id
# Some params could be unused, which means the optimizer
# hasn't created their state. Therefore, `local_id` obtained
# by enumerating the params above could be out of the range
# of keys in `combined_state` above. Here is an example:
#
# global local notes
# 0 0 FC1's weight, first flat buffer
# 1 0 FC1's bias, first flat buffer
# 2 None FC2's weight, no flat state
# 3 None FC2's bias, no flat state
# 4 2 FC3's weight, second flat buffer (but with id 2)
# 5 2 FC3's bias, second flat buffer (but with id 2)
global_to_local_id[next_global_id] = local_id if local_id in local_ids else None
next_global_id += 1
if not combined_state:
return {}, global_to_local_id
......@@ -122,11 +152,20 @@ def _unflatten_optim_state(
# copy non tensor state (like the "step" count) to all global entries
unflat_state = {i: copy.deepcopy(non_tensor_state[0]) for i in range(sum(num_global_params))}
# remove the global entries that don't have optim state because pytorch
# optimizer's state_dict() function returns a state_dict without the missing
# param, so we shouldn't have things like "1:{}" for missing params.
for g, l in global_to_local_id.items():
if l is None:
del unflat_state[g]
if non_tensor_state[0].keys() == combined_state[0].keys():
# Early return if there is no tensors in the state dict.
return unflat_state, global_to_local_id
local_to_global: Dict[int, List] = {i: [] for i in local_ids}
for g, l in global_to_local_id.items():
if l is not None:
local_to_global[l].append(g)
# loop over parameters in state.
# Tensor state will be padded, concatenated, and restored to original shape with FlattenParamsWrapper.get_views
......@@ -165,7 +204,20 @@ def build_unflat_state_dict(
uncollected_opt_state: Dict[int, Dict],
param_groups: List[Dict],
) -> Dict:
"""Build an unflattened optimizer state dict given a list of flattened optimizer state dicts from each rank."""
"""Build an unflattened optimizer state dict given a list of flattened optimizer state dicts
from each rank. This is only called on rank 0.
Args:
instance_list: list of FSDP wrapper objects
world_pad_info: [param_id][fsdp_instance_id][bytes_padded_per_rank]
state: all-gathered combined/local/flatten state_dict
singleton_state: all-gathered singleton_state (dimensionless tensors)
uncollected_opt_state: non-tensor and not-gathered state
param_groups: the original rank 0's sd["param_groups"]
Returns:
dict: an unflattened, nonsharded optimizer state, as if FSDP was not there.
"""
assert all(len(s) == len(instance_list) for s in world_pad_info)
assert all(len(s[0]) == 1 for s in world_pad_info)
......
......@@ -796,6 +796,7 @@ class FullyShardedDataParallel(nn.Module):
)
if self.verbose:
repr = (
f"self={id(self)} is_root={self._is_root}, "
f"rank={self.rank}, " + repr + f"reshard_after_forward={self.reshard_after_forward}, "
f"compute_dtype={self.compute_dtype}, "
f"buffer_dtype={self.buffer_dtype}, "
......@@ -907,6 +908,7 @@ class FullyShardedDataParallel(nn.Module):
"""
if torch.cuda.is_available():
torch.cuda.synchronize()
is_uninitialized = self._is_root is None # See comment below on why we use this.
self._lazy_init()
def maybe_cast_buffers(dtype: Optional[torch.dtype] = None) -> None:
......@@ -931,6 +933,12 @@ class FullyShardedDataParallel(nn.Module):
# In case we are in mixed precision, restore buffers back to buffer_dtype.
maybe_cast_buffers()
# We shouldn't change the init state in case this was an inner module and
# users simply wanted to get state_dict before training.
if is_uninitialized and self._is_root:
for module in self.modules():
if isinstance(module, FullyShardedDataParallel):
module._reset_lazy_init()
return state_dict
@typing.overload
......@@ -999,7 +1007,15 @@ class FullyShardedDataParallel(nn.Module):
def load_state_dict(
self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
) -> NamedTuple:
return self._load_state_dict(state_dict, strict)
is_uninitialized = self._is_root is None # See comment below on why we use this.
sd = self._load_state_dict(state_dict, strict)
# We shouldn't change the init state in case this was an inner module and
# users simply wanted to load_state_dict before training.
if is_uninitialized and self._is_root:
for module in self.modules():
if isinstance(module, FullyShardedDataParallel):
module._reset_lazy_init()
return sd
def load_local_state_dict(
self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
......@@ -1297,7 +1313,7 @@ class FullyShardedDataParallel(nn.Module):
if n != "" and isinstance(m, FullyShardedDataParallel):
# We relax the assert for non-root instance, when the nested inialized module is wrapped
# again in FSDP later, for example after training to run inference.
assert m._is_root is None or not m._is_root
assert m._is_root is None or not m._is_root, f"offending FSDP instance is {id(m)}, {m}"
if m._is_root is None:
m._is_root = False
if m.process_group != self.process_group:
......@@ -2363,9 +2379,10 @@ class FullyShardedDataParallel(nn.Module):
ids_not_to_shard = copy.deepcopy(full_optim_state_dict["uncollected_local_ids"])
if self.flatten_parameters:
full_optim_state_dict = ou.flatten_optim_state_dict(full_optim_state_dict)
assert len(full_optim_state_dict["state"]) in (
0,
len(instance_list),
# Due to unused params, the length of the state can be anywhere between
# 0 and number of params/fsdp_instance.
assert len(full_optim_state_dict["state"]) <= len(
instance_list
), f'{len(full_optim_state_dict["state"])}, {len(instance_list)}'
# get the portion of dict associated with the shard, in place
......
......@@ -300,7 +300,9 @@ class FlattenParamsWrapper(nn.Module):
assert (
len(set(p.dtype for p in params)) == 1
), f"expects all parameters to have same dtype: fp32: {fp32_msg} \n fp16: {fp16_msg} "
assert len(set(p.requires_grad for p in params)) == 1, "expects all parameters to have same requires_grad"
assert (
len(set(p.requires_grad for p in params)) == 1
), f"expects all parameters to have same requires_grad {p_set}"
assert len(params) == len(set(params)), "params list should not have dups"
return params, param_infos, shared_param_infos
......
......@@ -8,6 +8,7 @@ import unittest
from parameterized import parameterized
import torch
from torch import nn
from torch.optim import SGD, Adadelta, Adam # type: ignore
from fairscale.nn import FullyShardedDataParallel
......@@ -225,6 +226,34 @@ class TestOptimizerUtils(DistributedTest):
assert objects_are_equal(shard_sd["state"], original_shard_sd["state"])
assert objects_are_equal({k: shard_sd[k] for k in original_shard_sd}, original_shard_sd)
@parameterized.expand(
[(True,), (False,)],
name_func=rename_test,
)
def test_model_with_unused_params(self, wrap_l2):
"""Test handling of model with unused params by gather_full_optim_state_dict()"""
test_fn = functools.partial(self._test_model_with_unused_params, wrap_l2=wrap_l2)
spawn_and_init(test_fn, world_sizes=[2])
@classmethod
def _test_model_with_unused_params(self, rank, pg, wrap_l2):
model = ModelWithUnusedParams(wrap_l2).cuda()
data = torch.rand(4).cuda().requires_grad_(True)
model = FullyShardedDataParallel(model)
optim = SGD(model.parameters(), momentum=0.9, lr=0.1)
out = model(data).sum()
out.backward()
optim.step()
model.zero_grad(set_to_none=True)
sd = model.gather_full_optim_state_dict(optim)
if rank == 0:
shard_sd = model.get_shard_from_optim_state_dict(sd)
orig_sd = optim.state_dict()
orig_sd = recursive_copy_to_device(orig_sd, non_blocking=False, device="cpu")
objects_are_equal(shard_sd, orig_sd, raise_exception=True)
else:
assert sd is None, sd
def test_named_params_ordering(self):
"""Test assumption of consolidate_optimizer_state_dict"""
group = DummyProcessGroup(0, 1)
......@@ -234,8 +263,33 @@ class TestOptimizerUtils(DistributedTest):
assert objects_are_equal(p, named_pars[i])
def test_is_singleton_tensor(self):
"""Test is_singleton_tensor function"""
assert is_singleton_tensor(torch.tensor(4.0))
assert not is_singleton_tensor(torch.tensor([4.0]))
assert not is_singleton_tensor(torch.tensor([4.0, 5.0]))
assert not is_singleton_tensor([4.0])
assert not is_singleton_tensor(4.0)
class ModelWithUnusedParams(nn.Module):
def __init__(self, wrap_l2):
super().__init__()
self.l = nn.Linear(4, 4)
# unused param must be wrapped, otherwise, due to flatten, it
# is always used.
self.not_trained = nn.Linear(4, 4).requires_grad_(False)
self.not_trained = FullyShardedDataParallel(self.not_trained)
# optionally testing a used param after the unused one by
# wrapping it.
self.l2 = nn.Linear(4, 4)
if wrap_l2:
# When wrapping happens, the unused param will be in the middle
# of the param list (for optimizer state dict), not at the
# end. This way, we can test the handling code in more corner
# cases.
self.l2 = FullyShardedDataParallel(self.l2)
def forward(self, x):
with torch.no_grad():
y = self.not_trained(x)
return self.l2(self.l(x)) - y
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment