"vscode:/vscode.git/clone" did not exist on "e489abc684c864f1da010d56e3ca66dbd2df82fb"
Unverified Commit a82825db authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[FSDP] use all_gather for 10X OSD consolidation speedup (#595)

parent 4726d5be
......@@ -4,10 +4,12 @@
# LICENSE file in the root directory of this source tree.
"""These functions are used by FullyShardedDataParallel to help consolidate and shard optimizer states."""
import copy
from typing import Dict, Generator, List, Tuple
from typing import Any, Dict, Generator, List, Tuple
import torch
# These return keys are used by fairseq. To change, add @sshleifer as a reviewer.
UNFLAT_RETURN_KEYS = {"state", "param_groups", "uncollected_local_ids", "param_id_map"}
# This function helps shard a full optimizer state dict
def flatten_optim_state_dict(sd: Dict) -> Dict:
......@@ -16,6 +18,7 @@ def flatten_optim_state_dict(sd: Dict) -> Dict:
num_local_params = len(set(param_id_map.values()))
if sd["state"]:
new_state: Dict = {local_id: {} for local_id in range(num_local_params)}
singleton_state: Dict = copy.deepcopy(new_state)
else:
new_state = {}
non_tensor_state = {}
......@@ -24,19 +27,26 @@ def flatten_optim_state_dict(sd: Dict) -> Dict:
for global_id, buffers in sd["state"].items():
local_id = param_id_map[global_id]
for buffer_name, p in buffers.items():
if torch.is_tensor(p):
if is_singleton_tensor(p):
singleton_state[local_id][buffer_name] = p
elif torch.is_tensor(p):
if buffer_name not in new_state[local_id]:
new_state[local_id][buffer_name] = []
new_state[local_id][buffer_name].append(p.reshape(-1))
elif isinstance(p, list):
singleton_state[local_id][buffer_name] = p
else:
non_tensor_state[buffer_name] = p
# Now combine all tensors in each buffer using torch.cat().
for local_id, state in new_state.items():
for buffer_name, tensors in state.items():
new_state[local_id][buffer_name] = torch.cat(tensors)
new_state[local_id].update(non_tensor_state)
new_state[local_id].update(singleton_state[local_id])
new_sd = {"state": new_state, "param_groups": copy.deepcopy(sd["param_groups"])}
for k in sd.keys(): # if there are extra keys, like loss_scale, don't delete them
if k not in UNFLAT_RETURN_KEYS:
new_sd[k] = copy.deepcopy(sd[k])
# add pointers from the `params` dict.
for pg_id, _ in enumerate(sd["param_groups"]):
......@@ -70,22 +80,11 @@ def _extract_non_tensor_state(combined_state: Dict[int, Dict[str, List]], param_
return non_tensor_state
def _combine_state(states: List[Dict]) -> Dict[int, Dict]:
combined_state = states[0]
for param_id in combined_state:
combined_state[param_id] = {k: [v] for k, v in combined_state[param_id].items()}
if len(states) == 1:
return combined_state
for rank, s in enumerate(states[1:]):
for param_id, param_state in s.items():
for k, tensor in param_state.items():
combined_state[param_id][k].append(tensor)
return combined_state
def _unflatten_optim_state(
combined_state: Dict[int, Dict], instance_list: List[torch.nn.Module], world_pad_info: List[List[List[int]]],
combined_state: Dict[int, Dict],
instance_list: List[torch.nn.Module],
world_pad_info: List[List[List[int]]],
singleton_state: Dict[int, Dict],
) -> Tuple[Dict[int, Dict], Dict[int, int]]:
# local ids are the keys in the current state (combined_state), (usually fewer)
# global ids will be the keys in the unflattened state
......@@ -98,17 +97,17 @@ def _unflatten_optim_state(
non_tensor_state = [_extract_non_tensor_state(combined_state, id) for id in combined_state]
# local corresponds to flattened, global corresponds to unflattened
num_unflat_params = [len(m._param_numels) for m in instance_list] # type: ignore
num_global_params = [len(m._param_numels) for m in instance_list] # type: ignore
global_to_local_id = {}
for local_id, num_unflat in enumerate(num_unflat_params):
for local_id, num_unflat in enumerate(num_global_params):
for _ in range(num_unflat):
global_to_local_id[next_global_id] = local_id
next_global_id += 1
if not combined_state:
return {}, global_to_local_id
# If the constant state is the same as the combined state, copy it N times, no unflattening needed.
unflat_state = {i: copy.deepcopy(non_tensor_state[0]) for i in range(sum(num_unflat_params))}
# copy non tensor state to all global entries
unflat_state = {i: copy.deepcopy(non_tensor_state[0]) for i in range(sum(num_global_params))}
if non_tensor_state[0].keys() == combined_state[0].keys():
return unflat_state, global_to_local_id
......@@ -131,37 +130,44 @@ def _unflatten_optim_state(
for global_id, param_view in zip(sorted(local_to_global[local_id]), param_views):
assert k not in unflat_state[global_id], f"already added {k} to {global_id} {local_id}"
unflat_state[global_id][k] = param_view
unflat_state[global_id].update(singleton_state[local_id])
return unflat_state, global_to_local_id
def build_unflat_state_dict(
instance_list: List[torch.nn.Module], world_optim_states: List[Dict], uncollected_opt_state: Dict[int, Dict]
instance_list: List[torch.nn.Module],
world_pad_info: List[List[List[int]]],
state: Dict[int, Dict[str, List[torch.Tensor]]],
singleton_state: Dict[int, Dict[str, List[torch.Tensor]]],
uncollected_opt_state: Dict[int, Dict],
param_groups: List[Dict],
) -> Dict:
"""Build an unflattened optimizer state dict given a list of flattened optimizer state dicts from each rank."""
world_pad_info: List[List[List[int]]] = [s.pop("num_padded") for s in world_optim_states]
assert all(len(s) == len(instance_list) for s in world_pad_info)
assert all(len(s[0]) == 1 for s in world_pad_info)
# Since there are no tensors in param_groups, deepcopy is fine
param_groups = copy.deepcopy(world_optim_states[0]["param_groups"])
assert len(param_groups) == 1
# Aggregate from a list of dictionaries to a dictionary of lists
combined_state = _combine_state([x["state"] for x in world_optim_states])
# Use uncollected_opt_state to update tensor_state, singleton_state
for local_id, v in uncollected_opt_state.items():
assert local_id not in combined_state
combined_state[local_id] = {}
for buffer_name, tensor in v.items():
combined_state[local_id][buffer_name] = [tensor]
del world_optim_states
assert local_id not in state
state[local_id] = {buffer_name: [x] for buffer_name, x in v.items() if not is_singleton_tensor(x)}
singleton_state[local_id] = {buffer_name: [x] for buffer_name, x in v.items() if is_singleton_tensor(x)}
# local ids are in the current state, global_ids will be in returned state.
unflat_state, global_to_local_id = _unflatten_optim_state(combined_state, instance_list, world_pad_info)
unflat_state, global_to_local_id = _unflatten_optim_state(state, instance_list, world_pad_info, singleton_state)
# Since there are no tensors in param_groups, deepcopy is fine
param_groups = copy.deepcopy(param_groups)
num_params = sum([len(m._param_numels) for m in instance_list]) # type: ignore
param_groups[0]["params"] = list(range(num_params))
return {
unflat_optim_state_dict = {
"state": dict(sorted(unflat_state.items())), # NOTE: this is probably already sorted
"param_id_map": global_to_local_id,
"param_groups": param_groups,
"uncollected_local_ids": list(uncollected_opt_state.keys()),
}
assert set(unflat_optim_state_dict.keys()) == UNFLAT_RETURN_KEYS
return unflat_optim_state_dict
def is_singleton_tensor(x: Any) -> bool:
"""Is x a dimensionless tensor?"""
return torch.is_tensor(x) and x.dim() == 0
......@@ -1382,70 +1382,88 @@ class FullyShardedDataParallel(nn.Module):
traceback.print_stack()
raise ValueError(msg)
def _consolidate_optim_state_dict(
self, optim: torch.optim.Optimizer, recipient_rank: Optional[int] = None
) -> List[Dict]:
"""Update the consolidated state_dict list, one per rank.
Args:
optim (Optimizer): an optimizer instance for this FSDP rank. Its state is
used in the consolidation. However, its state is not modified.
recipient_rank (int): on which rank to materialize the full state dict.
None is a special value, which means that all ranks should have the state
Returns:
all_states (list[dict]) the optimizer state from each rank
.. warning: This needs to be called on all replicas"""
self._lazy_init()
# NOTE(SS): we do not support param groups yet, as they seem to break FSDP
# Pull the sharded state from all the other replicas
# Store all the states in order, rank by rank
should_collect_state = recipient_rank is None or (self.rank == recipient_rank)
all_states: List[Dict[str, Any]] = []
def _broadcast_pad_info_to_r0(self) -> List[List[List[int]]]:
"""Collect [x.numel_padded_per_param for x in self._fsdp_instances] from teach rank."""
dummy_tensor = torch.tensor([0], dtype=torch.uint8, device=self.compute_device)
world_pad_info: List[List[List[int]]] = [] # this will contain values from the whole world.
for rank in range(self.world_size):
if rank == self.rank:
sd = self._remove_uncollectable_params_from_optim_state_dict(optim.state_dict())
sd["num_padded"] = [m.numel_padded_per_param for m in self._fsdp_instances]
pad_info = [m.numel_padded_per_param for m in self._fsdp_instances]
else:
sd = dummy_tensor # type: ignore
sd = broadcast_object(sd, src_rank=rank, group=self.process_group, dist_device=self.compute_device)
if should_collect_state:
assert isinstance(sd, dict), f"{self.rank} received {type(sd)} from {rank}, expected dict"
all_states.append(recursive_copy_to_device(sd, non_blocking=False, device=torch.device("cpu")))
return all_states
def gather_full_optim_state_dict(
self, optim: torch.optim.Optimizer, recipient_rank: Optional[int] = 0
) -> Optional[Dict[str, Any]]:
pad_info = dummy_tensor # type: ignore
pad_info = broadcast_object(
pad_info, src_rank=rank, group=self.process_group, dist_device=self.compute_device
)
if self.rank == 0:
world_pad_info.append(pad_info) # type: ignore
return world_pad_info
def _gather_optim_state(
self, sd_state: Dict[int, Dict[str, Any]]
) -> Tuple[Dict[int, Dict[str, List]], Dict[int, Dict[str, List]]]:
"""For each value in state[i], if the value is a tensor, collect it from the world. Else use rank 0's entry."""
gathered_state: Dict[int, Dict[str, List[Any]]] = {}
singleton_state: Dict[int, Dict[str, List[Any]]] = {} # Dimensionless tensor
for k, v in sd_state.items():
gathered_state[k] = {}
singleton_state[k] = {}
desired_buffer_size = self._fsdp_instances[k].flat_param._full_param_padded.size() # type: ignore
buffer = None # for sharded tensors
singleton_buffer = None # for singleton tensors
for buffer_name, t in v.items():
if ou.is_singleton_tensor(t):
if singleton_buffer is None:
singleton_buffer = list(t.new_zeros(self.world_size).chunk(self.world_size))
dist.all_gather(singleton_buffer, t, group=self.process_group)
if self.rank == 0:
singleton_state[k][buffer_name] = [x.cpu().squeeze() for x in singleton_buffer]
assert ou.is_singleton_tensor(singleton_state[k][buffer_name][0])
elif torch.is_tensor(t):
if buffer is None:
buffer = list(t.new_zeros(*desired_buffer_size).chunk(self.world_size))
dist.all_gather(buffer, t, group=self.process_group)
if self.rank == 0:
gathered_state[k][buffer_name] = [x.cpu() for x in buffer]
elif self.rank == 0: # Add non tensor state
gathered_state[k][buffer_name] = [t]
return gathered_state, singleton_state
def gather_full_optim_state_dict(self, optim: torch.optim.Optimizer, **ignored: Dict) -> Optional[Dict[str, Any]]:
"""Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the
sharded properties are not exposed. Multiple parameter groups are not yet supported.
This should be called only on the root FSDP instance.
Nested FSDP instances are supported as long as they have the same world_size as the parent or world_size=1.
Different world_size groups in nested FSDP instances is not supported.
Args:
optim (Optimizer): an optimizer instance for this FSDP rank. Its state is
used in the consolidation. However, its state is not modified.
recipient_rank (int): on which rank to materialize the full state dict.
optim (Optimizer): an optimizer instance for this FSDP rank. Its state_dict is
used in the consolidation. However, its state is not modified.
Returns:
a dict with two entries
* A dict with four entries (On rank zero, other workers return ``None``)
* state - a dict holding gathered optimization state, 1 entry per unflat parameter
* param_groups - a dict containing the 1 parameter group
* param_id_map - global (unflat) to local (flat) id mapping
* uncollected_local_ids - keys in the state dict that were not broadcast
"""
if not self.flatten_parameters:
raise NotImplementedError("optim state dict requires flatten_parameters=True")
world_optim_states = self._consolidate_optim_state_dict(optim, recipient_rank)
if self.rank != recipient_rank and recipient_rank is not None:
self._lazy_init()
sd = self._remove_uncollectable_params_from_optim_state_dict(optim.state_dict())
assert set(sd.keys()) == {"param_groups", "state"}, f'{set(sd.keys())} != {"param_groups", "state"}'
assert len(sd["param_groups"]) == 1, "Param groups are not supported"
# We use all_gather to consolidate OSD['state'] and broadcast to consolidate the other keys (like param_groups)
state, singleton_state = self._gather_optim_state(sd.pop("state"))
pad_info = self._broadcast_pad_info_to_r0()
if self.rank != 0:
return None
# Unify the shard states by concatenating tensors and unflattening params
new_state_dict = ou.build_unflat_state_dict(
self._fsdp_instances, world_optim_states, self.uncollected_opt_state
self._fsdp_instances, pad_info, state, singleton_state, self.uncollected_opt_state, sd["param_groups"]
)
self.uncollected_opt_state = {}
assert "uncollected_local_ids" in new_state_dict
......@@ -1499,14 +1517,20 @@ class FullyShardedDataParallel(nn.Module):
for k, v in s.items():
if torch.is_tensor(v) and id not in ids_not_to_shard:
v_shard, _ = self._get_shard(v)
elif isinstance(v, list) and ou.is_singleton_tensor(v[0]):
# if we are resuming on larger world size, take first entry
v_shard = v[0] if self.rank >= len(v) else v[self.rank]
assert ou.is_singleton_tensor(v_shard)
else:
v_shard = v # dont shard entries that are not tensors
full_optim_state_dict["state"][id][k] = v_shard
return full_optim_state_dict
def _print_r0(self, msg: str) -> None:
def _print_r0(self, msg: str, restart: bool = False) -> None:
"""Debugging utility to print memory usage stats nicely on rank 0"""
if restart:
self._tstart = time.time()
if self.rank == 0:
gb_denom = 1024 ** 3
print(
......
......@@ -627,15 +627,19 @@ class MixtureOfExperts(NestedWrappedModule):
# "expert" params are different on each rank
torch.manual_seed(42 + group.rank())
d_expert = 16
expert = nn.Linear(d_expert, 4)
d_expert = 23
d_shared = 12
d_input = 8
expert = nn.Linear(d_expert, d_shared)
self.num_expert_params = sum([p.numel() for p in expert.parameters()])
for p in expert.parameters():
p.expert = True
# everything else is shared
torch.manual_seed(0)
shared = nn.Linear(4, d_expert)
shared = nn.Linear(d_shared, d_expert)
if checkpoint_act:
expert = checkpoint_wrapper(expert)
......@@ -648,7 +652,7 @@ class MixtureOfExperts(NestedWrappedModule):
shared = FullyShardedDataParallel(shared, group, **wrapper_config)
self.module = nn.Sequential(nn.Linear(8, 4), shared, expert, nn.Linear(4, 8))
self.module = nn.Sequential(nn.Linear(d_input, d_shared), shared, expert, nn.Linear(d_shared, d_input))
def forward(self, x):
if self.delay_before_free_ms > 0:
......
......@@ -10,6 +10,7 @@ import torch
from torch.optim import SGD, Adadelta, Adam # type: ignore
from fairscale.nn import FullyShardedDataParallel
from fairscale.nn.data_parallel.fsdp_optim_utils import is_singleton_tensor
from fairscale.optim.utils import recursive_copy_to_device
from fairscale.utils.testing import objects_are_equal
......@@ -147,3 +148,10 @@ class TestOptimizerUtils(DistributedTest):
named_pars = [p for n, p in model.named_parameters()]
for i, p in enumerate(model.parameters()):
assert objects_are_equal(p, named_pars[i])
def test_is_singleton_tensor(self):
assert is_singleton_tensor(torch.tensor(4.0))
assert not is_singleton_tensor(torch.tensor([4.0]))
assert not is_singleton_tensor(torch.tensor([4.0, 5.0]))
assert not is_singleton_tensor([4.0])
assert not is_singleton_tensor(4.0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment