Unverified Commit 1fcbd624 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[FSDP] add no_broadcast_optim_state option (#560)

parent 54a97ee5
......@@ -21,22 +21,22 @@ def flatten_optim_state_dict(sd: Dict) -> Dict:
non_tensor_state = {}
# Populate `new_state["state"]`. (Assuming sd is sorted)
for expanded_pid, buffers in sd["state"].items():
consolidated_pid = param_id_map[expanded_pid]
for global_id, buffers in sd["state"].items():
local_id = param_id_map[global_id]
for buffer_name, p in buffers.items():
if torch.is_tensor(p):
if buffer_name not in new_state[consolidated_pid]:
new_state[consolidated_pid][buffer_name] = []
new_state[consolidated_pid][buffer_name].append(p.reshape(-1))
if buffer_name not in new_state[local_id]:
new_state[local_id][buffer_name] = []
new_state[local_id][buffer_name].append(p.reshape(-1))
else:
non_tensor_state[buffer_name] = p
# Now combine all tensors in each buffer using torch.cat().
for consolidated_pid, state in new_state.items():
for local_id, state in new_state.items():
for buffer_name, tensors in state.items():
new_state[consolidated_pid][buffer_name] = torch.cat(tensors)
new_state[consolidated_pid].update(non_tensor_state)
new_sd = {"state": new_state, "param_groups": sd["param_groups"]}
new_state[local_id][buffer_name] = torch.cat(tensors)
new_state[local_id].update(non_tensor_state)
new_sd = {"state": new_state, "param_groups": copy.deepcopy(sd["param_groups"])}
# add pointers from the `params` dict.
for pg_id, _ in enumerate(sd["param_groups"]):
......@@ -109,6 +109,7 @@ def _unflatten_optim_state(
# If the constant state is the same as the combined state, copy it N times, no unflattening needed.
unflat_state = {i: copy.deepcopy(non_tensor_state[0]) for i in range(sum(num_unflat_params))}
if non_tensor_state[0].keys() == combined_state[0].keys():
return unflat_state, global_to_local_id
......@@ -134,24 +135,33 @@ def _unflatten_optim_state(
return unflat_state, global_to_local_id
def build_unflat_state_dict(instance_list: List[torch.nn.Module], world_optim_states: List[Dict]) -> Dict:
def build_unflat_state_dict(
instance_list: List[torch.nn.Module], world_optim_states: List[Dict], uncollected_opt_state: Dict[int, Dict]
) -> Dict:
"""Build an unflattened optimizer state dict given a list of flattened optimizer state dicts from each rank."""
world_pad_info: List[List[List[int]]] = [s.pop("num_padded") for s in world_optim_states]
assert all(len(s) == len(instance_list) for s in world_pad_info)
assert all(len(s[0]) == 1 for s in world_pad_info)
# Since there are no tensors in param_groups, deepcopy is fine
param_groups = copy.deepcopy(world_optim_states[0]["param_groups"])
assert len(param_groups) == 1
# Aggregate from a list of dictionaries to a dictionary of lists
combined_state = _combine_state([x["state"] for x in world_optim_states])
for local_id, v in uncollected_opt_state.items():
assert local_id not in combined_state
combined_state[local_id] = {}
for buffer_name, tensor in v.items():
combined_state[local_id][buffer_name] = [tensor]
del world_optim_states
# local ids are in the current state, global_ids will be in returned state.
unflat_state, global_to_local_id = _unflatten_optim_state(combined_state, instance_list, world_pad_info)
num_params = sum([len(m._param_numels) for m in instance_list]) # type: ignore
param_groups[0]["params"] = list(range(num_params)) # This could be a large list. #TODO: is it essential
param_groups[0]["params"] = list(range(num_params))
return {
"state": dict(sorted(unflat_state.items())), # NOTE: this is probably already sorted
"param_id_map": global_to_local_id,
"param_groups": param_groups,
"uncollected_local_ids": list(uncollected_opt_state.keys()),
}
......@@ -157,6 +157,12 @@ class FullyShardedDataParallel(nn.Module):
device, the param's device will be used. If not given and module
params are on CPU, then the current CUDA device (as indicated by
``torch.cuda.current_device()`` will be used.
no_broadcast_optim_state: (bool, Optional)
do not broadcast this modules optimizer state when ``gather_full_optim_state_dict`` is called.
If you set this true, you are expected to overwrite the relevant state entries of the returned optimizer state dict
with the proper state at each rank. This is useful for situations, like Mixture Of Experts,
where all but a few parameters can fit on one node.
Default: False
"""
def __init__(
......@@ -173,6 +179,7 @@ class FullyShardedDataParallel(nn.Module):
move_grads_to_cpu: Optional[bool] = None,
bucket_cap_mb: int = 25,
compute_device: Optional[torch.device] = None,
no_broadcast_optim_state: Optional[bool] = False,
):
super().__init__()
self.process_group = process_group or dist.new_group()
......@@ -187,6 +194,8 @@ class FullyShardedDataParallel(nn.Module):
self.buffer_dtype = buffer_dtype or self.compute_dtype
self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu
self.bucket_cap_mb = bucket_cap_mb
self.uncollected_opt_state: Dict[int, Dict] = {}
self.no_broadcast_optim_state = no_broadcast_optim_state
self.gradient_predivide_factor: int = self.get_gradient_predivide_factor(self.world_size)
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
......@@ -849,6 +858,12 @@ class FullyShardedDataParallel(nn.Module):
if m.process_group != self.process_group:
self.children_share_process_group = False
# if child instance in its own (smaller) world, that was probably an attempt to avoid OOM.
# Therefore gathering this child's optim state will probably cause OOM, so we won't do it.
m.no_broadcast_optim_state = m.no_broadcast_optim_state or (
(m.world_size == 1) and (m.world_size < self.world_size) and (m.process_group != self.process_group)
)
def _setup_streams(self) -> None:
"""Create streams to overlap data transfer and computation."""
if len(self._streams) > 0 or not self._is_root:
......@@ -1391,7 +1406,7 @@ class FullyShardedDataParallel(nn.Module):
dummy_tensor = torch.tensor([0], dtype=torch.uint8, device=self.compute_device)
for rank in range(self.world_size):
if rank == self.rank:
sd = optim.state_dict()
sd = self._remove_uncollectable_params_from_optim_state_dict(optim.state_dict())
sd["num_padded"] = [m.numel_padded_per_param for m in self._fsdp_instances]
else:
sd = dummy_tensor # type: ignore
......@@ -1428,8 +1443,11 @@ class FullyShardedDataParallel(nn.Module):
if self.rank != recipient_rank and recipient_rank is not None:
return None
# Unify the shard states by concatenating tensors and unflattening params
new_state_dict = ou.build_unflat_state_dict(self._fsdp_instances, world_optim_states)
# TODO: check if this code supports nested instances with different world size
new_state_dict = ou.build_unflat_state_dict(
self._fsdp_instances, world_optim_states, self.uncollected_opt_state
)
self.uncollected_opt_state = {}
assert "uncollected_local_ids" in new_state_dict
return new_state_dict
@property
......@@ -1437,6 +1455,17 @@ class FullyShardedDataParallel(nn.Module):
"""Returns all fsdp modules in self.modules() including self."""
return [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)]
def _remove_uncollectable_params_from_optim_state_dict(self, osd: Dict) -> Dict:
uncollected_ids = [i for i, m in enumerate(self._fsdp_instances) if m.no_broadcast_optim_state]
new_dct = {"state": {k: v for k, v in osd["state"].items() if k not in uncollected_ids}}
if self.rank == 0:
# Save placeholders for uncollected opt state to keep the same unflat OSD format.
self.uncollected_opt_state = {k: v for k, v in osd["state"].items() if k in uncollected_ids}
pg = copy.deepcopy(osd["param_groups"])
new_dct["param_groups"] = pg
return new_dct
def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Get the portion of the optimizer state dict associated with the shard
......@@ -1451,18 +1480,19 @@ class FullyShardedDataParallel(nn.Module):
"""
# Assert nesting is the same as it was at save time
instance_list = self._fsdp_instances
assert all(
x.world_size == self.world_size for x in instance_list
), "all nested instances must have same world size"
ou.check_param_counts_before_sharding(full_optim_state_dict, len(instance_list))
ids_not_to_shard = copy.deepcopy(full_optim_state_dict["uncollected_local_ids"])
if self.flatten_parameters:
full_optim_state_dict = ou.flatten_optim_state_dict(full_optim_state_dict)
assert len(full_optim_state_dict["state"]) in (0, len(instance_list))
assert len(full_optim_state_dict["state"]) in (
0,
len(instance_list),
), f'{len(full_optim_state_dict["state"])}, {len(instance_list)}'
# get the portion of dict associated with the shard, in place
for id, s in full_optim_state_dict["state"].items():
for k, v in s.items():
if torch.is_tensor(v):
if torch.is_tensor(v) and id not in ids_not_to_shard:
v_shard, _ = self._get_shard(v)
else:
v_shard = v # dont shard entries that are not tensors
......
......@@ -782,6 +782,7 @@ class MixtureOfExperts(NestedWrappedModule):
# "expert" params are different on each rank
torch.manual_seed(42 + group.rank())
expert = nn.Linear(16, 4)
self.num_expert_params = sum([p.numel() for p in expert.parameters()])
for p in expert.parameters():
p.expert = True
......@@ -795,7 +796,7 @@ class MixtureOfExperts(NestedWrappedModule):
if wrapper_config is not None:
# we create a process group of size 1 for the expert params
expert_group = torch.distributed.new_group([group.rank()])
expert_group = torch.distributed.new_group([group.rank()]) # world size 1 means no shard
expert = FullyShardedDataParallel(expert, expert_group, **wrapper_config)
shared = FullyShardedDataParallel(shared, group, **wrapper_config)
......
......@@ -16,7 +16,7 @@ from fairscale.utils.testing import objects_are_equal
from .test_fsdp import (
DistributedTest,
DummyProcessGroup,
NestedWrappedModule,
MixtureOfExperts,
TransformerWithSharedParams,
rename_test,
spawn_and_init,
......@@ -36,11 +36,12 @@ def assert_equal(a, b):
class TestOptimizerUtils(DistributedTest):
@parameterized.expand(
[[functools.partial(SGD, momentum=0.9), True], [SGD, False], [Adam, False], [Adadelta, True]],
[[functools.partial(SGD, momentum=0.9), True], [SGD, False], [Adam, False], [Adadelta, True], [Adam, True]],
name_func=rename_test,
)
def test_consolidate_optimizer(self, optim_fn, transformer):
config = {"mixed_precision": True, "flatten_parameters": True}
config["compute_dtype"] = torch.float32
test_fn = functools.partial(
self._test_consolidated_optimizer, config, optim_fn=optim_fn, transformer=transformer
)
......@@ -53,11 +54,11 @@ class TestOptimizerUtils(DistributedTest):
# Establish reference behavior.
if transformer:
unwrapped_model = TransformerWithSharedParams(group, wrapper_config=config).cuda()
fsdp = self.get_wrapped_model(group, config=config).cuda()
unwrapped_model = TransformerWithSharedParams(group).cuda()
else:
fsdp = FullyShardedDataParallel(NestedWrappedModule(group, wrapper_config=config), group, **config).cuda()
unwrapped_model = NestedWrappedModule(group, wrapper_config=None).cuda()
unwrapped_model = MixtureOfExperts(group, wrapper_config=None).cuda()
fsdp = FullyShardedDataParallel(MixtureOfExperts(group, wrapper_config=config)).cuda()
try:
fsdp_optim = optim_fn(fsdp.parameters(), lr=0.01,)
......@@ -68,19 +69,24 @@ class TestOptimizerUtils(DistributedTest):
fsdp_optim.zero_grad()
optim_unwrapped.zero_grad()
x = fsdp.module.get_input(torch.device("cuda"))
output = fsdp(*x)
loss = fsdp.module.get_loss(x, output).to("cuda")
fsdp.module.run_backward(loss)
fsdp_optim.step()
output = unwrapped_model(*x)
loss = unwrapped_model.get_loss(x, output)
unwrapped_model.run_backward(loss)
optim_unwrapped.step()
with torch.cuda.amp.autocast(enabled=True):
x = fsdp.module.get_input(torch.device("cuda"))
output = fsdp(*x)
loss = fsdp.module.get_loss(x, output).to("cuda")
fsdp.module.run_backward(loss)
fsdp_optim.step()
output = unwrapped_model(*x)
loss = unwrapped_model.get_loss(x, output)
unwrapped_model.run_backward(loss)
optim_unwrapped.step()
unwrapped_sd = optim_unwrapped.state_dict()
if not transformer:
no_broadcast_children = [x for x in fsdp._fsdp_instances if x.no_broadcast_optim_state]
assert len(no_broadcast_children) == 1
assert fsdp._fsdp_instances[-1].no_broadcast_optim_state
tstart = time()
sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0)
duration = time() - tstart
......@@ -88,7 +94,14 @@ class TestOptimizerUtils(DistributedTest):
assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate"
if fsdp.rank > 0:
assert sd is None
return
unflat_state = sd["state"]
assert "uncollected_local_ids" in sd
shard_sd = fsdp.get_shard_from_optim_state_dict(sd)
shard_sd = recursive_copy_to_device(shard_sd, non_blocking=False, device="cpu")
state_after_get_shard = sd["state"]
assert objects_are_equal(unflat_state, state_after_get_shard) # no side effects.
assert_equal(len(sd["state"]), len(unwrapped_sd["state"]))
assert_equal(len(sd["param_groups"][0]["params"]), len(unwrapped_sd["param_groups"][0]["params"]))
......@@ -97,18 +110,21 @@ class TestOptimizerUtils(DistributedTest):
sum([first_tensor_numel(v) for k, v in unwrapped_sd["state"].items()]),
)
shard_sd = fsdp.get_shard_from_optim_state_dict(sd)
original_shard_sd = fsdp_optim.state_dict()
assert_equal(len(shard_sd["state"]), len(original_shard_sd["state"]))
assert_equal(shard_sd.keys(), original_shard_sd.keys())
original_shard_sd = recursive_copy_to_device(original_shard_sd, non_blocking=False, device="cpu")
# Before asserting that the dicts are equal, we check keys individually to allow nice tracebacks.
assert_equal(
[first_tensor_numel(v) for k, v in shard_sd["state"].items()],
[first_tensor_numel(v) for k, v in original_shard_sd["state"].items()],
)
assert_equal(
sum([first_tensor_numel(v) for k, v in shard_sd["state"].items()]),
sum([first_tensor_numel(v) for k, v in original_shard_sd["state"].items()]),
[v for k, v in shard_sd["param_groups"][0].items()],
[v for k, v in original_shard_sd["param_groups"][0].items()],
)
assert objects_are_equal(shard_sd, original_shard_sd)
assert objects_are_equal(shard_sd["state"], original_shard_sd["state"])
assert objects_are_equal({k: shard_sd[k] for k in original_shard_sd}, original_shard_sd)
def test_named_params_ordering(self):
"""Test assumption of consolidate_optimizer_state_dict"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment