"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "f44f20cffc45ca226d5bf7e25c11fb0d1119b4bb"
Unverified Commit a6549be7 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[fix] [FSDP] optim state dict should be completely on CPU (#590)

parent ce1f2cea
...@@ -8,6 +8,7 @@ import copy ...@@ -8,6 +8,7 @@ import copy
from enum import Enum, auto from enum import Enum, auto
import functools import functools
from math import inf from math import inf
import time
import traceback import traceback
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, NamedTuple, Optional, Set, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, NamedTuple, Optional, Set, Tuple, Union
...@@ -208,6 +209,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -208,6 +209,7 @@ class FullyShardedDataParallel(nn.Module):
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
self.numel_padded_per_param: List[int] = [] self.numel_padded_per_param: List[int] = []
self._tstart = time.time()
if self.fp32_reduce_scatter and not self.mixed_precision: if self.fp32_reduce_scatter and not self.mixed_precision:
raise ValueError("fp32_reduce_scatter requires mixed_precision=True") raise ValueError("fp32_reduce_scatter requires mixed_precision=True")
...@@ -1414,7 +1416,6 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1414,7 +1416,6 @@ class FullyShardedDataParallel(nn.Module):
if should_collect_state: if should_collect_state:
assert isinstance(sd, dict), f"{self.rank} received {type(sd)} from {rank}, expected dict" assert isinstance(sd, dict), f"{self.rank} received {type(sd)} from {rank}, expected dict"
all_states.append(recursive_copy_to_device(sd, non_blocking=False, device=torch.device("cpu"))) all_states.append(recursive_copy_to_device(sd, non_blocking=False, device=torch.device("cpu")))
return all_states return all_states
def gather_full_optim_state_dict( def gather_full_optim_state_dict(
...@@ -1459,8 +1460,12 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1459,8 +1460,12 @@ class FullyShardedDataParallel(nn.Module):
uncollected_ids = [i for i, m in enumerate(self._fsdp_instances) if m.no_broadcast_optim_state] uncollected_ids = [i for i, m in enumerate(self._fsdp_instances) if m.no_broadcast_optim_state]
new_dct = {"state": {k: v for k, v in osd["state"].items() if k not in uncollected_ids}} new_dct = {"state": {k: v for k, v in osd["state"].items() if k not in uncollected_ids}}
if self.rank == 0: if self.rank == 0:
# Save placeholders for uncollected opt state to keep the same unflat OSD format. # Save placeholders for uncollected opt state to keep the same unflat OSD format, and move them to CPU.
self.uncollected_opt_state = {k: v for k, v in osd["state"].items() if k in uncollected_ids} self.uncollected_opt_state = {
k: recursive_copy_to_device(v, non_blocking=False, device=torch.device("cpu"))
for k, v in osd["state"].items()
if k in uncollected_ids
}
pg = copy.deepcopy(osd["param_groups"]) pg = copy.deepcopy(osd["param_groups"])
new_dct["param_groups"] = pg new_dct["param_groups"] = pg
...@@ -1500,6 +1505,14 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1500,6 +1505,14 @@ class FullyShardedDataParallel(nn.Module):
return full_optim_state_dict return full_optim_state_dict
def _print_r0(self, msg: str) -> None:
"""Debugging utility to print memory usage stats nicely on rank 0"""
if self.rank == 0:
gb_denom = 1024 ** 3
print(
f"{msg} cur={torch.cuda.memory_allocated()/gb_denom: .4f} GB, max={torch.cuda.max_memory_allocated()/gb_denom: .4f} GB, t={time.time()-self._tstart: .1f}"
)
def _get_default_cuda_device(module: nn.Module) -> torch.device: def _get_default_cuda_device(module: nn.Module) -> torch.device:
"""Try to infer CUDA device from module parameters.""" """Try to infer CUDA device from module parameters."""
......
...@@ -627,14 +627,15 @@ class MixtureOfExperts(NestedWrappedModule): ...@@ -627,14 +627,15 @@ class MixtureOfExperts(NestedWrappedModule):
# "expert" params are different on each rank # "expert" params are different on each rank
torch.manual_seed(42 + group.rank()) torch.manual_seed(42 + group.rank())
expert = nn.Linear(16, 4) d_expert = 16
expert = nn.Linear(d_expert, 4)
self.num_expert_params = sum([p.numel() for p in expert.parameters()]) self.num_expert_params = sum([p.numel() for p in expert.parameters()])
for p in expert.parameters(): for p in expert.parameters():
p.expert = True p.expert = True
# everything else is shared # everything else is shared
torch.manual_seed(0) torch.manual_seed(0)
shared = nn.Linear(4, 16) shared = nn.Linear(4, d_expert)
if checkpoint_act: if checkpoint_act:
expert = checkpoint_wrapper(expert) expert = checkpoint_wrapper(expert)
......
...@@ -86,16 +86,30 @@ class TestOptimizerUtils(DistributedTest): ...@@ -86,16 +86,30 @@ class TestOptimizerUtils(DistributedTest):
no_broadcast_children = [x for x in fsdp._fsdp_instances if x.no_broadcast_optim_state] no_broadcast_children = [x for x in fsdp._fsdp_instances if x.no_broadcast_optim_state]
assert len(no_broadcast_children) == 1 assert len(no_broadcast_children) == 1
assert fsdp._fsdp_instances[-1].no_broadcast_optim_state assert fsdp._fsdp_instances[-1].no_broadcast_optim_state
torch.cuda.empty_cache()
cuda_gb_before = torch.cuda.memory_stats(fsdp.rank)["allocated_bytes.all.current"] / 1024 ** 3
tstart = time() tstart = time()
sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0) sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0)
duration = time() - tstart duration = time() - tstart
# Switching from fairscale.optim.utils.broadcast_object to torch.broadcast_object_list will cause this to raise # Switching from fairscale.optim.utils.broadcast_object to torch.broadcast_object_list will cause this to raise
assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate" assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate"
cuda_gb_after = torch.cuda.memory_stats(fsdp.rank)["allocated_bytes.all.current"] / 1024 ** 3
mem_usg_gb = cuda_gb_after - cuda_gb_before
assert mem_usg_gb == 0, f"gather_full_optim_state_dict used {mem_usg_gb:.2f} CUDA GB, max allowed is 0"
assert cuda_gb_after > 0, "got 0 memory usage, logging is broken"
if fsdp.rank > 0: if fsdp.rank > 0:
assert sd is None assert sd is None
return return
# assert whole state dict on CPU
for k, v in sd["state"].items():
for buffer_name, t in v.items():
if torch.is_tensor(t):
msg = f"got device {t.device} for {k}: {buffer_name}. expected CPU"
assert t.device == torch.device("cpu"), msg
unflat_state = sd["state"] unflat_state = sd["state"]
assert "uncollected_local_ids" in sd assert "uncollected_local_ids" in sd
shard_sd = fsdp.get_shard_from_optim_state_dict(sd) shard_sd = fsdp.get_shard_from_optim_state_dict(sd)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment