Unverified Commit d8fc94d9 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[bug] fix optim state gather when there is empty FSDP instances (#1071)

* [bug] fix optim state gather when there is empty FSDP instances

* fixes an anssert and a test bug
parent 203dd668
......@@ -1343,11 +1343,10 @@ class FullyShardedDataParallel(nn.Module):
for n, m in self.named_modules():
# `n != ""` excludes self.
if n != "" and isinstance(m, FullyShardedDataParallel):
# We relax the assert for non-root instance, when the nested inialized module is wrapped
# again in FSDP later, for example after training to run inference.
assert m._is_root is None or not m._is_root, f"offending FSDP instance is {id(m)}, {m}"
if m._is_root is None:
m._is_root = False
# We set inner FSDP to non-root but they could have the value of True
# if, for example, a module is called first (like infernece, EMA)
# then later we call an outer FSDP for state dict load/save.
m._is_root = False
if m.process_group != self.process_group:
self.children_share_process_group = False
......@@ -2277,9 +2276,11 @@ class FullyShardedDataParallel(nn.Module):
raise ValueError(msg)
def _broadcast_pad_info_to_r0(self) -> List[List[List[int]]]:
"""Collect [x.numel_padded_per_param for x in self._fsdp_instances] from each rank."""
"""Collect [x.numel_padded_per_param for x in self._fsdp_instances()] from each rank."""
world_pad_info: List[List[List[int]]] = [] # this will contain values from the whole world.
my_pad_info: List[List[int]] = [cast(List[int], m.numel_padded_per_param) for m in self._fsdp_instances]
my_pad_info: List[List[int]] = [
cast(List[int], m.numel_padded_per_param) for m in self._fsdp_instances(skip_empty=True)
]
for rank in range(self.world_size):
if rank == self.rank:
pad_info = my_pad_info
......@@ -2296,24 +2297,31 @@ class FullyShardedDataParallel(nn.Module):
"""For each value in state[i], if the value is a tensor, collect it from the world. Else use rank 0's entry."""
gathered_state: Dict[int, Dict[str, List[Any]]] = {}
singleton_state: Dict[int, Dict[str, List[Any]]] = {} # Dimensionless tensor
# Non-empty FSDP instance and sd_state item number must match.
fsdp_instances = self._fsdp_instances(skip_empty=True)
assert len(fsdp_instances) >= len(sd_state), f"{len(fsdp_instances)} vs. {len(sd_state)}"
for k, v in sd_state.items():
gathered_state[k] = {}
singleton_state[k] = {}
# For shared params, we are not flattening. We have only 1 non-shared
# param that has the optimizer state. So we handle it with the correct
# parameter list.
non_shared_params = cast(FullyShardedDataParallel, self._fsdp_instances[k]).non_shared_params()
non_shared_params = fsdp_instances[k].non_shared_params()
# This is the world size and process group of the FSDP submodule which can be
# different than the parent module. For example, when FSDP is used with MoE.
non_shared_world_size = self._fsdp_instances[k].world_size
non_shared_process_group = self._fsdp_instances[k].process_group
non_shared_world_size = fsdp_instances[k].world_size
non_shared_process_group = fsdp_instances[k].process_group
assert (
len(non_shared_params) == 1
), f"Only flatten param or a single non-shared param is supported: len={len(non_shared_params)}"
), f"Only flatten param or a single non-shared param is supported: len={len(non_shared_params)} FSDP={self}"
desired_buffer_size = non_shared_params[0]._full_param_padded.size()
buffer = None # for sharded tensors
singleton_buffer = None # for singleton tensors
for buffer_name, t in v.items():
if torch.is_tensor(t):
t = t.to(self.compute_device)
......@@ -2370,16 +2378,23 @@ class FullyShardedDataParallel(nn.Module):
return None
# Unify the shard states by concatenating tensors and unflattening params
new_state_dict = ou.build_unflat_state_dict(
self._fsdp_instances, pad_info, state, singleton_state, self.uncollected_opt_state, sd["param_groups"]
self._fsdp_instances(skip_empty=True),
pad_info,
state,
singleton_state,
self.uncollected_opt_state,
sd["param_groups"],
)
self.uncollected_opt_state = {}
assert "uncollected_local_ids" in new_state_dict
return new_state_dict
@property
def _fsdp_instances(self) -> List["FullyShardedDataParallel"]:
def _fsdp_instances(self, skip_empty: bool = False) -> List["FullyShardedDataParallel"]:
"""Returns all fsdp modules in self.modules() including self."""
return [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)]
result = [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)]
if skip_empty:
result = list(filter(lambda x: len(cast(FullyShardedDataParallel, x).non_shared_params()) > 0, result))
return result
def _remove_uncollectable_params_from_optim_state_dict(self, osd: Dict) -> Dict:
"""Return a new state dict filtering out the ones like MoE layers, which has
......@@ -2396,7 +2411,7 @@ class FullyShardedDataParallel(nn.Module):
if ou.is_singleton_tensor(bufs["step"]):
bufs["step"] = bufs["step"].item()
# Get uncollected_ids.
uncollected_ids = [i for i, m in enumerate(self._fsdp_instances) if m.no_broadcast_optim_state]
uncollected_ids = [i for i, m in enumerate(self._fsdp_instances()) if m.no_broadcast_optim_state]
new_dct = {"state": {k: v for k, v in osd["state"].items() if k not in uncollected_ids}}
if self.rank == 0:
# Save placeholders for uncollected opt state to keep the same unflat OSD format, and move them to CPU.
......@@ -2423,7 +2438,7 @@ class FullyShardedDataParallel(nn.Module):
(dict): a shard of the optimizer state.
"""
# Assert nesting is the same as it was at save time
instance_list = self._fsdp_instances
instance_list = self._fsdp_instances()
ou.check_param_counts_before_sharding(full_optim_state_dict, len(instance_list))
ids_not_to_shard = copy.deepcopy(full_optim_state_dict["uncollected_local_ids"])
if self.flatten_parameters:
......
......@@ -158,9 +158,9 @@ class TestOptimizerUtils(DistributedTest):
unwrapped_sd = optim_unwrapped.state_dict()
if not transformer and not expert_group:
no_broadcast_children = [x for x in fsdp._fsdp_instances if x.no_broadcast_optim_state]
no_broadcast_children = [x for x in fsdp._fsdp_instances() if x.no_broadcast_optim_state]
assert len(no_broadcast_children) == 1, f"Length of non shared params {len(no_broadcast_children)}"
assert fsdp._fsdp_instances[-1].no_broadcast_optim_state
assert fsdp._fsdp_instances()[-1].no_broadcast_optim_state
torch.cuda.empty_cache()
cuda_gb_before = torch.cuda.memory_stats(fsdp.rank)["allocated_bytes.all.current"] / 1024**3
tstart = time()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment