Unverified Commit a6549be7 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[fix] [FSDP] optim state dict should be completely on CPU (#590)

parent ce1f2cea
......@@ -8,6 +8,7 @@ import copy
from enum import Enum, auto
import functools
from math import inf
import time
import traceback
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, NamedTuple, Optional, Set, Tuple, Union
......@@ -208,6 +209,7 @@ class FullyShardedDataParallel(nn.Module):
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
self.numel_padded_per_param: List[int] = []
self._tstart = time.time()
if self.fp32_reduce_scatter and not self.mixed_precision:
raise ValueError("fp32_reduce_scatter requires mixed_precision=True")
......@@ -1414,7 +1416,6 @@ class FullyShardedDataParallel(nn.Module):
if should_collect_state:
assert isinstance(sd, dict), f"{self.rank} received {type(sd)} from {rank}, expected dict"
all_states.append(recursive_copy_to_device(sd, non_blocking=False, device=torch.device("cpu")))
return all_states
def gather_full_optim_state_dict(
......@@ -1459,8 +1460,12 @@ class FullyShardedDataParallel(nn.Module):
uncollected_ids = [i for i, m in enumerate(self._fsdp_instances) if m.no_broadcast_optim_state]
new_dct = {"state": {k: v for k, v in osd["state"].items() if k not in uncollected_ids}}
if self.rank == 0:
# Save placeholders for uncollected opt state to keep the same unflat OSD format.
self.uncollected_opt_state = {k: v for k, v in osd["state"].items() if k in uncollected_ids}
# Save placeholders for uncollected opt state to keep the same unflat OSD format, and move them to CPU.
self.uncollected_opt_state = {
k: recursive_copy_to_device(v, non_blocking=False, device=torch.device("cpu"))
for k, v in osd["state"].items()
if k in uncollected_ids
}
pg = copy.deepcopy(osd["param_groups"])
new_dct["param_groups"] = pg
......@@ -1500,6 +1505,14 @@ class FullyShardedDataParallel(nn.Module):
return full_optim_state_dict
def _print_r0(self, msg: str) -> None:
"""Debugging utility to print memory usage stats nicely on rank 0"""
if self.rank == 0:
gb_denom = 1024 ** 3
print(
f"{msg} cur={torch.cuda.memory_allocated()/gb_denom: .4f} GB, max={torch.cuda.max_memory_allocated()/gb_denom: .4f} GB, t={time.time()-self._tstart: .1f}"
)
def _get_default_cuda_device(module: nn.Module) -> torch.device:
"""Try to infer CUDA device from module parameters."""
......
......@@ -627,14 +627,15 @@ class MixtureOfExperts(NestedWrappedModule):
# "expert" params are different on each rank
torch.manual_seed(42 + group.rank())
expert = nn.Linear(16, 4)
d_expert = 16
expert = nn.Linear(d_expert, 4)
self.num_expert_params = sum([p.numel() for p in expert.parameters()])
for p in expert.parameters():
p.expert = True
# everything else is shared
torch.manual_seed(0)
shared = nn.Linear(4, 16)
shared = nn.Linear(4, d_expert)
if checkpoint_act:
expert = checkpoint_wrapper(expert)
......
......@@ -86,16 +86,30 @@ class TestOptimizerUtils(DistributedTest):
no_broadcast_children = [x for x in fsdp._fsdp_instances if x.no_broadcast_optim_state]
assert len(no_broadcast_children) == 1
assert fsdp._fsdp_instances[-1].no_broadcast_optim_state
torch.cuda.empty_cache()
cuda_gb_before = torch.cuda.memory_stats(fsdp.rank)["allocated_bytes.all.current"] / 1024 ** 3
tstart = time()
sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0)
duration = time() - tstart
# Switching from fairscale.optim.utils.broadcast_object to torch.broadcast_object_list will cause this to raise
assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate"
cuda_gb_after = torch.cuda.memory_stats(fsdp.rank)["allocated_bytes.all.current"] / 1024 ** 3
mem_usg_gb = cuda_gb_after - cuda_gb_before
assert mem_usg_gb == 0, f"gather_full_optim_state_dict used {mem_usg_gb:.2f} CUDA GB, max allowed is 0"
assert cuda_gb_after > 0, "got 0 memory usage, logging is broken"
if fsdp.rank > 0:
assert sd is None
return
# assert whole state dict on CPU
for k, v in sd["state"].items():
for buffer_name, t in v.items():
if torch.is_tensor(t):
msg = f"got device {t.device} for {k}: {buffer_name}. expected CPU"
assert t.device == torch.device("cpu"), msg
unflat_state = sd["state"]
assert "uncollected_local_ids" in sd
shard_sd = fsdp.get_shard_from_optim_state_dict(sd)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment