Unverified Commit 9474d75d authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[FSDP][feature] optimizer state dict save and load (#537)


Co-authored-by: default avatarMin Xu <24926999+min-xu-ai@users.noreply.github.com>
parent df493a29
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""These functions are used by FullyShardedDataParallel to help consolidate and shard optimizer states."""
import copy
from typing import Dict, Generator, List, Tuple
import torch
# This function helps shard a full optimizer state dict
def flatten_optim_state_dict(sd: Dict) -> Dict:
"""Shard a full optimizer state dict (called by FSDP.get_shard_from_optim_state_dict)"""
param_id_map = sd["param_id_map"]
num_local_params = len(set(param_id_map.values()))
if sd["state"]:
new_state: Dict = {local_id: {} for local_id in range(num_local_params)}
else:
new_state = {}
non_tensor_state = {}
# Populate `new_state["state"]`. (Assuming sd is sorted)
for expanded_pid, buffers in sd["state"].items():
consolidated_pid = param_id_map[expanded_pid]
for buffer_name, p in buffers.items():
if torch.is_tensor(p):
if buffer_name not in new_state[consolidated_pid]:
new_state[consolidated_pid][buffer_name] = []
new_state[consolidated_pid][buffer_name].append(p.reshape(-1))
else:
non_tensor_state[buffer_name] = p
# Now combine all tensors in each buffer using torch.cat().
for consolidated_pid, state in new_state.items():
for buffer_name, tensors in state.items():
new_state[consolidated_pid][buffer_name] = torch.cat(tensors)
new_state[consolidated_pid].update(non_tensor_state)
new_sd = {"state": new_state, "param_groups": sd["param_groups"]}
# add pointers from the `params` dict.
for pg_id, _ in enumerate(sd["param_groups"]):
# TODO: this list could be huge. Can we avoid materializing?
new_sd["param_groups"][pg_id]["params"] = list(range(num_local_params))
return new_sd
def check_param_counts_before_sharding(full_optim_state_dict: Dict, n_instances: int) -> None:
n_local_params_in_opt = len(set(full_optim_state_dict["param_id_map"].values()))
msg = (
f"Including itself, this model has {n_instances} nested instances. When the optimizer state was saved "
f"there were {n_local_params_in_opt}"
)
stateless = len(full_optim_state_dict["state"]) == 0
assert stateless or (n_instances == n_local_params_in_opt), msg
# All functions below here help saving the list of optimizer states, one from each rank
# build_unflat_state_dict is the interface used by FSDP
def _extract_non_tensor_state(combined_state: Dict[int, Dict[str, List]], param_id: int) -> Dict:
non_tensor_state = {} # This state is like the `step` count in Adam, not a tensor so we dont unpad or cat it.
for k, v in combined_state[param_id].items():
if torch.is_tensor(v[0]):
continue
elif len(set(v)) == 1:
non_tensor_state[k] = v[0]
else:
raise TypeError(f"Dont know how to consolidate optimizer param {k} with values {v}")
return non_tensor_state
def _combine_state(states: List[Dict]) -> Dict[int, Dict]:
combined_state = states[0]
for param_id in combined_state:
combined_state[param_id] = {k: [v] for k, v in combined_state[param_id].items()}
if len(states) == 1:
return combined_state
for rank, s in enumerate(states[1:]):
for param_id, param_state in s.items():
for k, tensor in param_state.items():
combined_state[param_id][k].append(tensor)
return combined_state
def _unflatten_optim_state(
combined_state: Dict[int, Dict], instance_list: List[torch.nn.Module], world_pad_info: List[List[List[int]]],
) -> Tuple[Dict[int, Dict], Dict[int, int]]:
# local ids are the keys in the current state (combined_state), (usually fewer)
# global ids will be the keys in the unflattened state
next_global_id = 0 # gets incremented
pad_info = {id: [s[id][0] for s in world_pad_info] for id in combined_state}
local_ids = [id for id in sorted(combined_state.keys())]
# non_tensor_state refers to entries in sd[state][param_id] that are not tensors, like "step".
# we check that these are identical across workers and then take the first
non_tensor_state = [_extract_non_tensor_state(combined_state, id) for id in combined_state]
# local corresponds to flattened, global corresponds to unflattened
num_unflat_params = [len(m._param_numels) for m in instance_list] # type: ignore
global_to_local_id = {}
for local_id, num_unflat in enumerate(num_unflat_params):
for _ in range(num_unflat):
global_to_local_id[next_global_id] = local_id
next_global_id += 1
if not combined_state:
return {}, global_to_local_id
# If the constant state is the same as the combined state, copy it N times, no unflattening needed.
unflat_state = {i: copy.deepcopy(non_tensor_state[0]) for i in range(sum(num_unflat_params))}
if non_tensor_state[0].keys() == combined_state[0].keys():
return unflat_state, global_to_local_id
local_to_global: Dict[int, List] = {i: [] for i in local_ids}
for g, l in global_to_local_id.items():
local_to_global[l].append(g)
# loop over parameters in state.
# Tensor state will be padded, concatenated, and restored to original shape with FlattenParamsWrapper.get_views
# get_views returns multiple tensors, each of which is a new parameter with a new "global" id.
for local_id in local_ids:
# undo the work of shard_parameters
for k, v in combined_state[local_id].items():
if k in non_tensor_state[local_id]:
continue
assert isinstance(v, list), f"got {k}: {v} for {local_id}"
v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, pad_info[local_id])]
flat_buffer = torch.cat(v_unpad)
param_views: Generator = instance_list[local_id].get_param_views(flat_buffer) # type: ignore
for global_id, param_view in zip(sorted(local_to_global[local_id]), param_views):
assert k not in unflat_state[global_id], f"already added {k} to {global_id} {local_id}"
unflat_state[global_id][k] = param_view
return unflat_state, global_to_local_id
def build_unflat_state_dict(instance_list: List[torch.nn.Module], world_optim_states: List[Dict]) -> Dict:
"""Build an unflattened optimizer state dict given a list of flattened optimizer state dicts from each rank."""
world_pad_info: List[List[List[int]]] = [s.pop("num_padded") for s in world_optim_states]
assert all(len(s) == len(instance_list) for s in world_pad_info)
assert all(len(s[0]) == 1 for s in world_pad_info)
param_groups = copy.deepcopy(world_optim_states[0]["param_groups"])
assert len(param_groups) == 1
# Aggregate from a list of dictionaries to a dictionary of lists
combined_state = _combine_state([x["state"] for x in world_optim_states])
del world_optim_states
# local ids are in the current state, global_ids will be in returned state.
unflat_state, global_to_local_id = _unflatten_optim_state(combined_state, instance_list, world_pad_info)
num_params = sum([len(m._param_numels) for m in instance_list]) # type: ignore
param_groups[0]["params"] = list(range(num_params)) # This could be a large list. #TODO: is it essential
return {
"state": dict(sorted(unflat_state.items())), # NOTE: this is probably already sorted
"param_id_map": global_to_local_id,
"param_groups": param_groups,
}
......@@ -21,12 +21,14 @@ import torch.nn.functional as F
from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.nn.wrap import auto_wrap, default_auto_wrap_policy, enable_wrap
from fairscale.optim.utils import calc_grad_norm
from fairscale.optim.utils import broadcast_object, calc_grad_norm, recursive_copy_to_device
from fairscale.utils.containers import apply_to_tensors
from fairscale.utils.parallel import chunk_and_pad, enable_pytorch_sync_bn, validate_process_group
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.utils.state_dict import replace_by_prefix_
from . import fsdp_optim_utils as ou
if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401
......@@ -88,8 +90,8 @@ class FullyShardedDataParallel(nn.Module):
import torch
from fairscale.nn.auto_wrap import enable_wrap, auto_wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
fsdp_params = dict(mixed_precision=True, flatten_parameters=True)
with enable_wrap(wrapper_cls=FSDP, **fsdp_params):
fsdp_params = dict(wrapper_cls=FSDP, mixed_precision=True, flatten_parameters=True)
with enable_wrap(**fsdp_params):
# Wraps layer in FSDP by default if within context
self.l1 = wrap(torch.nn.Linear(5, 5))
assert isinstance(self.l1, FSDP)
......@@ -185,6 +187,8 @@ class FullyShardedDataParallel(nn.Module):
self.buffer_dtype = buffer_dtype or self.compute_dtype
self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu
self.bucket_cap_mb = bucket_cap_mb
self.numel_padded_per_param: List[int] = []
self.compute_device = compute_device
if self.fp32_reduce_scatter and not self.mixed_precision:
......@@ -412,6 +416,7 @@ class FullyShardedDataParallel(nn.Module):
allocate less memory for optimizer state, avoiding redundancy across
data parallel workers.
"""
self.numel_padded_per_param = []
for p in self.params:
assert not hasattr(p, "_is_sharded")
assert p.is_floating_point()
......@@ -423,16 +428,19 @@ class FullyShardedDataParallel(nn.Module):
p._orig_size = p.data.size()
if not p._is_sharded:
self.numel_padded_per_param.append(0)
continue
p._is_sharded = True
# Replace p.data with the relevant shard.
orig_data = p.data
p.data = self._get_shard(p.data)
p.data, num_padded = self._get_shard(p.data)
self.numel_padded_per_param.append(num_padded)
free_storage_(orig_data)
assert len(self.numel_padded_per_param) == len(self.params)
def _get_shard(self, tensor: torch.Tensor) -> torch.Tensor:
"""Return the local shard of a given full tensor."""
def _get_shard(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, int]:
"""Return the local shard of a full tensor."""
# Shard using torch.chunk to match all-gather/reduce-scatter.
chunks = list(torch.flatten(tensor).chunk(self.world_size))
while len(chunks) < self.world_size:
......@@ -445,7 +453,7 @@ class FullyShardedDataParallel(nn.Module):
shard = chunks[self.rank].clone()
if num_to_pad > 0:
shard = F.pad(shard, [0, num_to_pad])
return shard
return shard, num_to_pad
def extra_repr(self) -> str:
return (
......@@ -684,7 +692,7 @@ class FullyShardedDataParallel(nn.Module):
if not volatile:
# Copy any changes made to the full params back into
# the corresponding local shards.
local_shard = self._get_shard(full_tensor)
local_shard, _ = self._get_shard(full_tensor)
p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard))
if safe_to_free:
free_storage_(full_tensor)
......@@ -1346,6 +1354,111 @@ class FullyShardedDataParallel(nn.Module):
traceback.print_stack()
raise ValueError(msg)
def _consolidate_optim_state_dict(
self, optim: torch.optim.Optimizer, recipient_rank: Optional[int] = None
) -> List[Dict]:
"""Update the consolidated state_dict list, one per rank.
Args:
optim (Optimizer): an optimizer instance for this FSDP rank. Its state is
used in the consolidation. However, its state is not modified.
recipient_rank (int): on which rank to materialize the full state dict.
None is a special value, which means that all ranks should have the state
Returns:
all_states (list[dict]) the optimizer state from each rank
.. warning: This needs to be called on all replicas"""
self._lazy_init()
# NOTE(SS): we do not support param groups yet, as they seem to break FSDP
# Pull the sharded state from all the other replicas
# Store all the states in order, rank by rank
should_collect_state = recipient_rank is None or (self.rank == recipient_rank)
all_states: List[Dict[str, Any]] = []
dummy_tensor = torch.tensor([0], dtype=torch.uint8, device=self.compute_device)
for rank in range(self.world_size):
if rank == self.rank:
sd = optim.state_dict()
sd["num_padded"] = [m.numel_padded_per_param for m in self._fsdp_instances]
else:
sd = dummy_tensor # type: ignore
sd = broadcast_object(sd, src_rank=rank, group=self.process_group, dist_device=self.compute_device) # type: ignore
if should_collect_state:
assert isinstance(sd, dict), f"{self.rank} received {type(sd)} from {rank}, expected dict"
all_states.append(recursive_copy_to_device(sd, non_blocking=False, device=torch.device("cpu")))
return all_states
def gather_full_optim_state_dict(
self, optim: torch.optim.Optimizer, recipient_rank: Optional[int] = 0
) -> Optional[Dict[str, Any]]:
"""Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the
sharded properties are not exposed. Multiple parameter groups are not yet supported.
This should be called only on the root FSDP instance.
Different world_size groups in nested FSDP instances is not supported.
Args:
optim (Optimizer): an optimizer instance for this FSDP rank. Its state is
used in the consolidation. However, its state is not modified.
recipient_rank (int): on which rank to materialize the full state dict.
Returns:
a dict with two entries
* state - a dict holding gathered optimization state, 1 entry per unflat parameter
* param_groups - a dict containing the 1 parameter group
"""
if not self.flatten_parameters:
raise NotImplementedError("optim state dict requires flatten_parameters=True")
world_optim_states = self._consolidate_optim_state_dict(optim, recipient_rank)
if self.rank != recipient_rank and recipient_rank is not None:
return None
# Unify the shard states by concatenating tensors and unflattening params
new_state_dict = ou.build_unflat_state_dict(self._fsdp_instances, world_optim_states)
# TODO: check if this code supports nested instances with different world size
return new_state_dict
@property
def _fsdp_instances(self) -> List[nn.Module]:
"""Returns all fsdp modules in self.modules() including self."""
return [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)]
def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Get the portion of the optimizer state dict associated with the shard
This can be used to get the right sharded optimizer state to be loaded
into the sharded optimizer for this FSDP rank.
Args:
full_optim_state_dict (dict): consolidated optimizer state returned by ``gather_full_optim_state``, or loaded from a checkpoint.
Returns:
(dict): a shard of the optimizer state.
"""
# Assert nesting is the same as it was at save time
instance_list = self._fsdp_instances
assert all(
x.world_size == self.world_size for x in instance_list
), "all nested instances must have same world size"
ou.check_param_counts_before_sharding(full_optim_state_dict, len(instance_list))
if self.flatten_parameters:
full_optim_state_dict = ou.flatten_optim_state_dict(full_optim_state_dict)
assert len(full_optim_state_dict["state"]) in (0, len(instance_list))
# get the portion of dict associated with the shard, in place
for id, s in full_optim_state_dict["state"].items():
for k, v in s.items():
if torch.is_tensor(v):
v_shard, _ = self._get_shard(v)
else:
v_shard = v # dont shard entries that are not tensors
full_optim_state_dict["state"][id][k] = v_shard
return full_optim_state_dict
@torch.no_grad()
def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]:
......
......@@ -122,7 +122,7 @@ class FlattenParamsWrapper(nn.Module):
# register the views as plain attributes
self._unflatten_params_as_views()
def _get_param_views(self, flat_param: Tensor) -> Generator:
def get_param_views(self, flat_param: Tensor) -> Generator:
return (t.view(s) for (t, s) in zip(flat_param.split(self._param_numels), self._param_shapes))
def _unflatten_params(self, flat_param: Optional[Tensor] = None) -> None:
......@@ -130,7 +130,7 @@ class FlattenParamsWrapper(nn.Module):
self.is_flattened = False
flat_param = flat_param if flat_param is not None else self.flat_param
ps = self._get_param_views(flat_param)
ps = self.get_param_views(flat_param)
for (m, n), p in zip(self._param_infos, ps):
if hasattr(m, n):
delattr(m, n)
......@@ -144,7 +144,7 @@ class FlattenParamsWrapper(nn.Module):
def _unflatten_params_as_views(self) -> None:
assert self.is_flattened
ps = self._get_param_views(self.flat_param)
ps = self.get_param_views(self.flat_param)
for (m, n), p in zip(self._param_infos, ps):
setattr(m, n, p) # This will set as plain attr
for (m, n, shared_m, shared_n) in self._shared_param_infos:
......
......@@ -5,6 +5,7 @@ tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_fsdp_input.py
tests/nn/data_parallel/test_fsdp_multiple_forward.py
tests/nn/data_parallel/test_fsdp_regnet.py
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
tests/nn/data_parallel/test_sharded_ddp_features.py
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
tests/nn/pipe/skip/test_gpipe.py
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import functools
from time import time
from parameterized import parameterized
import torch
from torch.optim import SGD, Adadelta, Adam # type: ignore
from fairscale.nn import FullyShardedDataParallel
from fairscale.optim.utils import recursive_copy_to_device
from fairscale.utils.testing import objects_are_equal
from .test_fsdp import (
DistributedTest,
DummyProcessGroup,
NestedWrappedModule,
TransformerWithSharedParams,
rename_test,
spawn_and_init,
)
def first_tensor_numel(dct):
for k, v in dct.items():
if torch.is_tensor(v):
return v.numel()
return 0
def assert_equal(a, b):
assert a == b, f"{a} != {b}"
class TestOptimizerUtils(DistributedTest):
@parameterized.expand(
[[functools.partial(SGD, momentum=0.9), True], [SGD, False], [Adam, False], [Adadelta, True]],
name_func=rename_test,
)
def test_consolidate_optimizer(self, optim_fn, transformer):
config = {"mixed_precision": True, "flatten_parameters": True}
test_fn = functools.partial(
self._test_consolidated_optimizer, config, optim_fn=optim_fn, transformer=transformer
)
spawn_and_init(test_fn, world_sizes=[min(torch.cuda.device_count(), 4)])
@classmethod
def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim.SGD, transformer=False):
"""FSDP.gather_full_optim_state_dict() should return something very similar to optimizer.state_dict()"""
# Establish reference behavior.
if transformer:
fsdp = self.get_wrapped_model(group, config=config).cuda()
unwrapped_model = TransformerWithSharedParams(group).cuda()
else:
fsdp = FullyShardedDataParallel(NestedWrappedModule(group, wrapper_config=config), group, **config).cuda()
unwrapped_model = NestedWrappedModule(group, wrapper_config=None).cuda()
try:
fsdp_optim = optim_fn(fsdp.parameters(), lr=0.01,)
optim_unwrapped = optim_fn(unwrapped_model.parameters(), lr=0.01)
except TypeError: # Adadelta
fsdp_optim = optim_fn(fsdp.parameters())
optim_unwrapped = optim_fn(unwrapped_model.parameters())
fsdp_optim.zero_grad()
optim_unwrapped.zero_grad()
x = fsdp.module.get_input(torch.device("cuda"))
output = fsdp(*x)
loss = fsdp.module.get_loss(x, output).to("cuda")
fsdp.module.run_backward(loss)
fsdp_optim.step()
output = unwrapped_model(*x)
loss = unwrapped_model.get_loss(x, output)
unwrapped_model.run_backward(loss)
optim_unwrapped.step()
unwrapped_sd = optim_unwrapped.state_dict()
tstart = time()
sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0)
duration = time() - tstart
# Switching from fairscale.optim.utils.broadcast_object to torch.broadcast_object_list will cause this to raise
assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate"
if fsdp.rank > 0:
return
assert_equal(len(sd["state"]), len(unwrapped_sd["state"]))
assert_equal(len(sd["param_groups"][0]["params"]), len(unwrapped_sd["param_groups"][0]["params"]))
assert_equal(
sum([first_tensor_numel(v) for k, v in sd["state"].items()]),
sum([first_tensor_numel(v) for k, v in unwrapped_sd["state"].items()]),
)
shard_sd = fsdp.get_shard_from_optim_state_dict(sd)
original_shard_sd = fsdp_optim.state_dict()
assert_equal(len(shard_sd["state"]), len(original_shard_sd["state"]))
assert_equal(shard_sd.keys(), original_shard_sd.keys())
original_shard_sd = recursive_copy_to_device(original_shard_sd, non_blocking=False, device="cpu")
assert_equal(
sum([first_tensor_numel(v) for k, v in shard_sd["state"].items()]),
sum([first_tensor_numel(v) for k, v in original_shard_sd["state"].items()]),
)
assert objects_are_equal(shard_sd, original_shard_sd)
def test_named_params_ordering(self):
"""Test assumption of consolidate_optimizer_state_dict"""
group = DummyProcessGroup(0, 1)
model = TransformerWithSharedParams(group)
named_pars = [p for n, p in model.named_parameters()]
for i, p in enumerate(model.parameters()):
assert objects_are_equal(p, named_pars[i])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment