Unverified Commit 1fcbd624 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[FSDP] add no_broadcast_optim_state option (#560)

parent 54a97ee5
...@@ -21,22 +21,22 @@ def flatten_optim_state_dict(sd: Dict) -> Dict: ...@@ -21,22 +21,22 @@ def flatten_optim_state_dict(sd: Dict) -> Dict:
non_tensor_state = {} non_tensor_state = {}
# Populate `new_state["state"]`. (Assuming sd is sorted) # Populate `new_state["state"]`. (Assuming sd is sorted)
for expanded_pid, buffers in sd["state"].items(): for global_id, buffers in sd["state"].items():
consolidated_pid = param_id_map[expanded_pid] local_id = param_id_map[global_id]
for buffer_name, p in buffers.items(): for buffer_name, p in buffers.items():
if torch.is_tensor(p): if torch.is_tensor(p):
if buffer_name not in new_state[consolidated_pid]: if buffer_name not in new_state[local_id]:
new_state[consolidated_pid][buffer_name] = [] new_state[local_id][buffer_name] = []
new_state[consolidated_pid][buffer_name].append(p.reshape(-1)) new_state[local_id][buffer_name].append(p.reshape(-1))
else: else:
non_tensor_state[buffer_name] = p non_tensor_state[buffer_name] = p
# Now combine all tensors in each buffer using torch.cat(). # Now combine all tensors in each buffer using torch.cat().
for consolidated_pid, state in new_state.items(): for local_id, state in new_state.items():
for buffer_name, tensors in state.items(): for buffer_name, tensors in state.items():
new_state[consolidated_pid][buffer_name] = torch.cat(tensors) new_state[local_id][buffer_name] = torch.cat(tensors)
new_state[consolidated_pid].update(non_tensor_state) new_state[local_id].update(non_tensor_state)
new_sd = {"state": new_state, "param_groups": sd["param_groups"]} new_sd = {"state": new_state, "param_groups": copy.deepcopy(sd["param_groups"])}
# add pointers from the `params` dict. # add pointers from the `params` dict.
for pg_id, _ in enumerate(sd["param_groups"]): for pg_id, _ in enumerate(sd["param_groups"]):
...@@ -109,6 +109,7 @@ def _unflatten_optim_state( ...@@ -109,6 +109,7 @@ def _unflatten_optim_state(
# If the constant state is the same as the combined state, copy it N times, no unflattening needed. # If the constant state is the same as the combined state, copy it N times, no unflattening needed.
unflat_state = {i: copy.deepcopy(non_tensor_state[0]) for i in range(sum(num_unflat_params))} unflat_state = {i: copy.deepcopy(non_tensor_state[0]) for i in range(sum(num_unflat_params))}
if non_tensor_state[0].keys() == combined_state[0].keys(): if non_tensor_state[0].keys() == combined_state[0].keys():
return unflat_state, global_to_local_id return unflat_state, global_to_local_id
...@@ -134,24 +135,33 @@ def _unflatten_optim_state( ...@@ -134,24 +135,33 @@ def _unflatten_optim_state(
return unflat_state, global_to_local_id return unflat_state, global_to_local_id
def build_unflat_state_dict(instance_list: List[torch.nn.Module], world_optim_states: List[Dict]) -> Dict: def build_unflat_state_dict(
instance_list: List[torch.nn.Module], world_optim_states: List[Dict], uncollected_opt_state: Dict[int, Dict]
) -> Dict:
"""Build an unflattened optimizer state dict given a list of flattened optimizer state dicts from each rank.""" """Build an unflattened optimizer state dict given a list of flattened optimizer state dicts from each rank."""
world_pad_info: List[List[List[int]]] = [s.pop("num_padded") for s in world_optim_states] world_pad_info: List[List[List[int]]] = [s.pop("num_padded") for s in world_optim_states]
assert all(len(s) == len(instance_list) for s in world_pad_info) assert all(len(s) == len(instance_list) for s in world_pad_info)
assert all(len(s[0]) == 1 for s in world_pad_info) assert all(len(s[0]) == 1 for s in world_pad_info)
# Since there are no tensors in param_groups, deepcopy is fine
param_groups = copy.deepcopy(world_optim_states[0]["param_groups"]) param_groups = copy.deepcopy(world_optim_states[0]["param_groups"])
assert len(param_groups) == 1 assert len(param_groups) == 1
# Aggregate from a list of dictionaries to a dictionary of lists # Aggregate from a list of dictionaries to a dictionary of lists
combined_state = _combine_state([x["state"] for x in world_optim_states]) combined_state = _combine_state([x["state"] for x in world_optim_states])
for local_id, v in uncollected_opt_state.items():
assert local_id not in combined_state
combined_state[local_id] = {}
for buffer_name, tensor in v.items():
combined_state[local_id][buffer_name] = [tensor]
del world_optim_states del world_optim_states
# local ids are in the current state, global_ids will be in returned state. # local ids are in the current state, global_ids will be in returned state.
unflat_state, global_to_local_id = _unflatten_optim_state(combined_state, instance_list, world_pad_info) unflat_state, global_to_local_id = _unflatten_optim_state(combined_state, instance_list, world_pad_info)
num_params = sum([len(m._param_numels) for m in instance_list]) # type: ignore num_params = sum([len(m._param_numels) for m in instance_list]) # type: ignore
param_groups[0]["params"] = list(range(num_params)) # This could be a large list. #TODO: is it essential param_groups[0]["params"] = list(range(num_params))
return { return {
"state": dict(sorted(unflat_state.items())), # NOTE: this is probably already sorted "state": dict(sorted(unflat_state.items())), # NOTE: this is probably already sorted
"param_id_map": global_to_local_id, "param_id_map": global_to_local_id,
"param_groups": param_groups, "param_groups": param_groups,
"uncollected_local_ids": list(uncollected_opt_state.keys()),
} }
...@@ -157,6 +157,12 @@ class FullyShardedDataParallel(nn.Module): ...@@ -157,6 +157,12 @@ class FullyShardedDataParallel(nn.Module):
device, the param's device will be used. If not given and module device, the param's device will be used. If not given and module
params are on CPU, then the current CUDA device (as indicated by params are on CPU, then the current CUDA device (as indicated by
``torch.cuda.current_device()`` will be used. ``torch.cuda.current_device()`` will be used.
no_broadcast_optim_state: (bool, Optional)
do not broadcast this modules optimizer state when ``gather_full_optim_state_dict`` is called.
If you set this true, you are expected to overwrite the relevant state entries of the returned optimizer state dict
with the proper state at each rank. This is useful for situations, like Mixture Of Experts,
where all but a few parameters can fit on one node.
Default: False
""" """
def __init__( def __init__(
...@@ -173,6 +179,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -173,6 +179,7 @@ class FullyShardedDataParallel(nn.Module):
move_grads_to_cpu: Optional[bool] = None, move_grads_to_cpu: Optional[bool] = None,
bucket_cap_mb: int = 25, bucket_cap_mb: int = 25,
compute_device: Optional[torch.device] = None, compute_device: Optional[torch.device] = None,
no_broadcast_optim_state: Optional[bool] = False,
): ):
super().__init__() super().__init__()
self.process_group = process_group or dist.new_group() self.process_group = process_group or dist.new_group()
...@@ -187,6 +194,8 @@ class FullyShardedDataParallel(nn.Module): ...@@ -187,6 +194,8 @@ class FullyShardedDataParallel(nn.Module):
self.buffer_dtype = buffer_dtype or self.compute_dtype self.buffer_dtype = buffer_dtype or self.compute_dtype
self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu
self.bucket_cap_mb = bucket_cap_mb self.bucket_cap_mb = bucket_cap_mb
self.uncollected_opt_state: Dict[int, Dict] = {}
self.no_broadcast_optim_state = no_broadcast_optim_state
self.gradient_predivide_factor: int = self.get_gradient_predivide_factor(self.world_size) self.gradient_predivide_factor: int = self.get_gradient_predivide_factor(self.world_size)
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
...@@ -849,6 +858,12 @@ class FullyShardedDataParallel(nn.Module): ...@@ -849,6 +858,12 @@ class FullyShardedDataParallel(nn.Module):
if m.process_group != self.process_group: if m.process_group != self.process_group:
self.children_share_process_group = False self.children_share_process_group = False
# if child instance in its own (smaller) world, that was probably an attempt to avoid OOM.
# Therefore gathering this child's optim state will probably cause OOM, so we won't do it.
m.no_broadcast_optim_state = m.no_broadcast_optim_state or (
(m.world_size == 1) and (m.world_size < self.world_size) and (m.process_group != self.process_group)
)
def _setup_streams(self) -> None: def _setup_streams(self) -> None:
"""Create streams to overlap data transfer and computation.""" """Create streams to overlap data transfer and computation."""
if len(self._streams) > 0 or not self._is_root: if len(self._streams) > 0 or not self._is_root:
...@@ -1391,7 +1406,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1391,7 +1406,7 @@ class FullyShardedDataParallel(nn.Module):
dummy_tensor = torch.tensor([0], dtype=torch.uint8, device=self.compute_device) dummy_tensor = torch.tensor([0], dtype=torch.uint8, device=self.compute_device)
for rank in range(self.world_size): for rank in range(self.world_size):
if rank == self.rank: if rank == self.rank:
sd = optim.state_dict() sd = self._remove_uncollectable_params_from_optim_state_dict(optim.state_dict())
sd["num_padded"] = [m.numel_padded_per_param for m in self._fsdp_instances] sd["num_padded"] = [m.numel_padded_per_param for m in self._fsdp_instances]
else: else:
sd = dummy_tensor # type: ignore sd = dummy_tensor # type: ignore
...@@ -1428,8 +1443,11 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1428,8 +1443,11 @@ class FullyShardedDataParallel(nn.Module):
if self.rank != recipient_rank and recipient_rank is not None: if self.rank != recipient_rank and recipient_rank is not None:
return None return None
# Unify the shard states by concatenating tensors and unflattening params # Unify the shard states by concatenating tensors and unflattening params
new_state_dict = ou.build_unflat_state_dict(self._fsdp_instances, world_optim_states) new_state_dict = ou.build_unflat_state_dict(
# TODO: check if this code supports nested instances with different world size self._fsdp_instances, world_optim_states, self.uncollected_opt_state
)
self.uncollected_opt_state = {}
assert "uncollected_local_ids" in new_state_dict
return new_state_dict return new_state_dict
@property @property
...@@ -1437,6 +1455,17 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1437,6 +1455,17 @@ class FullyShardedDataParallel(nn.Module):
"""Returns all fsdp modules in self.modules() including self.""" """Returns all fsdp modules in self.modules() including self."""
return [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)] return [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)]
def _remove_uncollectable_params_from_optim_state_dict(self, osd: Dict) -> Dict:
uncollected_ids = [i for i, m in enumerate(self._fsdp_instances) if m.no_broadcast_optim_state]
new_dct = {"state": {k: v for k, v in osd["state"].items() if k not in uncollected_ids}}
if self.rank == 0:
# Save placeholders for uncollected opt state to keep the same unflat OSD format.
self.uncollected_opt_state = {k: v for k, v in osd["state"].items() if k in uncollected_ids}
pg = copy.deepcopy(osd["param_groups"])
new_dct["param_groups"] = pg
return new_dct
def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any]) -> Dict[str, Any]: def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Get the portion of the optimizer state dict associated with the shard """Get the portion of the optimizer state dict associated with the shard
...@@ -1451,18 +1480,19 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1451,18 +1480,19 @@ class FullyShardedDataParallel(nn.Module):
""" """
# Assert nesting is the same as it was at save time # Assert nesting is the same as it was at save time
instance_list = self._fsdp_instances instance_list = self._fsdp_instances
assert all(
x.world_size == self.world_size for x in instance_list
), "all nested instances must have same world size"
ou.check_param_counts_before_sharding(full_optim_state_dict, len(instance_list)) ou.check_param_counts_before_sharding(full_optim_state_dict, len(instance_list))
ids_not_to_shard = copy.deepcopy(full_optim_state_dict["uncollected_local_ids"])
if self.flatten_parameters: if self.flatten_parameters:
full_optim_state_dict = ou.flatten_optim_state_dict(full_optim_state_dict) full_optim_state_dict = ou.flatten_optim_state_dict(full_optim_state_dict)
assert len(full_optim_state_dict["state"]) in (0, len(instance_list)) assert len(full_optim_state_dict["state"]) in (
0,
len(instance_list),
), f'{len(full_optim_state_dict["state"])}, {len(instance_list)}'
# get the portion of dict associated with the shard, in place # get the portion of dict associated with the shard, in place
for id, s in full_optim_state_dict["state"].items(): for id, s in full_optim_state_dict["state"].items():
for k, v in s.items(): for k, v in s.items():
if torch.is_tensor(v): if torch.is_tensor(v) and id not in ids_not_to_shard:
v_shard, _ = self._get_shard(v) v_shard, _ = self._get_shard(v)
else: else:
v_shard = v # dont shard entries that are not tensors v_shard = v # dont shard entries that are not tensors
......
...@@ -782,6 +782,7 @@ class MixtureOfExperts(NestedWrappedModule): ...@@ -782,6 +782,7 @@ class MixtureOfExperts(NestedWrappedModule):
# "expert" params are different on each rank # "expert" params are different on each rank
torch.manual_seed(42 + group.rank()) torch.manual_seed(42 + group.rank())
expert = nn.Linear(16, 4) expert = nn.Linear(16, 4)
self.num_expert_params = sum([p.numel() for p in expert.parameters()])
for p in expert.parameters(): for p in expert.parameters():
p.expert = True p.expert = True
...@@ -795,7 +796,7 @@ class MixtureOfExperts(NestedWrappedModule): ...@@ -795,7 +796,7 @@ class MixtureOfExperts(NestedWrappedModule):
if wrapper_config is not None: if wrapper_config is not None:
# we create a process group of size 1 for the expert params # we create a process group of size 1 for the expert params
expert_group = torch.distributed.new_group([group.rank()]) expert_group = torch.distributed.new_group([group.rank()]) # world size 1 means no shard
expert = FullyShardedDataParallel(expert, expert_group, **wrapper_config) expert = FullyShardedDataParallel(expert, expert_group, **wrapper_config)
shared = FullyShardedDataParallel(shared, group, **wrapper_config) shared = FullyShardedDataParallel(shared, group, **wrapper_config)
......
...@@ -16,7 +16,7 @@ from fairscale.utils.testing import objects_are_equal ...@@ -16,7 +16,7 @@ from fairscale.utils.testing import objects_are_equal
from .test_fsdp import ( from .test_fsdp import (
DistributedTest, DistributedTest,
DummyProcessGroup, DummyProcessGroup,
NestedWrappedModule, MixtureOfExperts,
TransformerWithSharedParams, TransformerWithSharedParams,
rename_test, rename_test,
spawn_and_init, spawn_and_init,
...@@ -36,11 +36,12 @@ def assert_equal(a, b): ...@@ -36,11 +36,12 @@ def assert_equal(a, b):
class TestOptimizerUtils(DistributedTest): class TestOptimizerUtils(DistributedTest):
@parameterized.expand( @parameterized.expand(
[[functools.partial(SGD, momentum=0.9), True], [SGD, False], [Adam, False], [Adadelta, True]], [[functools.partial(SGD, momentum=0.9), True], [SGD, False], [Adam, False], [Adadelta, True], [Adam, True]],
name_func=rename_test, name_func=rename_test,
) )
def test_consolidate_optimizer(self, optim_fn, transformer): def test_consolidate_optimizer(self, optim_fn, transformer):
config = {"mixed_precision": True, "flatten_parameters": True} config = {"mixed_precision": True, "flatten_parameters": True}
config["compute_dtype"] = torch.float32
test_fn = functools.partial( test_fn = functools.partial(
self._test_consolidated_optimizer, config, optim_fn=optim_fn, transformer=transformer self._test_consolidated_optimizer, config, optim_fn=optim_fn, transformer=transformer
) )
...@@ -53,11 +54,11 @@ class TestOptimizerUtils(DistributedTest): ...@@ -53,11 +54,11 @@ class TestOptimizerUtils(DistributedTest):
# Establish reference behavior. # Establish reference behavior.
if transformer: if transformer:
unwrapped_model = TransformerWithSharedParams(group, wrapper_config=config).cuda()
fsdp = self.get_wrapped_model(group, config=config).cuda() fsdp = self.get_wrapped_model(group, config=config).cuda()
unwrapped_model = TransformerWithSharedParams(group).cuda()
else: else:
fsdp = FullyShardedDataParallel(NestedWrappedModule(group, wrapper_config=config), group, **config).cuda() unwrapped_model = MixtureOfExperts(group, wrapper_config=None).cuda()
unwrapped_model = NestedWrappedModule(group, wrapper_config=None).cuda() fsdp = FullyShardedDataParallel(MixtureOfExperts(group, wrapper_config=config)).cuda()
try: try:
fsdp_optim = optim_fn(fsdp.parameters(), lr=0.01,) fsdp_optim = optim_fn(fsdp.parameters(), lr=0.01,)
...@@ -68,19 +69,24 @@ class TestOptimizerUtils(DistributedTest): ...@@ -68,19 +69,24 @@ class TestOptimizerUtils(DistributedTest):
fsdp_optim.zero_grad() fsdp_optim.zero_grad()
optim_unwrapped.zero_grad() optim_unwrapped.zero_grad()
with torch.cuda.amp.autocast(enabled=True):
x = fsdp.module.get_input(torch.device("cuda")) x = fsdp.module.get_input(torch.device("cuda"))
output = fsdp(*x) output = fsdp(*x)
loss = fsdp.module.get_loss(x, output).to("cuda") loss = fsdp.module.get_loss(x, output).to("cuda")
fsdp.module.run_backward(loss) fsdp.module.run_backward(loss)
fsdp_optim.step() fsdp_optim.step()
output = unwrapped_model(*x) output = unwrapped_model(*x)
loss = unwrapped_model.get_loss(x, output) loss = unwrapped_model.get_loss(x, output)
unwrapped_model.run_backward(loss) unwrapped_model.run_backward(loss)
optim_unwrapped.step() optim_unwrapped.step()
unwrapped_sd = optim_unwrapped.state_dict() unwrapped_sd = optim_unwrapped.state_dict()
if not transformer:
no_broadcast_children = [x for x in fsdp._fsdp_instances if x.no_broadcast_optim_state]
assert len(no_broadcast_children) == 1
assert fsdp._fsdp_instances[-1].no_broadcast_optim_state
tstart = time() tstart = time()
sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0) sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0)
duration = time() - tstart duration = time() - tstart
...@@ -88,7 +94,14 @@ class TestOptimizerUtils(DistributedTest): ...@@ -88,7 +94,14 @@ class TestOptimizerUtils(DistributedTest):
assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate" assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate"
if fsdp.rank > 0: if fsdp.rank > 0:
assert sd is None
return return
unflat_state = sd["state"]
assert "uncollected_local_ids" in sd
shard_sd = fsdp.get_shard_from_optim_state_dict(sd)
shard_sd = recursive_copy_to_device(shard_sd, non_blocking=False, device="cpu")
state_after_get_shard = sd["state"]
assert objects_are_equal(unflat_state, state_after_get_shard) # no side effects.
assert_equal(len(sd["state"]), len(unwrapped_sd["state"])) assert_equal(len(sd["state"]), len(unwrapped_sd["state"]))
assert_equal(len(sd["param_groups"][0]["params"]), len(unwrapped_sd["param_groups"][0]["params"])) assert_equal(len(sd["param_groups"][0]["params"]), len(unwrapped_sd["param_groups"][0]["params"]))
...@@ -97,18 +110,21 @@ class TestOptimizerUtils(DistributedTest): ...@@ -97,18 +110,21 @@ class TestOptimizerUtils(DistributedTest):
sum([first_tensor_numel(v) for k, v in unwrapped_sd["state"].items()]), sum([first_tensor_numel(v) for k, v in unwrapped_sd["state"].items()]),
) )
shard_sd = fsdp.get_shard_from_optim_state_dict(sd)
original_shard_sd = fsdp_optim.state_dict() original_shard_sd = fsdp_optim.state_dict()
assert_equal(len(shard_sd["state"]), len(original_shard_sd["state"])) assert_equal(len(shard_sd["state"]), len(original_shard_sd["state"]))
assert_equal(shard_sd.keys(), original_shard_sd.keys()) assert_equal(shard_sd.keys(), original_shard_sd.keys())
original_shard_sd = recursive_copy_to_device(original_shard_sd, non_blocking=False, device="cpu") original_shard_sd = recursive_copy_to_device(original_shard_sd, non_blocking=False, device="cpu")
# Before asserting that the dicts are equal, we check keys individually to allow nice tracebacks.
assert_equal(
[first_tensor_numel(v) for k, v in shard_sd["state"].items()],
[first_tensor_numel(v) for k, v in original_shard_sd["state"].items()],
)
assert_equal( assert_equal(
sum([first_tensor_numel(v) for k, v in shard_sd["state"].items()]), [v for k, v in shard_sd["param_groups"][0].items()],
sum([first_tensor_numel(v) for k, v in original_shard_sd["state"].items()]), [v for k, v in original_shard_sd["param_groups"][0].items()],
) )
assert objects_are_equal(shard_sd, original_shard_sd) assert objects_are_equal(shard_sd["state"], original_shard_sd["state"])
assert objects_are_equal({k: shard_sd[k] for k in original_shard_sd}, original_shard_sd)
def test_named_params_ordering(self): def test_named_params_ordering(self):
"""Test assumption of consolidate_optimizer_state_dict""" """Test assumption of consolidate_optimizer_state_dict"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment