Unverified Commit 429f3d31 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] better handling non-flatten in FSDP (#1072)



* [fix] better handling non-flatten in FSDP

- see the detailed comment about that backward firing case
- also minor debugging help in FSDP
- also minor fix in FPW's state dict

* [feat] disallow reset_parameters by default

* [feat] adding fsdp_instances API - useful in check wrapping by user code

* [fix] one line fix but more than a day of debugging

* fixed the case of loading combined check with empty fsdp instances

* fixed another bug around state loading the root/nonroot module full param caching due to not resharding after forward

* [feat] support .half and .float better

* fixed a bug in gather optim state losses extra keys from the original state_dict

* fixed a test failure in mixed precision

* fixed another bug affecting no_sync grad acc

* fixed a bug and a test in fsdp optim state

* fixed another corner case

* added a comment

* skip ssd offload tests

* skip fsdp one for ssd overload
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent 47ce21ac
...@@ -500,7 +500,7 @@ class LayerwiseMemoryTracker: ...@@ -500,7 +500,7 @@ class LayerwiseMemoryTracker:
Indicate if x and y share the same storage, meaning that one of them Indicate if x and y share the same storage, meaning that one of them
is a view, reshape or stride of the other or from a common tensor is a view, reshape or stride of the other or from a common tensor
""" """
return x.storage().data_ptr() == y.storage().data_ptr() # type: ignore return x.storage().data_ptr() == y.storage().data_ptr()
@staticmethod @staticmethod
def _collect_tensors(module_io_tensors: Union[torch.Tensor, Sequence[torch.Tensor]]) -> List[torch.Tensor]: def _collect_tensors(module_io_tensors: Union[torch.Tensor, Sequence[torch.Tensor]]) -> List[torch.Tensor]:
......
...@@ -12,6 +12,7 @@ from .fully_sharded_data_parallel import ( ...@@ -12,6 +12,7 @@ from .fully_sharded_data_parallel import (
OffloadConfig, OffloadConfig,
TrainingState, TrainingState,
auto_wrap_bn, auto_wrap_bn,
get_fsdp_instances,
no_pre_load_state_dict_hook, no_pre_load_state_dict_hook,
) )
......
...@@ -14,9 +14,6 @@ from fairscale.nn.misc import FlattenParamsWrapper ...@@ -14,9 +14,6 @@ from fairscale.nn.misc import FlattenParamsWrapper
if TYPE_CHECKING: if TYPE_CHECKING:
from fairscale.nn.data_parallel import FullyShardedDataParallel from fairscale.nn.data_parallel import FullyShardedDataParallel
# These return keys are used by fairseq. To change, add @sshleifer as a reviewer.
UNFLAT_RETURN_KEYS = {"state", "param_groups", "uncollected_local_ids", "param_id_map"}
# This function helps shard a full optimizer state dict # This function helps shard a full optimizer state dict
def flatten_optim_state_dict(sd: Dict) -> Dict: def flatten_optim_state_dict(sd: Dict) -> Dict:
"""Shard a full optimizer state dict (called by FSDP.get_shard_from_optim_state_dict)""" """Shard a full optimizer state dict (called by FSDP.get_shard_from_optim_state_dict)"""
...@@ -52,20 +49,24 @@ def flatten_optim_state_dict(sd: Dict) -> Dict: ...@@ -52,20 +49,24 @@ def flatten_optim_state_dict(sd: Dict) -> Dict:
new_state[local_id][buffer_name] = torch.cat(tensors) new_state[local_id][buffer_name] = torch.cat(tensors)
new_state[local_id].update(non_tensor_state) new_state[local_id].update(non_tensor_state)
new_state[local_id].update(singleton_state[local_id]) new_state[local_id].update(singleton_state[local_id])
new_sd = {"state": new_state, "param_groups": copy.deepcopy(sd["param_groups"])}
for k in sd.keys(): # if there are extra keys, like loss_scale, don't delete them
if k not in UNFLAT_RETURN_KEYS:
new_sd[k] = copy.deepcopy(sd[k])
# Now make a new param_groups copy and update it.
new_sd_pg = copy.deepcopy(sd["param_groups"])
# add pointers from the `params` dict. # add pointers from the `params` dict.
for pg_id, _ in enumerate(sd["param_groups"]): for pg_id, _ in enumerate(sd["param_groups"]):
# The values() list may look like [0,0,None,None,2,2]. We use # The values() list may look like [0,0,None,None,2,2]. We use
# groupby to remove the duplicates and then count the length of # groupby to remove the duplicates and then count the length of
# resulting iter. # resulting iter.
num_local_params = sum(1 for _ in groupby(param_id_map.values())) num_local_params = sum(1 for _ in groupby(param_id_map.values()))
new_sd["param_groups"][pg_id]["params"] = list(range(num_local_params)) new_sd_pg[pg_id]["params"] = list(range(num_local_params))
return new_sd # update the original sd so that we don't lose extra keys, like loss_scale.
sd["state"] = new_state
sd["param_groups"] = new_sd_pg
# delete extra keys we have added to match the original state.
del sd["uncollected_local_ids"]
del sd["param_id_map"]
return sd
def check_param_counts_before_sharding(full_optim_state_dict: Dict, n_instances: int) -> None: def check_param_counts_before_sharding(full_optim_state_dict: Dict, n_instances: int) -> None:
...@@ -202,7 +203,7 @@ def build_unflat_state_dict( ...@@ -202,7 +203,7 @@ def build_unflat_state_dict(
state: Dict[int, Dict[str, List[torch.Tensor]]], state: Dict[int, Dict[str, List[torch.Tensor]]],
singleton_state: Dict[int, Dict[str, List[torch.Tensor]]], singleton_state: Dict[int, Dict[str, List[torch.Tensor]]],
uncollected_opt_state: Dict[int, Dict], uncollected_opt_state: Dict[int, Dict],
param_groups: List[Dict], original_sd: Dict,
) -> Dict: ) -> Dict:
"""Build an unflattened optimizer state dict given a list of flattened optimizer state dicts """Build an unflattened optimizer state dict given a list of flattened optimizer state dicts
from each rank. This is only called on rank 0. from each rank. This is only called on rank 0.
...@@ -213,7 +214,7 @@ def build_unflat_state_dict( ...@@ -213,7 +214,7 @@ def build_unflat_state_dict(
state: all-gathered combined/local/flatten state_dict state: all-gathered combined/local/flatten state_dict
singleton_state: all-gathered singleton_state (dimensionless tensors) singleton_state: all-gathered singleton_state (dimensionless tensors)
uncollected_opt_state: non-tensor and not-gathered state uncollected_opt_state: non-tensor and not-gathered state
param_groups: the original rank 0's sd["param_groups"] original_sd: the original rank 0's sd
Returns: Returns:
dict: an unflattened, nonsharded optimizer state, as if FSDP was not there. dict: an unflattened, nonsharded optimizer state, as if FSDP was not there.
...@@ -228,19 +229,19 @@ def build_unflat_state_dict( ...@@ -228,19 +229,19 @@ def build_unflat_state_dict(
singleton_state[local_id] = {buffer_name: [x] for buffer_name, x in v.items() if is_singleton_tensor(x)} singleton_state[local_id] = {buffer_name: [x] for buffer_name, x in v.items() if is_singleton_tensor(x)}
# local ids are in the current state, global_ids will be in returned state. # local ids are in the current state, global_ids will be in returned state.
unflat_state, global_to_local_id = _unflatten_optim_state(state, instance_list, world_pad_info, singleton_state) unflat_state, global_to_local_id = _unflatten_optim_state(state, instance_list, world_pad_info, singleton_state)
# Since there are no tensors in param_groups, deepcopy is fine. # Since there are no tensors in param_groups, deepcopy is fine.
param_groups = copy.deepcopy(param_groups) param_groups = copy.deepcopy(original_sd["param_groups"])
# Casting needed only for mypy. # Casting needed only for mypy.
num_params = sum([cast(int, m.num_params_managed) for m in instance_list]) num_params = sum([cast(int, m.num_params_managed) for m in instance_list])
param_groups[0]["params"] = list(range(num_params)) param_groups[0]["params"] = list(range(num_params))
unflat_optim_state_dict = {
"state": dict(sorted(unflat_state.items())), # NOTE: this is probably already sorted # Update the original sd so we don't loss extra state like loss_scale.
"param_id_map": global_to_local_id, original_sd["state"] = dict(sorted(unflat_state.items())) # NOTE: this is probably already sorted
"param_groups": param_groups, original_sd["param_id_map"] = global_to_local_id
"uncollected_local_ids": list(uncollected_opt_state.keys()), original_sd["param_groups"] = param_groups
} original_sd["uncollected_local_ids"] = list(uncollected_opt_state.keys())
assert set(unflat_optim_state_dict.keys()) == UNFLAT_RETURN_KEYS return original_sd
return unflat_optim_state_dict
def is_singleton_tensor(x: Any) -> bool: def is_singleton_tensor(x: Any) -> bool:
......
...@@ -217,6 +217,8 @@ class FullyShardedDataParallel(nn.Module): ...@@ -217,6 +217,8 @@ class FullyShardedDataParallel(nn.Module):
save memory. Consider a case that an FSDP root module is a submodule of a model. save memory. Consider a case that an FSDP root module is a submodule of a model.
Backward pass may not start immediate after the FSDP root module finishes its forward. Backward pass may not start immediate after the FSDP root module finishes its forward.
So, reshard the parameters for the FSDP root modules can help to save memory in this case. So, reshard the parameters for the FSDP root modules can help to save memory in this case.
In certain cases, the performance is not even slower, because the cached full param
state may be stale due to load_local_state_dict() calls anyway.
Default: True. Default: True.
mixed_precision (bool, Optional): mixed_precision (bool, Optional):
if ``True``, inputs, activations and gradients will be kept in FP16; if ``True``, inputs, activations and gradients will be kept in FP16;
...@@ -230,6 +232,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -230,6 +232,7 @@ class FullyShardedDataParallel(nn.Module):
which improves training speed. which improves training speed.
move_params_to_cpu (bool, Optional): move_params_to_cpu (bool, Optional):
if ``True``, offload params to CPU. if ``True``, offload params to CPU.
Default: False
compute_dtype (torch.dtype, Optional): compute_dtype (torch.dtype, Optional):
dtype for full parameters for computation. This defaults to dtype for full parameters for computation. This defaults to
``torch.float32`` unless *``mixed_precision``* is set, in which case ``torch.float32`` unless *``mixed_precision``* is set, in which case
...@@ -308,6 +311,12 @@ class FullyShardedDataParallel(nn.Module): ...@@ -308,6 +311,12 @@ class FullyShardedDataParallel(nn.Module):
rank 0 and return empty dict non-rank 0, which allow FullyShardedDataParallel to rank 0 and return empty dict non-rank 0, which allow FullyShardedDataParallel to
skip the GPU -> CPU copy on non-rank 0 altogether and prevent OOM. skip the GPU -> CPU copy on non-rank 0 altogether and prevent OOM.
Default: False Default: False
gradient_predivide_factor (float, optional):
If supplied, pre-divide the gradients before scatter-reduce.
Default: None
allow_reset_parameters (bool):
If True, allow ``reset_parameters`` API to be proxied to the wrapped module.
Default: False
""" """
def __init__( def __init__(
...@@ -336,6 +345,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -336,6 +345,7 @@ class FullyShardedDataParallel(nn.Module):
offload_config: Optional[OffloadConfig] = None, offload_config: Optional[OffloadConfig] = None,
state_dict_on_rank_0_only: bool = False, state_dict_on_rank_0_only: bool = False,
gradient_predivide_factor: Optional[float] = None, gradient_predivide_factor: Optional[float] = None,
allow_reset_parameters: bool = False,
): ):
try: try:
import torch._C import torch._C
...@@ -415,6 +425,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -415,6 +425,7 @@ class FullyShardedDataParallel(nn.Module):
self.world_size self.world_size
) )
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
self.allow_reset_parameters = allow_reset_parameters
self.numel_padded_per_param: List[int] = [] self.numel_padded_per_param: List[int] = []
self._tstart = time.time() self._tstart = time.time()
...@@ -488,6 +499,10 @@ class FullyShardedDataParallel(nn.Module): ...@@ -488,6 +499,10 @@ class FullyShardedDataParallel(nn.Module):
for p in self.params: for p in self.params:
if p.dtype is not torch.float16: if p.dtype is not torch.float16:
raise ValueError("Expecting FP16 param type in pure FP16 mode.") raise ValueError("Expecting FP16 param type in pure FP16 mode.")
else:
for p in self.params:
if p.dtype is not torch.float32:
raise ValueError("Expecting FP16 param type in FP32 & MP modes.")
# Shard module parameters in place # Shard module parameters in place
self._shard_parameters_() self._shard_parameters_()
...@@ -531,9 +546,8 @@ class FullyShardedDataParallel(nn.Module): ...@@ -531,9 +546,8 @@ class FullyShardedDataParallel(nn.Module):
# Free all params at the end of initialization. # Free all params at the end of initialization.
if self.ssd_offload: if self.ssd_offload:
for m in self.modules(): # includes self for m in get_fsdp_instances(self):
if isinstance(m, FullyShardedDataParallel): m._free_ssd_offload()
m._free_ssd_offload()
def _get_gradient_predivide_factor(self, world_size: int) -> float: def _get_gradient_predivide_factor(self, world_size: int) -> float:
factor: int = 1 factor: int = 1
...@@ -820,12 +834,12 @@ class FullyShardedDataParallel(nn.Module): ...@@ -820,12 +834,12 @@ class FullyShardedDataParallel(nn.Module):
f"compute_dtype={self.compute_dtype}, " f"compute_dtype={self.compute_dtype}, "
f"buffer_dtype={self.buffer_dtype}, " f"buffer_dtype={self.buffer_dtype}, "
f"fp32_reduce_scatter={self.fp32_reduce_scatter}, " f"fp32_reduce_scatter={self.fp32_reduce_scatter}, "
f"compute_device={self.compute_device}" f"compute_device={self.compute_device}, "
f"move_params_to_cpu={self.move_params_to_cpu}, " f"move_params_to_cpu={self.move_params_to_cpu}, "
f"move_grads_to_cpu={self.move_grads_to_cpu}, " f"move_grads_to_cpu={self.move_grads_to_cpu}, "
f"bucket_cap_mb={self.bucket_cap_mb}, " f"bucket_cap_mb={self.bucket_cap_mb}, "
f"clear_autocast_cache={self.clear_autocast_cache}" f"clear_autocast_cache={self.clear_autocast_cache}, "
f"force_input_to_fp32={self.force_input_to_fp32}" f"force_input_to_fp32={self.force_input_to_fp32}, "
) )
return repr return repr
...@@ -989,9 +1003,8 @@ class FullyShardedDataParallel(nn.Module): ...@@ -989,9 +1003,8 @@ class FullyShardedDataParallel(nn.Module):
) )
with contextlib.ExitStack() as stack: with contextlib.ExitStack() as stack:
# Tell any nested FSDP instances not to auto summon full params. # Tell any nested FSDP instances not to auto summon full params.
for module in self.modules(): # includes self for module in get_fsdp_instances(self):
if isinstance(module, FullyShardedDataParallel): stack.enter_context(module._no_return_full_state_dict())
stack.enter_context(module._no_return_full_state_dict())
# We need to specially call FSDP's state_dict function in case # We need to specially call FSDP's state_dict function in case
# self.state_dict is a function from a child class of FSDP. # self.state_dict is a function from a child class of FSDP.
return FullyShardedDataParallel.state_dict(self, *args, **kwargs) return FullyShardedDataParallel.state_dict(self, *args, **kwargs)
...@@ -1057,10 +1070,22 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1057,10 +1070,22 @@ class FullyShardedDataParallel(nn.Module):
) )
with contextlib.ExitStack() as stack: with contextlib.ExitStack() as stack:
# Tell any nested FSDP instances not to auto summon full params. # Tell any nested FSDP instances not to auto summon full params.
for module in self.modules(): # includes self for module in get_fsdp_instances(self):
if isinstance(module, FullyShardedDataParallel): stack.enter_context(module._no_return_full_state_dict())
stack.enter_context(module._no_return_full_state_dict())
output = self._load_state_dict(state_dict, strict) output = self._load_state_dict(state_dict, strict)
# After loading the local state, if the a FSDP wrapper has the full
# params built, it will not use the updated value. Therefore we call
# _free_full_params() here to force it get updated on the next time when
# it needs to be built.
#
# There are 2 cases why this can happen:
# 1. in training, outermost wrapper may have reshrad_after_forward to
# False. (_is_root is True); therefore, full param is built and kept.
# 2. in eval, inner modules may get called directly, hence having multiple
# "root" instance, therefore, we need to loop over all instances
# below to free full params.
for module in get_fsdp_instances(self):
module._free_full_params()
return output return output
@contextlib.contextmanager @contextlib.contextmanager
...@@ -1077,7 +1102,8 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1077,7 +1102,8 @@ class FullyShardedDataParallel(nn.Module):
.. note:: Gradient accumulation can be done without this context, .. note:: Gradient accumulation can be done without this context,
avoiding the extra GPU memory overhead, but with the extra avoiding the extra GPU memory overhead, but with the extra
networking overhead. networking overhead. I.e. the trainer loop should just do
multiple fwd/bwd without step() without the no_sync context.
""" """
self._lazy_init() self._lazy_init()
assert self._is_root, "no_sync on inner FSDP is not supported" assert self._is_root, "no_sync on inner FSDP is not supported"
...@@ -1085,10 +1111,9 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1085,10 +1111,9 @@ class FullyShardedDataParallel(nn.Module):
# This instance may wrap other FSDP instances and we # This instance may wrap other FSDP instances and we
# need to set all of them to accumulate gradients. # need to set all of them to accumulate gradients.
old_flags = [] old_flags = []
for m in self.modules(): # includes self for m in get_fsdp_instances(self):
if isinstance(m, FullyShardedDataParallel): old_flags.append((m, m._require_backward_grad_sync))
old_flags.append((m, m._require_backward_grad_sync)) m._require_backward_grad_sync = False
m._require_backward_grad_sync = False
try: try:
yield yield
finally: finally:
...@@ -1096,6 +1121,89 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1096,6 +1121,89 @@ class FullyShardedDataParallel(nn.Module):
assert m._require_backward_grad_sync is False assert m._require_backward_grad_sync is False
m._require_backward_grad_sync = old_flag m._require_backward_grad_sync = old_flag
def reset_parameters(self) -> None:
"""Special reset_parameters API handling.
We don't by default allow this API because it has at least 2 issues:
1. calling it after wrapping can crash due to unexpected tensor size
and dimensions due to flattening and sharding. So summon_full_params
context might be required.
2. calling it after wrapping can result in incorrect init values due
to flattening.
See this gist for an example of the init issue when parameters are
flatten.
https://gist.github.com/407bb158f0d0612e157c2cbcf5c8b76a
Or, like in 1, init function can silently init the weight differently
because of the dimensions.
Finally, be advised that init on CPU vs. on GPU can have different
values. If models are originally on CPU and after wrapping it is moved
to GPU, calling this will again be problematic.
"""
if self.allow_reset_parameters:
self.module.reset_parameters()
else:
raise RuntimeError("reset parameters after FSDP wrapping is not allowed")
def _apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel":
"""Hook into model conversion, like .half() and .float()
When users call module.half() or module.float() after FSDP wrapping,
we need to update some internal states here.
Args:
fn (Callable):
same as nn.Module's _apply.
Returns:
(Any):
same as nn.Module's _apply.
"""
# Just a pre-caution. Conversion happens while IDLE is the safest.
self.assert_state(TrainingState.IDLE)
# In order to determine how to change compute_dtype, we need to
# remember the dtype before this call.
if len(self.params):
dtype_before = self.params[0].dtype
# Call nn.Module's _apply.
ret = super()._apply(fn)
# Make sure we update p._full_param_padded according to the new dtype if we are
# not in mixed_precision. In mixed precision, doing m.half() or m.float() really
# don't make much sense. But we need allow it in case user just wanted to temporarily
# .half() and then .float() back for some reason.
if not self.mixed_precision:
for p in self.params:
if hasattr(p, "_full_param_padded"):
allocated = False
if p._full_param_padded.storage().size() == 0:
allocated = True
alloc_storage_(p._full_param_padded, size=p._full_param_padded.size())
p._full_param_padded = p._full_param_padded.to(dtype=p.data.dtype, device=p.data.device)
if allocated:
free_storage_(p._full_param_padded)
# Update compute_dtype because otherwise, p._full_param_padded will
# still be in that dtype.
if len(self.params):
dtype_after = self.params[0].dtype
if dtype_before != dtype_after:
# There are 4 cases below. Only 2 result in compute_dtype change
# to the dtype_after.
# 16 -> 32, 32 -> 16
# mixed n/a no change
# not mixed change change
if not self.mixed_precision:
self.compute_dtype = dtype_after
return ret
@contextlib.contextmanager @contextlib.contextmanager
def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Generator: def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Generator:
""" """
...@@ -1138,8 +1246,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1138,8 +1246,7 @@ class FullyShardedDataParallel(nn.Module):
torch.cuda.synchronize() torch.cuda.synchronize()
self._lazy_init() self._lazy_init()
self.assert_state(TrainingState.IDLE) self.assert_state(TrainingState.IDLE)
# Set the state so that we assert when trying to go into # Set the state so that we assert when trying to go into fwd/bwd.
# forward/backward.
self.training_state = TrainingState.SUMMON_FULL_PARAMS self.training_state = TrainingState.SUMMON_FULL_PARAMS
full_tensors = self._rebuild_full_params(force_full_precision=True) full_tensors = self._rebuild_full_params(force_full_precision=True)
assert full_tensors is not None assert full_tensors is not None
...@@ -1267,9 +1374,9 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1267,9 +1374,9 @@ class FullyShardedDataParallel(nn.Module):
p._fp32_shard = p.data p._fp32_shard = p.data
if self.mixed_precision: if self.mixed_precision:
assert p._fp32_shard.dtype == torch.float32 assert p._fp32_shard.dtype == torch.float32, self
if self.move_params_to_cpu: if self.move_params_to_cpu:
assert p._fp32_shard.device == torch.device("cpu") assert p._fp32_shard.device == torch.device("cpu"), self
# We don't pin memory when using ssd_offload since that results in OOM when # We don't pin memory when using ssd_offload since that results in OOM when
# the memory requirements of a model are larger than host memory. # the memory requirements of a model are larger than host memory.
...@@ -1817,6 +1924,29 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1817,6 +1924,29 @@ class FullyShardedDataParallel(nn.Module):
def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None: def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
"""Helper used below on all fsdp modules.""" """Helper used below on all fsdp modules."""
if not fsdp_module._is_root and self._require_backward_grad_sync:
# We make sure to switch to fp32 shards here because there might be
# params linger in full_param mode, if the following firing order happens:
# pre-bwd: rebuild and use full for p1 and p2
# post-bwd for p1: free and switch to fp32 shard for p1
# pre-bwd: rebuild again for p1 and p2
# post-bwd for p2: free and switch to fp32 shard for p2
# In the end, p1 will be left in full param mode.
#
# We need switch to fp32 *and* free full params. If we don't free,
# we end up reusing potentially *stale* full param (after the fp32
# shard is updated (e.g. by optimizer.step()).
#
# We skip the root because it may have reshard=False, which means
# we want to keep the speed benefit of that. I haven't seen a case
# where this is needed on the root module.
#
# We skip also in grad accum steps since we want to keep the full
# params since they haven't been updated. See comment of ``no_sync``
# for when to use no_sync style grad acc. For FSDP, it is more likely
# you want to use grad acc without no_sync.
fsdp_module._free_full_params()
fsdp_module._use_fp32_param_shard()
for p in fsdp_module.params: for p in fsdp_module.params:
if not p.requires_grad: if not p.requires_grad:
continue continue
...@@ -1828,7 +1958,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1828,7 +1958,7 @@ class FullyShardedDataParallel(nn.Module):
# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad # Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
# remains the unsharded gradient accumulated from prior no-sync passes, and p._saved_grad_shard # remains the unsharded gradient accumulated from prior no-sync passes, and p._saved_grad_shard
# remains the sharded gradient from the last synchronized pass. This also allows interleaved no-sync and # remains the sharded gradient from the last synchronized pass. This also allows interleaved no-sync and
# sync passes, if desired. # sync passes.
if not self._require_backward_grad_sync: if not self._require_backward_grad_sync:
continue continue
...@@ -1841,46 +1971,49 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1841,46 +1971,49 @@ class FullyShardedDataParallel(nn.Module):
p.device == p._saved_grad_shard.device, p.device == p._saved_grad_shard.device,
f"WFPB: incorrect saved_grad_shard device {p.device} vs {p._saved_grad_shard.device}", f"WFPB: incorrect saved_grad_shard device {p.device} vs {p._saved_grad_shard.device}",
) )
p_assert(
p.shape == p._saved_grad_shard.shape,
f"WFPB: incorrect saved_grad_shard shape {p.shape} vs {p._saved_grad_shard.shape}",
)
p.grad = p._saved_grad_shard p.grad = p._saved_grad_shard
if hasattr(p, "_saved_grad_shard"): if hasattr(p, "_saved_grad_shard"):
delattr(p, "_saved_grad_shard") delattr(p, "_saved_grad_shard")
# Update root and nested FSDP's hooks and flags. # Update root and nested FSDP's hooks and flags.
for m in self.modules(): # includes self for m in get_fsdp_instances(self):
if isinstance(m, FullyShardedDataParallel): _finalize_parameters(m)
_finalize_parameters(m) m._free_ssd_offload()
m._free_ssd_offload() m._pre_backward_hook_has_run = False
m._pre_backward_hook_has_run = False if any(p.requires_grad for p in m.parameters()):
if any(p.requires_grad for p in m.parameters()): # Check if the module has params and if any of them has
# Check if the module has params and if any of them has # the `requires_grad` field set. If `requires_grad=False` for
# the `requires_grad` field set. If `requires_grad=False` for # all the params, the post_backward hook will not fire and the
# all the params, the post_backward hook will not fire and the # state will remain in `TrainingState.BACKWARD_PRE`.
# state will remain in `TrainingState.BACKWARD_PRE`. if any([p.requires_grad for p in m.params]):
if any([p.requires_grad for p in m.params]): m.assert_state(TrainingState.BACKWARD_POST)
m.assert_state(TrainingState.BACKWARD_POST)
else:
m.assert_state(TrainingState.BACKWARD_PRE)
else: else:
# When `m` and its children has no params or has params but m.assert_state(TrainingState.BACKWARD_PRE)
# none with `requires_grad==True`, there are two cases: else:
# 1. output tensors are `requires_grad==True`. In this case, # When `m` and its children has no params or has params but
# pre-backward hook is still registered, so it is in BACKWARD_PRE state. # none with `requires_grad==True`, there are two cases:
# 2. output tensors are `requires_grad==False`. In this case, # 1. output tensors are `requires_grad==True`. In this case,
# pre-backward hook is not registered, so it is in IDLE state. # pre-backward hook is still registered, so it is in BACKWARD_PRE state.
m.assert_state([TrainingState.BACKWARD_PRE, TrainingState.IDLE]) # 2. output tensors are `requires_grad==False`. In this case,
m.training_state = TrainingState.IDLE # pre-backward hook is not registered, so it is in IDLE state.
m.assert_state([TrainingState.BACKWARD_PRE, TrainingState.IDLE])
if m._is_root: m.training_state = TrainingState.IDLE
# reset this flag for cases like "one forward pass + multiple backward passes"
self._post_backward_callback_queued = False if m._is_root:
# clear this list for next iteration # reset this flag for cases like "one forward pass + multiple backward passes"
p_assert( self._post_backward_callback_queued = False
self._output_pre_backward_hook_registered is not None, # clear this list for next iteration
"WFPB: self._output_pre_backward_hook_registered should not be None", p_assert(
) self._output_pre_backward_hook_registered is not None,
assert self._output_pre_backward_hook_registered is not None # make mypy happy "WFPB: self._output_pre_backward_hook_registered should not be None",
self._output_pre_backward_hook_registered.clear() )
assert self._output_pre_backward_hook_registered is not None # make mypy happy
self._output_pre_backward_hook_registered.clear()
@torch.no_grad() @torch.no_grad()
def _rebuild_full_params(self, force_full_precision: bool = False) -> Optional[List[Tuple[torch.Tensor, bool]]]: def _rebuild_full_params(self, force_full_precision: bool = False) -> Optional[List[Tuple[torch.Tensor, bool]]]:
...@@ -1979,25 +2112,38 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1979,25 +2112,38 @@ class FullyShardedDataParallel(nn.Module):
if not p._is_sharded: # e.g., when world_size == 1 if not p._is_sharded: # e.g., when world_size == 1
update_p_data() update_p_data()
else: else:
# Skip if already built. Only shared param can be rebuilt multiple times. # Skip if already built.
#
# case 1: shared param can be rebuilt multiple times.
# A corner case is p._orig_size = (1,), which means the shape equality is # A corner case is p._orig_size = (1,), which means the shape equality is
# not a perfect check. But we assume we don't share a param with shape (1,). # not a perfect check. But we assume we don't share a param with shape (1,).
if p.data.shape == p._orig_size and hasattr(p, "_is_shared") and p._is_shared: # We do use size (1,) in unit testing at least.
# case 2: with multiple params (like non-flatten, or multiple flatten groups)
# we may have pre & post bwd firing order issues. See comments in the
# _finalize_parameters function for such case.
if p.data.shape == p._orig_size and p._orig_size != (1,):
assert p.data.storage().data_ptr() == p._full_param_padded.storage().data_ptr(), (
f"p.data {p.data.storage().data_ptr()} "
f"p._fp32_shard {p._fp32_shard.storage().data_ptr()} "
f"p._fp16_shard {p._fp16_shard.storage().data_ptr() if p._fp16_shard is not None else None} "
f"p._full_params_padded {p._full_param_padded.storage().data_ptr()} "
)
continue continue
# If self.move_params_to_cpu and force_full_precision, we need to cast # If self.move_params_to_cpu and force_full_precision, we need to cast
# the FP32 CPU param to CUDA for the all-gather. # the FP32 CPU param to CUDA for the all-gather.
p_data = p.data.to(p._full_param_padded.device, non_blocking=True) p_data = p.data.to(p._full_param_padded.device, non_blocking=True)
p_size = p._full_param_padded.size() full_p_size = p._full_param_padded.size()
assert p_size.numel() % self.world_size == 0 assert full_p_size.numel() % self.world_size == 0
if self.mixed_precision and force_full_precision: if self.mixed_precision and force_full_precision:
# Allocate fresh tensor in full precision since we are in # Allocate fresh tensor in full precision since we are in
# mixed precision and full precision rebuild is asked. # mixed precision and full precision rebuild is asked.
output_tensor = p_data.new_zeros(p_size) output_tensor = p_data.new_zeros(full_p_size)
else: else:
if p._full_param_padded.storage().size() != p_size.numel(): if p._full_param_padded.storage().size() != full_p_size.numel():
# Allocate based on full size from all shards. # Allocate based on full size from all shards.
alloc_storage_(p._full_param_padded, size=p_size) alloc_storage_(p._full_param_padded, size=full_p_size)
output_tensor = p._full_param_padded output_tensor = p._full_param_padded
# Fill output_tensor with (p.data for each shard in self.world_size) # Fill output_tensor with (p.data for each shard in self.world_size)
...@@ -2284,10 +2430,10 @@ class FullyShardedDataParallel(nn.Module): ...@@ -2284,10 +2430,10 @@ class FullyShardedDataParallel(nn.Module):
raise ValueError(msg) raise ValueError(msg)
def _broadcast_pad_info_to_r0(self) -> List[List[List[int]]]: def _broadcast_pad_info_to_r0(self) -> List[List[List[int]]]:
"""Collect [x.numel_padded_per_param for x in self._fsdp_instances()] from each rank.""" """Collect [x.numel_padded_per_param for x in get_fsdp_instances(self)] from each rank."""
world_pad_info: List[List[List[int]]] = [] # this will contain values from the whole world. world_pad_info: List[List[List[int]]] = [] # this will contain values from the whole world.
my_pad_info: List[List[int]] = [ my_pad_info: List[List[int]] = [
cast(List[int], m.numel_padded_per_param) for m in self._fsdp_instances(skip_empty=True) cast(List[int], m.numel_padded_per_param) for m in get_fsdp_instances(self, skip_empty=True)
] ]
for rank in range(self.world_size): for rank in range(self.world_size):
if rank == self.rank: if rank == self.rank:
...@@ -2307,7 +2453,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -2307,7 +2453,7 @@ class FullyShardedDataParallel(nn.Module):
singleton_state: Dict[int, Dict[str, List[Any]]] = {} # Dimensionless tensor singleton_state: Dict[int, Dict[str, List[Any]]] = {} # Dimensionless tensor
# Non-empty FSDP instance and sd_state item number must match. # Non-empty FSDP instance and sd_state item number must match.
fsdp_instances = self._fsdp_instances(skip_empty=True) fsdp_instances = get_fsdp_instances(self, skip_empty=True)
assert len(fsdp_instances) >= len(sd_state), f"{len(fsdp_instances)} vs. {len(sd_state)}" assert len(fsdp_instances) >= len(sd_state), f"{len(fsdp_instances)} vs. {len(sd_state)}"
for k, v in sd_state.items(): for k, v in sd_state.items():
...@@ -2377,7 +2523,9 @@ class FullyShardedDataParallel(nn.Module): ...@@ -2377,7 +2523,9 @@ class FullyShardedDataParallel(nn.Module):
self._lazy_init() self._lazy_init()
sd = self._remove_uncollectable_params_from_optim_state_dict(optim.state_dict()) sd = self._remove_uncollectable_params_from_optim_state_dict(optim.state_dict())
assert set(sd.keys()) == {"param_groups", "state"}, f'{set(sd.keys())} != {"param_groups", "state"}' assert {"param_groups", "state"}.issubset(
set(sd.keys())
), f'{set(sd.keys())} not a superset of {"param_groups", "state"}'
assert len(sd["param_groups"]) == 1, "Param groups are not supported" assert len(sd["param_groups"]) == 1, "Param groups are not supported"
# We use all_gather to consolidate OSD['state'] and broadcast to consolidate the other keys (like param_groups) # We use all_gather to consolidate OSD['state'] and broadcast to consolidate the other keys (like param_groups)
state, singleton_state = self._gather_optim_state(sd.pop("state")) state, singleton_state = self._gather_optim_state(sd.pop("state"))
...@@ -2386,29 +2534,30 @@ class FullyShardedDataParallel(nn.Module): ...@@ -2386,29 +2534,30 @@ class FullyShardedDataParallel(nn.Module):
return None return None
# Unify the shard states by concatenating tensors and unflattening params # Unify the shard states by concatenating tensors and unflattening params
new_state_dict = ou.build_unflat_state_dict( new_state_dict = ou.build_unflat_state_dict(
self._fsdp_instances(skip_empty=True), get_fsdp_instances(self, skip_empty=True),
pad_info, pad_info,
state, state,
singleton_state, singleton_state,
self.uncollected_opt_state, self.uncollected_opt_state,
sd["param_groups"], sd,
) )
self.uncollected_opt_state = {} self.uncollected_opt_state = {}
assert "uncollected_local_ids" in new_state_dict assert "uncollected_local_ids" in new_state_dict
return new_state_dict return new_state_dict
def _fsdp_instances(self, skip_empty: bool = False) -> List["FullyShardedDataParallel"]:
"""Returns all fsdp modules in self.modules() including self."""
result = [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)]
if skip_empty:
result = list(filter(lambda x: len(cast(FullyShardedDataParallel, x).non_shared_params()) > 0, result))
return result
def _remove_uncollectable_params_from_optim_state_dict(self, osd: Dict) -> Dict: def _remove_uncollectable_params_from_optim_state_dict(self, osd: Dict) -> Dict:
"""Return a new state dict filtering out the ones like MoE layers, which has """Return a new state dict filtering out the ones like MoE layers, which has
``no_broadcast_optim_state`` flag set. ``no_broadcast_optim_state`` flag set.
We also make rooms for the optimizer state on rank 0. We also make rooms for the optimizer state on rank 0.
Args:
osd (Dict):
Optimizer state dict from a rank. osd["state"] is what we mainly
care. Osd may contain other keys and values, we need to keep. Therefore,
we only change osd["state"] and not returning a new copy of osd
which is slower and may also lose extra fields, like "loss_scale"
used by fairseq.
""" """
# In PyTorch version 1.12, Adam's `step` state changed from an int to a singleton # In PyTorch version 1.12, Adam's `step` state changed from an int to a singleton
# tensor. We convert it back here. Otherwise, the step counter will be treated # tensor. We convert it back here. Otherwise, the step counter will be treated
...@@ -2419,8 +2568,8 @@ class FullyShardedDataParallel(nn.Module): ...@@ -2419,8 +2568,8 @@ class FullyShardedDataParallel(nn.Module):
if ou.is_singleton_tensor(bufs["step"]): if ou.is_singleton_tensor(bufs["step"]):
bufs["step"] = bufs["step"].item() bufs["step"] = bufs["step"].item()
# Get uncollected_ids. # Get uncollected_ids.
uncollected_ids = [i for i, m in enumerate(self._fsdp_instances()) if m.no_broadcast_optim_state] uncollected_ids = [i for i, m in enumerate(get_fsdp_instances(self)) if m.no_broadcast_optim_state]
new_dct = {"state": {k: v for k, v in osd["state"].items() if k not in uncollected_ids}} new_state_value = {k: v for k, v in osd["state"].items() if k not in uncollected_ids}
if self.rank == 0: if self.rank == 0:
# Save placeholders for uncollected opt state to keep the same unflat OSD format, and move them to CPU. # Save placeholders for uncollected opt state to keep the same unflat OSD format, and move them to CPU.
self.uncollected_opt_state = { self.uncollected_opt_state = {
...@@ -2429,9 +2578,8 @@ class FullyShardedDataParallel(nn.Module): ...@@ -2429,9 +2578,8 @@ class FullyShardedDataParallel(nn.Module):
if k in uncollected_ids if k in uncollected_ids
} }
pg = copy.deepcopy(osd["param_groups"]) osd["state"] = new_state_value
new_dct["param_groups"] = pg return osd
return new_dct
def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any]) -> Dict[str, Any]: def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Get the portion of the optimizer state dict associated with the shard """Get the portion of the optimizer state dict associated with the shard
...@@ -2439,14 +2587,21 @@ class FullyShardedDataParallel(nn.Module): ...@@ -2439,14 +2587,21 @@ class FullyShardedDataParallel(nn.Module):
This can be used to get the right sharded optimizer state to be loaded This can be used to get the right sharded optimizer state to be loaded
into the sharded optimizer for this FSDP rank. into the sharded optimizer for this FSDP rank.
..warning:: The input state dict is modified in-place assuming the original
full state isn't going to be used anymore. This is done so that
we don't need to copy extra state in it. It is caller's responsibility
to make copies if it doesn't want the original state dict modified.
Args: Args:
full_optim_state_dict (dict): consolidated optimizer state returned by ``gather_full_optim_state``, or loaded from a checkpoint. full_optim_state_dict (dict):
consolidated optimizer state returned by ``gather_full_optim_state``,
or loaded from a checkpoint.
Returns: Returns:
(dict): a shard of the optimizer state. (dict): a shard of the optimizer state.
""" """
# Assert nesting is the same as it was at save time # Assert nesting is the same as it was at save time
instance_list = self._fsdp_instances() instance_list = get_fsdp_instances(self, skip_empty=True)
ou.check_param_counts_before_sharding(full_optim_state_dict, len(instance_list)) ou.check_param_counts_before_sharding(full_optim_state_dict, len(instance_list))
ids_not_to_shard = copy.deepcopy(full_optim_state_dict["uncollected_local_ids"]) ids_not_to_shard = copy.deepcopy(full_optim_state_dict["uncollected_local_ids"])
if self.flatten_parameters: if self.flatten_parameters:
...@@ -2458,9 +2613,9 @@ class FullyShardedDataParallel(nn.Module): ...@@ -2458,9 +2613,9 @@ class FullyShardedDataParallel(nn.Module):
), f'{len(full_optim_state_dict["state"])}, {len(instance_list)}' ), f'{len(full_optim_state_dict["state"])}, {len(instance_list)}'
# get the portion of dict associated with the shard, in place # get the portion of dict associated with the shard, in place
for id, s in full_optim_state_dict["state"].items(): for _id, s in full_optim_state_dict["state"].items():
for k, v in s.items(): for k, v in s.items():
if torch.is_tensor(v) and id not in ids_not_to_shard: if torch.is_tensor(v) and _id not in ids_not_to_shard:
v_shard, _ = self._get_shard(v) v_shard, _ = self._get_shard(v)
elif isinstance(v, list) and ou.is_singleton_tensor(v[0]): elif isinstance(v, list) and ou.is_singleton_tensor(v[0]):
# if we are resuming on larger world size, take first entry # if we are resuming on larger world size, take first entry
...@@ -2468,7 +2623,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -2468,7 +2623,7 @@ class FullyShardedDataParallel(nn.Module):
assert ou.is_singleton_tensor(v_shard) assert ou.is_singleton_tensor(v_shard)
else: else:
v_shard = v # don't shard entries that are not tensors v_shard = v # don't shard entries that are not tensors
full_optim_state_dict["state"][id][k] = v_shard full_optim_state_dict["state"][_id][k] = v_shard
return full_optim_state_dict return full_optim_state_dict
...@@ -2728,3 +2883,22 @@ def auto_wrap_bn( ...@@ -2728,3 +2883,22 @@ def auto_wrap_bn(
enable_wrap(config_auto_wrap_policy, wrapper_cls=FullyShardedDataParallel) if wrap_it else contextlib.suppress() enable_wrap(config_auto_wrap_policy, wrapper_cls=FullyShardedDataParallel) if wrap_it else contextlib.suppress()
): ):
return auto_wrap(module) return auto_wrap(module)
def get_fsdp_instances(mod: nn.Module, skip_empty: bool = False) -> List[FullyShardedDataParallel]:
"""Return, a list, if any, of the module/submodule is wrapped by FSDP within another module.
Args:
mod (nn.Module):
A nn.Module module to be scanned.
skip_empty (bool):
If True, skip wrappers without any parameters.
Default: False
"""
ret: List[FullyShardedDataParallel] = []
for m in mod.modules(): # including mod itself
if isinstance(m, FullyShardedDataParallel):
ret.append(m)
if skip_empty:
ret = list(filter(lambda x: len(cast(FullyShardedDataParallel, x).non_shared_params()) > 0, ret))
return ret
...@@ -500,7 +500,11 @@ class FlattenParamsWrapper(nn.Module): ...@@ -500,7 +500,11 @@ class FlattenParamsWrapper(nn.Module):
# Unflatten the module automatically if the state_dict is non-flat. # Unflatten the module automatically if the state_dict is non-flat.
# Note, we check the flat_param_ prefix since custom names can be given and flat_param_0 is # Note, we check the flat_param_ prefix since custom names can be given and flat_param_0 is
# not always in the state dict's key list. # not always in the state dict's key list.
if self.is_flattened and not any(k.startswith("flat_param_") for k in state_dict.keys()): if (
self.num_params_managed > 0
and self.is_flattened
and not any(k.startswith("flat_param_") for k in state_dict.keys())
):
# This object is flatten but state_dict is not. So we unflatten and load. # This object is flatten but state_dict is not. So we unflatten and load.
with self.unflatten_params(): with self.unflatten_params():
return super().load_state_dict(state_dict, strict) return super().load_state_dict(state_dict, strict)
......
...@@ -32,11 +32,9 @@ from . import utils as utils ...@@ -32,11 +32,9 @@ from . import utils as utils
from . import jit as jit from . import jit as jit
from . import fft as fft from . import fft as fft
#MODIFIED BY TORCHGPIPE
from . import backends from . import backends
from . import distributed from . import distributed
from . import version from . import version
#END
class dtype: class dtype:
is_floating_point: builtins.bool is_floating_point: builtins.bool
...@@ -67,10 +65,8 @@ class device: ...@@ -67,10 +65,8 @@ class device:
type: str type: str
index: _int index: _int
#MODIFIED BY TORCHGPIPE
@overload @overload
def __init__(self, device: device) -> None: ... def __init__(self, device: device) -> None: ...
#END
@overload @overload
def __init__(self, device: Union[_int, str]) -> None: ... def __init__(self, device: Union[_int, str]) -> None: ...
...@@ -78,17 +74,14 @@ class device: ...@@ -78,17 +74,14 @@ class device:
@overload @overload
def __init__(self, type: str, index: _int) -> None: ... def __init__(self, type: str, index: _int) -> None: ...
#MODIFIED BY TORCHGPIPE
class Size(tuple): class Size(tuple):
def numel(self) -> _int: ... def numel(self) -> _int: ...
#END
#MODIFIED BY TORCHGPIPE
class Storage: class Storage:
def size(self) -> _int: ... def size(self) -> _int: ...
def element_size(self) -> _int: ... def element_size(self) -> _int: ...
def resize_(self, int) -> None: ... def resize_(self, int) -> None: ...
#END def data_ptr(self) -> _int: ...
# See https://github.com/python/mypy/issues/4146 for why these workarounds # See https://github.com/python/mypy/issues/4146 for why these workarounds
# is necessary # is necessary
...@@ -935,10 +928,8 @@ class Tensor: ...@@ -935,10 +928,8 @@ class Tensor:
def unique_consecutive(self, sorted=True, return_inverse=False, return_counts=False, dim=None): ... def unique_consecutive(self, sorted=True, return_inverse=False, return_counts=False, dim=None): ...
def lu(self, pivot=True, get_infos=False): ... def lu(self, pivot=True, get_infos=False): ...
#MODIFIED BY TORCHGPIPE
from .cuda import Stream from .cuda import Stream
def record_stream(self, stream: Optional[Stream]) -> None: ... def record_stream(self, stream: Optional[Stream]) -> None: ...
#END
@overload @overload
def __and__(self: Tensor, other: Number) -> Tensor: ... def __and__(self: Tensor, other: Number) -> Tensor: ...
...@@ -1924,7 +1915,5 @@ def clear_autocast_cache() -> None: ... ...@@ -1924,7 +1915,5 @@ def clear_autocast_cache() -> None: ...
# possible to type correctly # possible to type correctly
def nonzero(input: Tensor, *, out: Optional[Tensor]=None, as_tuple: Optional[_bool]=None): ... def nonzero(input: Tensor, *, out: Optional[Tensor]=None, as_tuple: Optional[_bool]=None): ...
#MODIFIED BY TORCHGPIPE
def is_grad_enabled() -> _bool: ... def is_grad_enabled() -> _bool: ...
__version__: str = ... __version__: str = ...
#END
...@@ -32,6 +32,7 @@ class Module(Generic[T_co]): ...@@ -32,6 +32,7 @@ class Module(Generic[T_co]):
def add_module(self, name: str, module: 'Module') -> None: ... def add_module(self, name: str, module: 'Module') -> None: ...
def apply(self: T, fn: Callable[['Module'], None]) -> T: ... def apply(self: T, fn: Callable[['Module'], None]) -> T: ...
def _apply(self: T, fn: Callable[['Module'], None]) -> T: ...
def cuda(self: T, device: Optional[Union[int, str, device]] = ...) -> T: ... def cuda(self: T, device: Optional[Union[int, str, device]] = ...) -> T: ...
......
...@@ -16,6 +16,8 @@ import numpy as np ...@@ -16,6 +16,8 @@ import numpy as np
import pytest import pytest
import torch import torch
pytestmark = pytest.mark.skip(reason="ssd offload to be removed to simplify the code")
try: try:
import fairscale.experimental.nn.ssd_offload as so import fairscale.experimental.nn.ssd_offload as so
except ImportError as ie: except ImportError as ie:
......
...@@ -658,6 +658,25 @@ class TestModuleProperties(DistributedTest): ...@@ -658,6 +658,25 @@ class TestModuleProperties(DistributedTest):
torch.testing.assert_allclose(before_nm[1].shape, after_nm_original[1].cpu().shape) torch.testing.assert_allclose(before_nm[1].shape, after_nm_original[1].cpu().shape)
class TestResetParameters(DistributedTest):
def test_reset_parameters(self):
"""Ensure that reduce_scatter_process_group same size with the world size."""
test_fn = functools.partial(self._test_reset, config={})
spawn_and_init(test_fn, world_sizes=[2])
@classmethod
def _test_reset(self, rank, group, config):
model = self._get_model(group, config)
with model.summon_full_params():
model.reset_parameters()
@classmethod
def _get_model(self, group, config):
with torch.no_grad(): # required for multiprocessing
model = nn.Linear(10, 10)
return FullyShardedDataParallel(model, group, allow_reset_parameters=True, **config)
class TransformerWithSharedParams(nn.Module): class TransformerWithSharedParams(nn.Module):
def __init__(self, group, *unused_args, d_vocab=23, d_model=16, add_bn=True, **unused_kwargs): def __init__(self, group, *unused_args, d_vocab=23, d_model=16, add_bn=True, **unused_kwargs):
super().__init__() super().__init__()
......
...@@ -189,10 +189,10 @@ class TestGradAccCommunication(DistributedTest): ...@@ -189,10 +189,10 @@ class TestGradAccCommunication(DistributedTest):
# the sum of the _base and public methods should stay the same. # the sum of the _base and public methods should stay the same.
assert ( assert (
mock_all_gather.call_count + mock_all_gather_base.call_count == expected_all_gather1 mock_all_gather.call_count + mock_all_gather_base.call_count == expected_all_gather1
), f"{mock_all_gather.call_count + mock_all_gather_base.call_count} != {expected_all_gather1}" ), f"{mock_all_gather.call_count} + {mock_all_gather_base.call_count} != {expected_all_gather1}"
assert ( assert (
mock_reduce_scatter.call_count + mock_reduce_scatter_base.call_count == 0 mock_reduce_scatter.call_count + mock_reduce_scatter_base.call_count == 0
), f"{mock_reduce_scatter.call_count + mock_reduce_scatter_base.call_count} != 0" ), f"{mock_reduce_scatter.call_count} + {mock_reduce_scatter_base.call_count} != 0"
output = model(*batch) output = model(*batch)
loss = model.module.get_loss(batch, output) loss = model.module.get_loss(batch, output)
...@@ -200,11 +200,11 @@ class TestGradAccCommunication(DistributedTest): ...@@ -200,11 +200,11 @@ class TestGradAccCommunication(DistributedTest):
assert ( assert (
mock_all_gather.call_count + mock_all_gather_base.call_count == expected_all_gather2 mock_all_gather.call_count + mock_all_gather_base.call_count == expected_all_gather2
), f"{mock_all_gather.call_count + mock_all_gather_base.call_count} != {expected_all_gather2}" ), f"{mock_all_gather.call_count} + {mock_all_gather_base.call_count} != {expected_all_gather2}"
assert ( assert (
mock_reduce_scatter.call_count + mock_reduce_scatter_base.call_count mock_reduce_scatter.call_count + mock_reduce_scatter_base.call_count
== expected_reduce_scatter == expected_reduce_scatter
), f"{mock_reduce_scatter.call_count + mock_reduce_scatter_base.call_count} != {expected_reduce_scatter}" ), f"{mock_reduce_scatter.call_count} + {mock_reduce_scatter_base.call_count} != {expected_reduce_scatter}"
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -16,6 +16,8 @@ import torch ...@@ -16,6 +16,8 @@ import torch
from torch import nn from torch import nn
import torch.distributed import torch.distributed
pytestmark = pytest.mark.skip(reason="ssd offload to be removed to simplify the code")
try: try:
import fairscale.experimental.nn.ssd_offload as so import fairscale.experimental.nn.ssd_offload as so
except ImportError as ie: except ImportError as ie:
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# #
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import copy
import functools import functools
from time import time from time import time
import unittest import unittest
...@@ -13,7 +14,7 @@ from torch.optim import SGD, Adadelta, Adam # type: ignore ...@@ -13,7 +14,7 @@ from torch.optim import SGD, Adadelta, Adam # type: ignore
from fair_dev.testing.testing import dist_init, objects_are_equal, spawn_for_all_world_sizes from fair_dev.testing.testing import dist_init, objects_are_equal, spawn_for_all_world_sizes
from fairscale.internal.params import recursive_copy_to_device from fairscale.internal.params import recursive_copy_to_device
from fairscale.nn import FullyShardedDataParallel from fairscale.nn.data_parallel import FullyShardedDataParallel, get_fsdp_instances
from fairscale.nn.data_parallel.fsdp_optim_utils import is_singleton_tensor from fairscale.nn.data_parallel.fsdp_optim_utils import is_singleton_tensor
from .test_fsdp import ( from .test_fsdp import (
...@@ -158,9 +159,9 @@ class TestOptimizerUtils(DistributedTest): ...@@ -158,9 +159,9 @@ class TestOptimizerUtils(DistributedTest):
unwrapped_sd = optim_unwrapped.state_dict() unwrapped_sd = optim_unwrapped.state_dict()
if not transformer and not expert_group: if not transformer and not expert_group:
no_broadcast_children = [x for x in fsdp._fsdp_instances() if x.no_broadcast_optim_state] no_broadcast_children = [x for x in get_fsdp_instances(fsdp) if x.no_broadcast_optim_state]
assert len(no_broadcast_children) == 1, f"Length of non shared params {len(no_broadcast_children)}" assert len(no_broadcast_children) == 1, f"Length of non shared params {len(no_broadcast_children)}"
assert fsdp._fsdp_instances()[-1].no_broadcast_optim_state assert get_fsdp_instances(fsdp)[-1].no_broadcast_optim_state
torch.cuda.empty_cache() torch.cuda.empty_cache()
cuda_gb_before = torch.cuda.memory_stats(fsdp.rank)["allocated_bytes.all.current"] / 1024**3 cuda_gb_before = torch.cuda.memory_stats(fsdp.rank)["allocated_bytes.all.current"] / 1024**3
tstart = time() tstart = time()
...@@ -196,12 +197,15 @@ class TestOptimizerUtils(DistributedTest): ...@@ -196,12 +197,15 @@ class TestOptimizerUtils(DistributedTest):
) )
return return
unflat_state = sd["state"]
assert "uncollected_local_ids" in sd assert "uncollected_local_ids" in sd
shard_sd = fsdp.get_shard_from_optim_state_dict(sd) sd_copy = copy.deepcopy(sd)
unflat_state = sd_copy["state"]
shard_sd = fsdp.get_shard_from_optim_state_dict(sd_copy)
shard_sd = recursive_copy_to_device(shard_sd, non_blocking=False, device="cpu") shard_sd = recursive_copy_to_device(shard_sd, non_blocking=False, device="cpu")
state_after_get_shard = sd["state"] state_after_get_shard = sd_copy["state"]
assert objects_are_equal(unflat_state, state_after_get_shard) # no side effects. # sd is changed in-place in case there are extra states.
assert not objects_are_equal(unflat_state, state_after_get_shard)
del sd_copy
assert_equal(len(sd["state"]), len(unwrapped_sd["state"])) assert_equal(len(sd["state"]), len(unwrapped_sd["state"]))
assert_equal(len(sd["param_groups"][0]["params"]), len(unwrapped_sd["param_groups"][0]["params"])) assert_equal(len(sd["param_groups"][0]["params"]), len(unwrapped_sd["param_groups"][0]["params"]))
...@@ -223,8 +227,8 @@ class TestOptimizerUtils(DistributedTest): ...@@ -223,8 +227,8 @@ class TestOptimizerUtils(DistributedTest):
[v for k, v in shard_sd["param_groups"][0].items()], [v for k, v in shard_sd["param_groups"][0].items()],
[v for k, v in original_shard_sd["param_groups"][0].items()], [v for k, v in original_shard_sd["param_groups"][0].items()],
) )
assert objects_are_equal(shard_sd["state"], original_shard_sd["state"]) objects_are_equal(shard_sd["state"], original_shard_sd["state"], raise_exception=True)
assert objects_are_equal({k: shard_sd[k] for k in original_shard_sd}, original_shard_sd) objects_are_equal({k: shard_sd[k] for k in original_shard_sd}, original_shard_sd, raise_exception=True)
@parameterized.expand( @parameterized.expand(
[(True,), (False,)], [(True,), (False,)],
...@@ -260,7 +264,7 @@ class TestOptimizerUtils(DistributedTest): ...@@ -260,7 +264,7 @@ class TestOptimizerUtils(DistributedTest):
model = TransformerWithSharedParams(group) model = TransformerWithSharedParams(group)
named_pars = [p for n, p in model.named_parameters()] named_pars = [p for n, p in model.named_parameters()]
for i, p in enumerate(model.parameters()): for i, p in enumerate(model.parameters()):
assert objects_are_equal(p, named_pars[i]) objects_are_equal(p, named_pars[i], raise_exception=True)
def test_is_singleton_tensor(self): def test_is_singleton_tensor(self):
"""Test is_singleton_tensor function""" """Test is_singleton_tensor function"""
......
...@@ -158,7 +158,8 @@ def _dist_worker(rank, world_size, files, wrap_middle, test_fn): ...@@ -158,7 +158,8 @@ def _dist_worker(rank, world_size, files, wrap_middle, test_fn):
# We don't raise exceptions in CI since CI's T4 machine seems to be flaky with this test. # We don't raise exceptions in CI since CI's T4 machine seems to be flaky with this test.
# On devel machines, we do want to catch potential errors. There could be real bugs or # On devel machines, we do want to catch potential errors. There could be real bugs or
# system issues behind the flakiness. One example is all-reduce vs. simulated averaging # system issues behind the flakiness. One example is all-reduce vs. simulated averaging
# below. # below. The check also fails on my rtx 20xx. So maybe it only works on devfair with
# Quadro GP100 GPUs. TODO (Min): debug this.
objects_are_equal(sd_after, fsdp_model.state_dict(), raise_exception=not in_circle_ci()) objects_are_equal(sd_after, fsdp_model.state_dict(), raise_exception=not in_circle_ci())
elif test_fn == "eval": elif test_fn == "eval":
_eval(fsdp_model, in_data) _eval(fsdp_model, in_data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment