Unverified Commit 429f3d31 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] better handling non-flatten in FSDP (#1072)



* [fix] better handling non-flatten in FSDP

- see the detailed comment about that backward firing case
- also minor debugging help in FSDP
- also minor fix in FPW's state dict

* [feat] disallow reset_parameters by default

* [feat] adding fsdp_instances API - useful in check wrapping by user code

* [fix] one line fix but more than a day of debugging

* fixed the case of loading combined check with empty fsdp instances

* fixed another bug around state loading the root/nonroot module full param caching due to not resharding after forward

* [feat] support .half and .float better

* fixed a bug in gather optim state losses extra keys from the original state_dict

* fixed a test failure in mixed precision

* fixed another bug affecting no_sync grad acc

* fixed a bug and a test in fsdp optim state

* fixed another corner case

* added a comment

* skip ssd offload tests

* skip fsdp one for ssd overload
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent 47ce21ac
......@@ -500,7 +500,7 @@ class LayerwiseMemoryTracker:
Indicate if x and y share the same storage, meaning that one of them
is a view, reshape or stride of the other or from a common tensor
"""
return x.storage().data_ptr() == y.storage().data_ptr() # type: ignore
return x.storage().data_ptr() == y.storage().data_ptr()
@staticmethod
def _collect_tensors(module_io_tensors: Union[torch.Tensor, Sequence[torch.Tensor]]) -> List[torch.Tensor]:
......
......@@ -12,6 +12,7 @@ from .fully_sharded_data_parallel import (
OffloadConfig,
TrainingState,
auto_wrap_bn,
get_fsdp_instances,
no_pre_load_state_dict_hook,
)
......
......@@ -14,9 +14,6 @@ from fairscale.nn.misc import FlattenParamsWrapper
if TYPE_CHECKING:
from fairscale.nn.data_parallel import FullyShardedDataParallel
# These return keys are used by fairseq. To change, add @sshleifer as a reviewer.
UNFLAT_RETURN_KEYS = {"state", "param_groups", "uncollected_local_ids", "param_id_map"}
# This function helps shard a full optimizer state dict
def flatten_optim_state_dict(sd: Dict) -> Dict:
"""Shard a full optimizer state dict (called by FSDP.get_shard_from_optim_state_dict)"""
......@@ -52,20 +49,24 @@ def flatten_optim_state_dict(sd: Dict) -> Dict:
new_state[local_id][buffer_name] = torch.cat(tensors)
new_state[local_id].update(non_tensor_state)
new_state[local_id].update(singleton_state[local_id])
new_sd = {"state": new_state, "param_groups": copy.deepcopy(sd["param_groups"])}
for k in sd.keys(): # if there are extra keys, like loss_scale, don't delete them
if k not in UNFLAT_RETURN_KEYS:
new_sd[k] = copy.deepcopy(sd[k])
# Now make a new param_groups copy and update it.
new_sd_pg = copy.deepcopy(sd["param_groups"])
# add pointers from the `params` dict.
for pg_id, _ in enumerate(sd["param_groups"]):
# The values() list may look like [0,0,None,None,2,2]. We use
# groupby to remove the duplicates and then count the length of
# resulting iter.
num_local_params = sum(1 for _ in groupby(param_id_map.values()))
new_sd["param_groups"][pg_id]["params"] = list(range(num_local_params))
new_sd_pg[pg_id]["params"] = list(range(num_local_params))
return new_sd
# update the original sd so that we don't lose extra keys, like loss_scale.
sd["state"] = new_state
sd["param_groups"] = new_sd_pg
# delete extra keys we have added to match the original state.
del sd["uncollected_local_ids"]
del sd["param_id_map"]
return sd
def check_param_counts_before_sharding(full_optim_state_dict: Dict, n_instances: int) -> None:
......@@ -202,7 +203,7 @@ def build_unflat_state_dict(
state: Dict[int, Dict[str, List[torch.Tensor]]],
singleton_state: Dict[int, Dict[str, List[torch.Tensor]]],
uncollected_opt_state: Dict[int, Dict],
param_groups: List[Dict],
original_sd: Dict,
) -> Dict:
"""Build an unflattened optimizer state dict given a list of flattened optimizer state dicts
from each rank. This is only called on rank 0.
......@@ -213,7 +214,7 @@ def build_unflat_state_dict(
state: all-gathered combined/local/flatten state_dict
singleton_state: all-gathered singleton_state (dimensionless tensors)
uncollected_opt_state: non-tensor and not-gathered state
param_groups: the original rank 0's sd["param_groups"]
original_sd: the original rank 0's sd
Returns:
dict: an unflattened, nonsharded optimizer state, as if FSDP was not there.
......@@ -228,19 +229,19 @@ def build_unflat_state_dict(
singleton_state[local_id] = {buffer_name: [x] for buffer_name, x in v.items() if is_singleton_tensor(x)}
# local ids are in the current state, global_ids will be in returned state.
unflat_state, global_to_local_id = _unflatten_optim_state(state, instance_list, world_pad_info, singleton_state)
# Since there are no tensors in param_groups, deepcopy is fine.
param_groups = copy.deepcopy(param_groups)
param_groups = copy.deepcopy(original_sd["param_groups"])
# Casting needed only for mypy.
num_params = sum([cast(int, m.num_params_managed) for m in instance_list])
param_groups[0]["params"] = list(range(num_params))
unflat_optim_state_dict = {
"state": dict(sorted(unflat_state.items())), # NOTE: this is probably already sorted
"param_id_map": global_to_local_id,
"param_groups": param_groups,
"uncollected_local_ids": list(uncollected_opt_state.keys()),
}
assert set(unflat_optim_state_dict.keys()) == UNFLAT_RETURN_KEYS
return unflat_optim_state_dict
# Update the original sd so we don't loss extra state like loss_scale.
original_sd["state"] = dict(sorted(unflat_state.items())) # NOTE: this is probably already sorted
original_sd["param_id_map"] = global_to_local_id
original_sd["param_groups"] = param_groups
original_sd["uncollected_local_ids"] = list(uncollected_opt_state.keys())
return original_sd
def is_singleton_tensor(x: Any) -> bool:
......
......@@ -217,6 +217,8 @@ class FullyShardedDataParallel(nn.Module):
save memory. Consider a case that an FSDP root module is a submodule of a model.
Backward pass may not start immediate after the FSDP root module finishes its forward.
So, reshard the parameters for the FSDP root modules can help to save memory in this case.
In certain cases, the performance is not even slower, because the cached full param
state may be stale due to load_local_state_dict() calls anyway.
Default: True.
mixed_precision (bool, Optional):
if ``True``, inputs, activations and gradients will be kept in FP16;
......@@ -230,6 +232,7 @@ class FullyShardedDataParallel(nn.Module):
which improves training speed.
move_params_to_cpu (bool, Optional):
if ``True``, offload params to CPU.
Default: False
compute_dtype (torch.dtype, Optional):
dtype for full parameters for computation. This defaults to
``torch.float32`` unless *``mixed_precision``* is set, in which case
......@@ -308,6 +311,12 @@ class FullyShardedDataParallel(nn.Module):
rank 0 and return empty dict non-rank 0, which allow FullyShardedDataParallel to
skip the GPU -> CPU copy on non-rank 0 altogether and prevent OOM.
Default: False
gradient_predivide_factor (float, optional):
If supplied, pre-divide the gradients before scatter-reduce.
Default: None
allow_reset_parameters (bool):
If True, allow ``reset_parameters`` API to be proxied to the wrapped module.
Default: False
"""
def __init__(
......@@ -336,6 +345,7 @@ class FullyShardedDataParallel(nn.Module):
offload_config: Optional[OffloadConfig] = None,
state_dict_on_rank_0_only: bool = False,
gradient_predivide_factor: Optional[float] = None,
allow_reset_parameters: bool = False,
):
try:
import torch._C
......@@ -415,6 +425,7 @@ class FullyShardedDataParallel(nn.Module):
self.world_size
)
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
self.allow_reset_parameters = allow_reset_parameters
self.numel_padded_per_param: List[int] = []
self._tstart = time.time()
......@@ -488,6 +499,10 @@ class FullyShardedDataParallel(nn.Module):
for p in self.params:
if p.dtype is not torch.float16:
raise ValueError("Expecting FP16 param type in pure FP16 mode.")
else:
for p in self.params:
if p.dtype is not torch.float32:
raise ValueError("Expecting FP16 param type in FP32 & MP modes.")
# Shard module parameters in place
self._shard_parameters_()
......@@ -531,9 +546,8 @@ class FullyShardedDataParallel(nn.Module):
# Free all params at the end of initialization.
if self.ssd_offload:
for m in self.modules(): # includes self
if isinstance(m, FullyShardedDataParallel):
m._free_ssd_offload()
for m in get_fsdp_instances(self):
m._free_ssd_offload()
def _get_gradient_predivide_factor(self, world_size: int) -> float:
factor: int = 1
......@@ -820,12 +834,12 @@ class FullyShardedDataParallel(nn.Module):
f"compute_dtype={self.compute_dtype}, "
f"buffer_dtype={self.buffer_dtype}, "
f"fp32_reduce_scatter={self.fp32_reduce_scatter}, "
f"compute_device={self.compute_device}"
f"compute_device={self.compute_device}, "
f"move_params_to_cpu={self.move_params_to_cpu}, "
f"move_grads_to_cpu={self.move_grads_to_cpu}, "
f"bucket_cap_mb={self.bucket_cap_mb}, "
f"clear_autocast_cache={self.clear_autocast_cache}"
f"force_input_to_fp32={self.force_input_to_fp32}"
f"clear_autocast_cache={self.clear_autocast_cache}, "
f"force_input_to_fp32={self.force_input_to_fp32}, "
)
return repr
......@@ -989,9 +1003,8 @@ class FullyShardedDataParallel(nn.Module):
)
with contextlib.ExitStack() as stack:
# Tell any nested FSDP instances not to auto summon full params.
for module in self.modules(): # includes self
if isinstance(module, FullyShardedDataParallel):
stack.enter_context(module._no_return_full_state_dict())
for module in get_fsdp_instances(self):
stack.enter_context(module._no_return_full_state_dict())
# We need to specially call FSDP's state_dict function in case
# self.state_dict is a function from a child class of FSDP.
return FullyShardedDataParallel.state_dict(self, *args, **kwargs)
......@@ -1057,10 +1070,22 @@ class FullyShardedDataParallel(nn.Module):
)
with contextlib.ExitStack() as stack:
# Tell any nested FSDP instances not to auto summon full params.
for module in self.modules(): # includes self
if isinstance(module, FullyShardedDataParallel):
stack.enter_context(module._no_return_full_state_dict())
for module in get_fsdp_instances(self):
stack.enter_context(module._no_return_full_state_dict())
output = self._load_state_dict(state_dict, strict)
# After loading the local state, if the a FSDP wrapper has the full
# params built, it will not use the updated value. Therefore we call
# _free_full_params() here to force it get updated on the next time when
# it needs to be built.
#
# There are 2 cases why this can happen:
# 1. in training, outermost wrapper may have reshrad_after_forward to
# False. (_is_root is True); therefore, full param is built and kept.
# 2. in eval, inner modules may get called directly, hence having multiple
# "root" instance, therefore, we need to loop over all instances
# below to free full params.
for module in get_fsdp_instances(self):
module._free_full_params()
return output
@contextlib.contextmanager
......@@ -1077,7 +1102,8 @@ class FullyShardedDataParallel(nn.Module):
.. note:: Gradient accumulation can be done without this context,
avoiding the extra GPU memory overhead, but with the extra
networking overhead.
networking overhead. I.e. the trainer loop should just do
multiple fwd/bwd without step() without the no_sync context.
"""
self._lazy_init()
assert self._is_root, "no_sync on inner FSDP is not supported"
......@@ -1085,10 +1111,9 @@ class FullyShardedDataParallel(nn.Module):
# This instance may wrap other FSDP instances and we
# need to set all of them to accumulate gradients.
old_flags = []
for m in self.modules(): # includes self
if isinstance(m, FullyShardedDataParallel):
old_flags.append((m, m._require_backward_grad_sync))
m._require_backward_grad_sync = False
for m in get_fsdp_instances(self):
old_flags.append((m, m._require_backward_grad_sync))
m._require_backward_grad_sync = False
try:
yield
finally:
......@@ -1096,6 +1121,89 @@ class FullyShardedDataParallel(nn.Module):
assert m._require_backward_grad_sync is False
m._require_backward_grad_sync = old_flag
def reset_parameters(self) -> None:
"""Special reset_parameters API handling.
We don't by default allow this API because it has at least 2 issues:
1. calling it after wrapping can crash due to unexpected tensor size
and dimensions due to flattening and sharding. So summon_full_params
context might be required.
2. calling it after wrapping can result in incorrect init values due
to flattening.
See this gist for an example of the init issue when parameters are
flatten.
https://gist.github.com/407bb158f0d0612e157c2cbcf5c8b76a
Or, like in 1, init function can silently init the weight differently
because of the dimensions.
Finally, be advised that init on CPU vs. on GPU can have different
values. If models are originally on CPU and after wrapping it is moved
to GPU, calling this will again be problematic.
"""
if self.allow_reset_parameters:
self.module.reset_parameters()
else:
raise RuntimeError("reset parameters after FSDP wrapping is not allowed")
def _apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel":
"""Hook into model conversion, like .half() and .float()
When users call module.half() or module.float() after FSDP wrapping,
we need to update some internal states here.
Args:
fn (Callable):
same as nn.Module's _apply.
Returns:
(Any):
same as nn.Module's _apply.
"""
# Just a pre-caution. Conversion happens while IDLE is the safest.
self.assert_state(TrainingState.IDLE)
# In order to determine how to change compute_dtype, we need to
# remember the dtype before this call.
if len(self.params):
dtype_before = self.params[0].dtype
# Call nn.Module's _apply.
ret = super()._apply(fn)
# Make sure we update p._full_param_padded according to the new dtype if we are
# not in mixed_precision. In mixed precision, doing m.half() or m.float() really
# don't make much sense. But we need allow it in case user just wanted to temporarily
# .half() and then .float() back for some reason.
if not self.mixed_precision:
for p in self.params:
if hasattr(p, "_full_param_padded"):
allocated = False
if p._full_param_padded.storage().size() == 0:
allocated = True
alloc_storage_(p._full_param_padded, size=p._full_param_padded.size())
p._full_param_padded = p._full_param_padded.to(dtype=p.data.dtype, device=p.data.device)
if allocated:
free_storage_(p._full_param_padded)
# Update compute_dtype because otherwise, p._full_param_padded will
# still be in that dtype.
if len(self.params):
dtype_after = self.params[0].dtype
if dtype_before != dtype_after:
# There are 4 cases below. Only 2 result in compute_dtype change
# to the dtype_after.
# 16 -> 32, 32 -> 16
# mixed n/a no change
# not mixed change change
if not self.mixed_precision:
self.compute_dtype = dtype_after
return ret
@contextlib.contextmanager
def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Generator:
"""
......@@ -1138,8 +1246,7 @@ class FullyShardedDataParallel(nn.Module):
torch.cuda.synchronize()
self._lazy_init()
self.assert_state(TrainingState.IDLE)
# Set the state so that we assert when trying to go into
# forward/backward.
# Set the state so that we assert when trying to go into fwd/bwd.
self.training_state = TrainingState.SUMMON_FULL_PARAMS
full_tensors = self._rebuild_full_params(force_full_precision=True)
assert full_tensors is not None
......@@ -1267,9 +1374,9 @@ class FullyShardedDataParallel(nn.Module):
p._fp32_shard = p.data
if self.mixed_precision:
assert p._fp32_shard.dtype == torch.float32
assert p._fp32_shard.dtype == torch.float32, self
if self.move_params_to_cpu:
assert p._fp32_shard.device == torch.device("cpu")
assert p._fp32_shard.device == torch.device("cpu"), self
# We don't pin memory when using ssd_offload since that results in OOM when
# the memory requirements of a model are larger than host memory.
......@@ -1817,6 +1924,29 @@ class FullyShardedDataParallel(nn.Module):
def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
"""Helper used below on all fsdp modules."""
if not fsdp_module._is_root and self._require_backward_grad_sync:
# We make sure to switch to fp32 shards here because there might be
# params linger in full_param mode, if the following firing order happens:
# pre-bwd: rebuild and use full for p1 and p2
# post-bwd for p1: free and switch to fp32 shard for p1
# pre-bwd: rebuild again for p1 and p2
# post-bwd for p2: free and switch to fp32 shard for p2
# In the end, p1 will be left in full param mode.
#
# We need switch to fp32 *and* free full params. If we don't free,
# we end up reusing potentially *stale* full param (after the fp32
# shard is updated (e.g. by optimizer.step()).
#
# We skip the root because it may have reshard=False, which means
# we want to keep the speed benefit of that. I haven't seen a case
# where this is needed on the root module.
#
# We skip also in grad accum steps since we want to keep the full
# params since they haven't been updated. See comment of ``no_sync``
# for when to use no_sync style grad acc. For FSDP, it is more likely
# you want to use grad acc without no_sync.
fsdp_module._free_full_params()
fsdp_module._use_fp32_param_shard()
for p in fsdp_module.params:
if not p.requires_grad:
continue
......@@ -1828,7 +1958,7 @@ class FullyShardedDataParallel(nn.Module):
# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
# remains the unsharded gradient accumulated from prior no-sync passes, and p._saved_grad_shard
# remains the sharded gradient from the last synchronized pass. This also allows interleaved no-sync and
# sync passes, if desired.
# sync passes.
if not self._require_backward_grad_sync:
continue
......@@ -1841,46 +1971,49 @@ class FullyShardedDataParallel(nn.Module):
p.device == p._saved_grad_shard.device,
f"WFPB: incorrect saved_grad_shard device {p.device} vs {p._saved_grad_shard.device}",
)
p_assert(
p.shape == p._saved_grad_shard.shape,
f"WFPB: incorrect saved_grad_shard shape {p.shape} vs {p._saved_grad_shard.shape}",
)
p.grad = p._saved_grad_shard
if hasattr(p, "_saved_grad_shard"):
delattr(p, "_saved_grad_shard")
# Update root and nested FSDP's hooks and flags.
for m in self.modules(): # includes self
if isinstance(m, FullyShardedDataParallel):
_finalize_parameters(m)
m._free_ssd_offload()
m._pre_backward_hook_has_run = False
if any(p.requires_grad for p in m.parameters()):
# Check if the module has params and if any of them has
# the `requires_grad` field set. If `requires_grad=False` for
# all the params, the post_backward hook will not fire and the
# state will remain in `TrainingState.BACKWARD_PRE`.
if any([p.requires_grad for p in m.params]):
m.assert_state(TrainingState.BACKWARD_POST)
else:
m.assert_state(TrainingState.BACKWARD_PRE)
for m in get_fsdp_instances(self):
_finalize_parameters(m)
m._free_ssd_offload()
m._pre_backward_hook_has_run = False
if any(p.requires_grad for p in m.parameters()):
# Check if the module has params and if any of them has
# the `requires_grad` field set. If `requires_grad=False` for
# all the params, the post_backward hook will not fire and the
# state will remain in `TrainingState.BACKWARD_PRE`.
if any([p.requires_grad for p in m.params]):
m.assert_state(TrainingState.BACKWARD_POST)
else:
# When `m` and its children has no params or has params but
# none with `requires_grad==True`, there are two cases:
# 1. output tensors are `requires_grad==True`. In this case,
# pre-backward hook is still registered, so it is in BACKWARD_PRE state.
# 2. output tensors are `requires_grad==False`. In this case,
# pre-backward hook is not registered, so it is in IDLE state.
m.assert_state([TrainingState.BACKWARD_PRE, TrainingState.IDLE])
m.training_state = TrainingState.IDLE
if m._is_root:
# reset this flag for cases like "one forward pass + multiple backward passes"
self._post_backward_callback_queued = False
# clear this list for next iteration
p_assert(
self._output_pre_backward_hook_registered is not None,
"WFPB: self._output_pre_backward_hook_registered should not be None",
)
assert self._output_pre_backward_hook_registered is not None # make mypy happy
self._output_pre_backward_hook_registered.clear()
m.assert_state(TrainingState.BACKWARD_PRE)
else:
# When `m` and its children has no params or has params but
# none with `requires_grad==True`, there are two cases:
# 1. output tensors are `requires_grad==True`. In this case,
# pre-backward hook is still registered, so it is in BACKWARD_PRE state.
# 2. output tensors are `requires_grad==False`. In this case,
# pre-backward hook is not registered, so it is in IDLE state.
m.assert_state([TrainingState.BACKWARD_PRE, TrainingState.IDLE])
m.training_state = TrainingState.IDLE
if m._is_root:
# reset this flag for cases like "one forward pass + multiple backward passes"
self._post_backward_callback_queued = False
# clear this list for next iteration
p_assert(
self._output_pre_backward_hook_registered is not None,
"WFPB: self._output_pre_backward_hook_registered should not be None",
)
assert self._output_pre_backward_hook_registered is not None # make mypy happy
self._output_pre_backward_hook_registered.clear()
@torch.no_grad()
def _rebuild_full_params(self, force_full_precision: bool = False) -> Optional[List[Tuple[torch.Tensor, bool]]]:
......@@ -1979,25 +2112,38 @@ class FullyShardedDataParallel(nn.Module):
if not p._is_sharded: # e.g., when world_size == 1
update_p_data()
else:
# Skip if already built. Only shared param can be rebuilt multiple times.
# Skip if already built.
#
# case 1: shared param can be rebuilt multiple times.
# A corner case is p._orig_size = (1,), which means the shape equality is
# not a perfect check. But we assume we don't share a param with shape (1,).
if p.data.shape == p._orig_size and hasattr(p, "_is_shared") and p._is_shared:
# We do use size (1,) in unit testing at least.
# case 2: with multiple params (like non-flatten, or multiple flatten groups)
# we may have pre & post bwd firing order issues. See comments in the
# _finalize_parameters function for such case.
if p.data.shape == p._orig_size and p._orig_size != (1,):
assert p.data.storage().data_ptr() == p._full_param_padded.storage().data_ptr(), (
f"p.data {p.data.storage().data_ptr()} "
f"p._fp32_shard {p._fp32_shard.storage().data_ptr()} "
f"p._fp16_shard {p._fp16_shard.storage().data_ptr() if p._fp16_shard is not None else None} "
f"p._full_params_padded {p._full_param_padded.storage().data_ptr()} "
)
continue
# If self.move_params_to_cpu and force_full_precision, we need to cast
# the FP32 CPU param to CUDA for the all-gather.
p_data = p.data.to(p._full_param_padded.device, non_blocking=True)
p_size = p._full_param_padded.size()
assert p_size.numel() % self.world_size == 0
full_p_size = p._full_param_padded.size()
assert full_p_size.numel() % self.world_size == 0
if self.mixed_precision and force_full_precision:
# Allocate fresh tensor in full precision since we are in
# mixed precision and full precision rebuild is asked.
output_tensor = p_data.new_zeros(p_size)
output_tensor = p_data.new_zeros(full_p_size)
else:
if p._full_param_padded.storage().size() != p_size.numel():
if p._full_param_padded.storage().size() != full_p_size.numel():
# Allocate based on full size from all shards.
alloc_storage_(p._full_param_padded, size=p_size)
alloc_storage_(p._full_param_padded, size=full_p_size)
output_tensor = p._full_param_padded
# Fill output_tensor with (p.data for each shard in self.world_size)
......@@ -2284,10 +2430,10 @@ class FullyShardedDataParallel(nn.Module):
raise ValueError(msg)
def _broadcast_pad_info_to_r0(self) -> List[List[List[int]]]:
"""Collect [x.numel_padded_per_param for x in self._fsdp_instances()] from each rank."""
"""Collect [x.numel_padded_per_param for x in get_fsdp_instances(self)] from each rank."""
world_pad_info: List[List[List[int]]] = [] # this will contain values from the whole world.
my_pad_info: List[List[int]] = [
cast(List[int], m.numel_padded_per_param) for m in self._fsdp_instances(skip_empty=True)
cast(List[int], m.numel_padded_per_param) for m in get_fsdp_instances(self, skip_empty=True)
]
for rank in range(self.world_size):
if rank == self.rank:
......@@ -2307,7 +2453,7 @@ class FullyShardedDataParallel(nn.Module):
singleton_state: Dict[int, Dict[str, List[Any]]] = {} # Dimensionless tensor
# Non-empty FSDP instance and sd_state item number must match.
fsdp_instances = self._fsdp_instances(skip_empty=True)
fsdp_instances = get_fsdp_instances(self, skip_empty=True)
assert len(fsdp_instances) >= len(sd_state), f"{len(fsdp_instances)} vs. {len(sd_state)}"
for k, v in sd_state.items():
......@@ -2377,7 +2523,9 @@ class FullyShardedDataParallel(nn.Module):
self._lazy_init()
sd = self._remove_uncollectable_params_from_optim_state_dict(optim.state_dict())
assert set(sd.keys()) == {"param_groups", "state"}, f'{set(sd.keys())} != {"param_groups", "state"}'
assert {"param_groups", "state"}.issubset(
set(sd.keys())
), f'{set(sd.keys())} not a superset of {"param_groups", "state"}'
assert len(sd["param_groups"]) == 1, "Param groups are not supported"
# We use all_gather to consolidate OSD['state'] and broadcast to consolidate the other keys (like param_groups)
state, singleton_state = self._gather_optim_state(sd.pop("state"))
......@@ -2386,29 +2534,30 @@ class FullyShardedDataParallel(nn.Module):
return None
# Unify the shard states by concatenating tensors and unflattening params
new_state_dict = ou.build_unflat_state_dict(
self._fsdp_instances(skip_empty=True),
get_fsdp_instances(self, skip_empty=True),
pad_info,
state,
singleton_state,
self.uncollected_opt_state,
sd["param_groups"],
sd,
)
self.uncollected_opt_state = {}
assert "uncollected_local_ids" in new_state_dict
return new_state_dict
def _fsdp_instances(self, skip_empty: bool = False) -> List["FullyShardedDataParallel"]:
"""Returns all fsdp modules in self.modules() including self."""
result = [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)]
if skip_empty:
result = list(filter(lambda x: len(cast(FullyShardedDataParallel, x).non_shared_params()) > 0, result))
return result
def _remove_uncollectable_params_from_optim_state_dict(self, osd: Dict) -> Dict:
"""Return a new state dict filtering out the ones like MoE layers, which has
``no_broadcast_optim_state`` flag set.
We also make rooms for the optimizer state on rank 0.
Args:
osd (Dict):
Optimizer state dict from a rank. osd["state"] is what we mainly
care. Osd may contain other keys and values, we need to keep. Therefore,
we only change osd["state"] and not returning a new copy of osd
which is slower and may also lose extra fields, like "loss_scale"
used by fairseq.
"""
# In PyTorch version 1.12, Adam's `step` state changed from an int to a singleton
# tensor. We convert it back here. Otherwise, the step counter will be treated
......@@ -2419,8 +2568,8 @@ class FullyShardedDataParallel(nn.Module):
if ou.is_singleton_tensor(bufs["step"]):
bufs["step"] = bufs["step"].item()
# Get uncollected_ids.
uncollected_ids = [i for i, m in enumerate(self._fsdp_instances()) if m.no_broadcast_optim_state]
new_dct = {"state": {k: v for k, v in osd["state"].items() if k not in uncollected_ids}}
uncollected_ids = [i for i, m in enumerate(get_fsdp_instances(self)) if m.no_broadcast_optim_state]
new_state_value = {k: v for k, v in osd["state"].items() if k not in uncollected_ids}
if self.rank == 0:
# Save placeholders for uncollected opt state to keep the same unflat OSD format, and move them to CPU.
self.uncollected_opt_state = {
......@@ -2429,9 +2578,8 @@ class FullyShardedDataParallel(nn.Module):
if k in uncollected_ids
}
pg = copy.deepcopy(osd["param_groups"])
new_dct["param_groups"] = pg
return new_dct
osd["state"] = new_state_value
return osd
def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Get the portion of the optimizer state dict associated with the shard
......@@ -2439,14 +2587,21 @@ class FullyShardedDataParallel(nn.Module):
This can be used to get the right sharded optimizer state to be loaded
into the sharded optimizer for this FSDP rank.
..warning:: The input state dict is modified in-place assuming the original
full state isn't going to be used anymore. This is done so that
we don't need to copy extra state in it. It is caller's responsibility
to make copies if it doesn't want the original state dict modified.
Args:
full_optim_state_dict (dict): consolidated optimizer state returned by ``gather_full_optim_state``, or loaded from a checkpoint.
full_optim_state_dict (dict):
consolidated optimizer state returned by ``gather_full_optim_state``,
or loaded from a checkpoint.
Returns:
(dict): a shard of the optimizer state.
"""
# Assert nesting is the same as it was at save time
instance_list = self._fsdp_instances()
instance_list = get_fsdp_instances(self, skip_empty=True)
ou.check_param_counts_before_sharding(full_optim_state_dict, len(instance_list))
ids_not_to_shard = copy.deepcopy(full_optim_state_dict["uncollected_local_ids"])
if self.flatten_parameters:
......@@ -2458,9 +2613,9 @@ class FullyShardedDataParallel(nn.Module):
), f'{len(full_optim_state_dict["state"])}, {len(instance_list)}'
# get the portion of dict associated with the shard, in place
for id, s in full_optim_state_dict["state"].items():
for _id, s in full_optim_state_dict["state"].items():
for k, v in s.items():
if torch.is_tensor(v) and id not in ids_not_to_shard:
if torch.is_tensor(v) and _id not in ids_not_to_shard:
v_shard, _ = self._get_shard(v)
elif isinstance(v, list) and ou.is_singleton_tensor(v[0]):
# if we are resuming on larger world size, take first entry
......@@ -2468,7 +2623,7 @@ class FullyShardedDataParallel(nn.Module):
assert ou.is_singleton_tensor(v_shard)
else:
v_shard = v # don't shard entries that are not tensors
full_optim_state_dict["state"][id][k] = v_shard
full_optim_state_dict["state"][_id][k] = v_shard
return full_optim_state_dict
......@@ -2728,3 +2883,22 @@ def auto_wrap_bn(
enable_wrap(config_auto_wrap_policy, wrapper_cls=FullyShardedDataParallel) if wrap_it else contextlib.suppress()
):
return auto_wrap(module)
def get_fsdp_instances(mod: nn.Module, skip_empty: bool = False) -> List[FullyShardedDataParallel]:
"""Return, a list, if any, of the module/submodule is wrapped by FSDP within another module.
Args:
mod (nn.Module):
A nn.Module module to be scanned.
skip_empty (bool):
If True, skip wrappers without any parameters.
Default: False
"""
ret: List[FullyShardedDataParallel] = []
for m in mod.modules(): # including mod itself
if isinstance(m, FullyShardedDataParallel):
ret.append(m)
if skip_empty:
ret = list(filter(lambda x: len(cast(FullyShardedDataParallel, x).non_shared_params()) > 0, ret))
return ret
......@@ -500,7 +500,11 @@ class FlattenParamsWrapper(nn.Module):
# Unflatten the module automatically if the state_dict is non-flat.
# Note, we check the flat_param_ prefix since custom names can be given and flat_param_0 is
# not always in the state dict's key list.
if self.is_flattened and not any(k.startswith("flat_param_") for k in state_dict.keys()):
if (
self.num_params_managed > 0
and self.is_flattened
and not any(k.startswith("flat_param_") for k in state_dict.keys())
):
# This object is flatten but state_dict is not. So we unflatten and load.
with self.unflatten_params():
return super().load_state_dict(state_dict, strict)
......
......@@ -32,11 +32,9 @@ from . import utils as utils
from . import jit as jit
from . import fft as fft
#MODIFIED BY TORCHGPIPE
from . import backends
from . import distributed
from . import version
#END
class dtype:
is_floating_point: builtins.bool
......@@ -67,10 +65,8 @@ class device:
type: str
index: _int
#MODIFIED BY TORCHGPIPE
@overload
def __init__(self, device: device) -> None: ...
#END
@overload
def __init__(self, device: Union[_int, str]) -> None: ...
......@@ -78,17 +74,14 @@ class device:
@overload
def __init__(self, type: str, index: _int) -> None: ...
#MODIFIED BY TORCHGPIPE
class Size(tuple):
def numel(self) -> _int: ...
#END
#MODIFIED BY TORCHGPIPE
class Storage:
def size(self) -> _int: ...
def element_size(self) -> _int: ...
def resize_(self, int) -> None: ...
#END
def data_ptr(self) -> _int: ...
# See https://github.com/python/mypy/issues/4146 for why these workarounds
# is necessary
......@@ -935,10 +928,8 @@ class Tensor:
def unique_consecutive(self, sorted=True, return_inverse=False, return_counts=False, dim=None): ...
def lu(self, pivot=True, get_infos=False): ...
#MODIFIED BY TORCHGPIPE
from .cuda import Stream
def record_stream(self, stream: Optional[Stream]) -> None: ...
#END
@overload
def __and__(self: Tensor, other: Number) -> Tensor: ...
......@@ -1924,7 +1915,5 @@ def clear_autocast_cache() -> None: ...
# possible to type correctly
def nonzero(input: Tensor, *, out: Optional[Tensor]=None, as_tuple: Optional[_bool]=None): ...
#MODIFIED BY TORCHGPIPE
def is_grad_enabled() -> _bool: ...
__version__: str = ...
#END
......@@ -32,6 +32,7 @@ class Module(Generic[T_co]):
def add_module(self, name: str, module: 'Module') -> None: ...
def apply(self: T, fn: Callable[['Module'], None]) -> T: ...
def _apply(self: T, fn: Callable[['Module'], None]) -> T: ...
def cuda(self: T, device: Optional[Union[int, str, device]] = ...) -> T: ...
......
......@@ -16,6 +16,8 @@ import numpy as np
import pytest
import torch
pytestmark = pytest.mark.skip(reason="ssd offload to be removed to simplify the code")
try:
import fairscale.experimental.nn.ssd_offload as so
except ImportError as ie:
......
......@@ -658,6 +658,25 @@ class TestModuleProperties(DistributedTest):
torch.testing.assert_allclose(before_nm[1].shape, after_nm_original[1].cpu().shape)
class TestResetParameters(DistributedTest):
def test_reset_parameters(self):
"""Ensure that reduce_scatter_process_group same size with the world size."""
test_fn = functools.partial(self._test_reset, config={})
spawn_and_init(test_fn, world_sizes=[2])
@classmethod
def _test_reset(self, rank, group, config):
model = self._get_model(group, config)
with model.summon_full_params():
model.reset_parameters()
@classmethod
def _get_model(self, group, config):
with torch.no_grad(): # required for multiprocessing
model = nn.Linear(10, 10)
return FullyShardedDataParallel(model, group, allow_reset_parameters=True, **config)
class TransformerWithSharedParams(nn.Module):
def __init__(self, group, *unused_args, d_vocab=23, d_model=16, add_bn=True, **unused_kwargs):
super().__init__()
......
......@@ -189,10 +189,10 @@ class TestGradAccCommunication(DistributedTest):
# the sum of the _base and public methods should stay the same.
assert (
mock_all_gather.call_count + mock_all_gather_base.call_count == expected_all_gather1
), f"{mock_all_gather.call_count + mock_all_gather_base.call_count} != {expected_all_gather1}"
), f"{mock_all_gather.call_count} + {mock_all_gather_base.call_count} != {expected_all_gather1}"
assert (
mock_reduce_scatter.call_count + mock_reduce_scatter_base.call_count == 0
), f"{mock_reduce_scatter.call_count + mock_reduce_scatter_base.call_count} != 0"
), f"{mock_reduce_scatter.call_count} + {mock_reduce_scatter_base.call_count} != 0"
output = model(*batch)
loss = model.module.get_loss(batch, output)
......@@ -200,11 +200,11 @@ class TestGradAccCommunication(DistributedTest):
assert (
mock_all_gather.call_count + mock_all_gather_base.call_count == expected_all_gather2
), f"{mock_all_gather.call_count + mock_all_gather_base.call_count} != {expected_all_gather2}"
), f"{mock_all_gather.call_count} + {mock_all_gather_base.call_count} != {expected_all_gather2}"
assert (
mock_reduce_scatter.call_count + mock_reduce_scatter_base.call_count
== expected_reduce_scatter
), f"{mock_reduce_scatter.call_count + mock_reduce_scatter_base.call_count} != {expected_reduce_scatter}"
), f"{mock_reduce_scatter.call_count} + {mock_reduce_scatter_base.call_count} != {expected_reduce_scatter}"
if __name__ == "__main__":
......
......@@ -16,6 +16,8 @@ import torch
from torch import nn
import torch.distributed
pytestmark = pytest.mark.skip(reason="ssd offload to be removed to simplify the code")
try:
import fairscale.experimental.nn.ssd_offload as so
except ImportError as ie:
......
......@@ -2,6 +2,7 @@
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import copy
import functools
from time import time
import unittest
......@@ -13,7 +14,7 @@ from torch.optim import SGD, Adadelta, Adam # type: ignore
from fair_dev.testing.testing import dist_init, objects_are_equal, spawn_for_all_world_sizes
from fairscale.internal.params import recursive_copy_to_device
from fairscale.nn import FullyShardedDataParallel
from fairscale.nn.data_parallel import FullyShardedDataParallel, get_fsdp_instances
from fairscale.nn.data_parallel.fsdp_optim_utils import is_singleton_tensor
from .test_fsdp import (
......@@ -158,9 +159,9 @@ class TestOptimizerUtils(DistributedTest):
unwrapped_sd = optim_unwrapped.state_dict()
if not transformer and not expert_group:
no_broadcast_children = [x for x in fsdp._fsdp_instances() if x.no_broadcast_optim_state]
no_broadcast_children = [x for x in get_fsdp_instances(fsdp) if x.no_broadcast_optim_state]
assert len(no_broadcast_children) == 1, f"Length of non shared params {len(no_broadcast_children)}"
assert fsdp._fsdp_instances()[-1].no_broadcast_optim_state
assert get_fsdp_instances(fsdp)[-1].no_broadcast_optim_state
torch.cuda.empty_cache()
cuda_gb_before = torch.cuda.memory_stats(fsdp.rank)["allocated_bytes.all.current"] / 1024**3
tstart = time()
......@@ -196,12 +197,15 @@ class TestOptimizerUtils(DistributedTest):
)
return
unflat_state = sd["state"]
assert "uncollected_local_ids" in sd
shard_sd = fsdp.get_shard_from_optim_state_dict(sd)
sd_copy = copy.deepcopy(sd)
unflat_state = sd_copy["state"]
shard_sd = fsdp.get_shard_from_optim_state_dict(sd_copy)
shard_sd = recursive_copy_to_device(shard_sd, non_blocking=False, device="cpu")
state_after_get_shard = sd["state"]
assert objects_are_equal(unflat_state, state_after_get_shard) # no side effects.
state_after_get_shard = sd_copy["state"]
# sd is changed in-place in case there are extra states.
assert not objects_are_equal(unflat_state, state_after_get_shard)
del sd_copy
assert_equal(len(sd["state"]), len(unwrapped_sd["state"]))
assert_equal(len(sd["param_groups"][0]["params"]), len(unwrapped_sd["param_groups"][0]["params"]))
......@@ -223,8 +227,8 @@ class TestOptimizerUtils(DistributedTest):
[v for k, v in shard_sd["param_groups"][0].items()],
[v for k, v in original_shard_sd["param_groups"][0].items()],
)
assert objects_are_equal(shard_sd["state"], original_shard_sd["state"])
assert objects_are_equal({k: shard_sd[k] for k in original_shard_sd}, original_shard_sd)
objects_are_equal(shard_sd["state"], original_shard_sd["state"], raise_exception=True)
objects_are_equal({k: shard_sd[k] for k in original_shard_sd}, original_shard_sd, raise_exception=True)
@parameterized.expand(
[(True,), (False,)],
......@@ -260,7 +264,7 @@ class TestOptimizerUtils(DistributedTest):
model = TransformerWithSharedParams(group)
named_pars = [p for n, p in model.named_parameters()]
for i, p in enumerate(model.parameters()):
assert objects_are_equal(p, named_pars[i])
objects_are_equal(p, named_pars[i], raise_exception=True)
def test_is_singleton_tensor(self):
"""Test is_singleton_tensor function"""
......
......@@ -158,7 +158,8 @@ def _dist_worker(rank, world_size, files, wrap_middle, test_fn):
# We don't raise exceptions in CI since CI's T4 machine seems to be flaky with this test.
# On devel machines, we do want to catch potential errors. There could be real bugs or
# system issues behind the flakiness. One example is all-reduce vs. simulated averaging
# below.
# below. The check also fails on my rtx 20xx. So maybe it only works on devfair with
# Quadro GP100 GPUs. TODO (Min): debug this.
objects_are_equal(sd_after, fsdp_model.state_dict(), raise_exception=not in_circle_ci())
elif test_fn == "eval":
_eval(fsdp_model, in_data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment