Unverified Commit 0db50ce5 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] [MEVO]: make mevo work with eval and optim_state checkpointing (#851)



* [fix]: fix eval for shared weight FSDP

* fixing optim state saving

* add changelog

* reformat with newer local isort

* update test

* avoid computing reference state unless we are testing training

* added optim_state test

* make mypy happy

* move tests; maybe we need to CUDA memory related tests in the first of the lists
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent fd831c4a
......@@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
and the file path for storing params on SSD. Note: This is an experimental feature. [#855]
### Changed
- MEVO: fixed eval and checkpointing code paths [#851]
- Cleanup: Moving forward we would be testing all of our code with Python 3.9.7, CUDA 11.2 and the following three versions of PyTorch [#847]:
- the most recent stable version
- the most recent LTS version
......
......@@ -10,7 +10,6 @@ import tempfile
import torch
from torch.utils.data import DataLoader
import torchtext
from torchtext.data.utils import get_tokenizer
from torchtext.utils import download_from_url, extract_archive
......
......@@ -18,6 +18,8 @@ import torch.multiprocessing as mp
import torch.nn as nn
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
import torchtext
from torchtext.data.utils import get_tokenizer
from fairscale.experimental.nn.ampnet_pipe import pipe
from fairscale.nn.model_parallel import initialize_model_parallel
......@@ -25,8 +27,6 @@ from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
from fairscale.nn.pipe import LazyModule
from fairscale.optim import GradScaler
from fairscale.utils.testing import dist_init, get_worker_map
import torchtext
from torchtext.data.utils import get_tokenizer
try:
from fairscale.optim import Adam # type: ignore
......
......@@ -378,7 +378,7 @@ class BackwardTrigger(nn.Module):
def __init__(self, linked_param: torch.Tensor):
super().__init__()
assert isinstance(linked_param, nn.Parameter)
self.trigger = nn.Parameter(torch.rand(1, dtype=linked_param.dtype))
self.trigger = nn.Parameter(torch.rand(1, dtype=linked_param.dtype, device=linked_param.device))
self.trigger._linked_param = linked_param
def forward(self) -> torch.Tensor: # type: ignore
......@@ -437,7 +437,8 @@ class MemoryEfficientVocabOutput(nn.Module): # AKA. MEVO
print("DEBUG cur, peak", cur_mem, mem)
assert isinstance(input, torch.Tensor)
assert isinstance(target, torch.Tensor)
assert input.requires_grad
if torch.is_grad_enabled():
assert input.requires_grad
input, target = _reshape_inputs(input, target)
tokens, d_model = input.shape
......
......@@ -4,12 +4,15 @@
# LICENSE file in the root directory of this source tree.
"""These functions are used by FullyShardedDataParallel to help consolidate and shard optimizer states."""
import copy
from typing import Any, Dict, Iterator, List, Tuple, cast
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Tuple, cast
import torch
from fairscale.nn.misc import FlattenParamsWrapper
if TYPE_CHECKING:
from fairscale.nn.data_parallel import FullyShardedDataParallel
# These return keys are used by fairseq. To change, add @sshleifer as a reviewer.
UNFLAT_RETURN_KEYS = {"state", "param_groups", "uncollected_local_ids", "param_id_map"}
......@@ -84,10 +87,11 @@ def _extract_non_tensor_state(combined_state: Dict[int, Dict[str, List]], param_
def _unflatten_optim_state(
combined_state: Dict[int, Dict],
instance_list: List[torch.nn.Module],
instance_list: List["FullyShardedDataParallel"],
world_pad_info: List[List[List[int]]],
singleton_state: Dict[int, Dict],
) -> Tuple[Dict[int, Dict], Dict[int, int]]:
"""Convert optimizer state for flattened parameters into original, unflatten ones."""
# local ids are the keys in the current state (combined_state), (usually fewer)
# global ids will be the keys in the unflattened state
next_global_id = 0 # gets incremented
......@@ -100,7 +104,13 @@ def _unflatten_optim_state(
# Local corresponds to flattened, global corresponds to unflattened.
# Casting needed only for mypy.
num_global_params = [cast(int, m.num_params_managed) for m in instance_list]
num_global_params: List[int] = []
for m in instance_list:
if m.flatten_parameters:
num_flatten = cast(int, m.num_params_managed)
num_global_params.append(num_flatten)
else:
num_global_params.append(len(m.non_shared_params()))
global_to_local_id = {}
for local_id, num_unflat in enumerate(num_global_params):
for _ in range(num_unflat):
......@@ -129,18 +139,26 @@ def _unflatten_optim_state(
assert isinstance(v, list), f"got {k}: {v} for {local_id}"
v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, pad_info[local_id])]
flat_buffer = torch.cat(v_unpad)
# Casting needed only for mypy.
param_views: Iterator = cast(FlattenParamsWrapper, instance_list[local_id]).get_param_views([flat_buffer])
for global_id, param_view in zip(sorted(local_to_global[local_id]), param_views):
assert k not in unflat_state[global_id], f"already added {k} to {global_id} {local_id}"
unflat_state[global_id][k] = param_view
unflat_state[global_id].update(singleton_state[local_id])
if instance_list[local_id].flatten_parameters:
# Unflatten. Casting needed only for mypy.
param_views: Iterator = cast(FlattenParamsWrapper, instance_list[local_id]).get_param_views(
[flat_buffer]
)
for global_id, param_view in zip(sorted(local_to_global[local_id]), param_views):
assert k not in unflat_state[global_id], f"already added {k} to {global_id} {local_id}"
unflat_state[global_id][k] = param_view
else:
# Copy non-flatten state directly.
assert len(local_to_global[local_id]) == 1, "Only support a single non-flatten parameter"
global_id = local_to_global[local_id][0]
unflat_state[global_id][k] = flat_buffer
unflat_state[global_id].update(singleton_state[local_id])
return unflat_state, global_to_local_id
def build_unflat_state_dict(
instance_list: List[torch.nn.Module],
instance_list: List["FullyShardedDataParallel"],
world_pad_info: List[List[List[int]]],
state: Dict[int, Dict[str, List[torch.Tensor]]],
singleton_state: Dict[int, Dict[str, List[torch.Tensor]]],
......
......@@ -353,6 +353,7 @@ class FullyShardedDataParallel(nn.Module):
params.append(param)
self._has_params = len(params) > 0
self._has_shared_params = False
# TODO(anj): Should we conditionally do this only if we have params?
# TODO(anj): Figure out if we can allocate the buffer during sharding.
......@@ -492,6 +493,14 @@ class FullyShardedDataParallel(nn.Module):
len(list(filter(lambda p: not (hasattr(p, "_is_shared") and p._is_shared), self.params))) > 0
), "Must have at least 1 non-shared param."
self.params.append(p)
self._has_shared_params = True
def non_shared_params(self) -> List[nn.Parameter]:
"""Return the list of non-shared parameters."""
if self._has_shared_params:
return list(filter(lambda p: not (hasattr(p, "_is_shared") and p._is_shared), self.params))
else:
return self.params
def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel":
"""
......@@ -1050,9 +1059,7 @@ class FullyShardedDataParallel(nn.Module):
non_shared_params = self.params
# filter out shared params for all but the owner FSDP module.
if len(full_tensors) < len(non_shared_params):
non_shared_params = list(
filter(lambda p: not (hasattr(p, "_is_shared") and p._is_shared), self.params)
)
non_shared_params = self.non_shared_params()
assert len(full_tensors) == len(
non_shared_params
), f"{len(full_tensors)} vs. {len(non_shared_params)}"
......@@ -1809,6 +1816,18 @@ class FullyShardedDataParallel(nn.Module):
self.has_full_params = False
if self._has_shared_params:
# self.has_full_params flag can be out of sync if a shared param is
# sharded by another FSDP instance. An example is that in eval case
# with reshard_after_forward=False but the sharing instance has
# reshard_after_forward=True. Then, on the second forward, the
# other instance can shard the shared param and but this instance
# can mistakenly think the full param is already gathered from the
# has_full_params flag.
#
# Therefore, we update the flag accordingly here.
self.has_full_params = not any(p._full_param_padded.storage().size() == 0 for p in self.params)
# Early exit if we already have full params and don't need full precision.
if self.has_full_params and not force_full_precision:
for p in self.params:
......@@ -2148,7 +2167,14 @@ class FullyShardedDataParallel(nn.Module):
for k, v in sd_state.items():
gathered_state[k] = {}
singleton_state[k] = {}
desired_buffer_size = self._fsdp_instances[k].flat_param._full_param_padded.size() # type: ignore
# For shared params, we are not flattening. We have only 1 non-shared
# param that has the optimizer state. So we handle it with the correct
# parameter list.
non_shared_params = cast(FullyShardedDataParallel, self._fsdp_instances[k]).non_shared_params()
assert (
len(non_shared_params) == 1
), f"Only flatten param or a single non-shared param is supported: len={len(non_shared_params)}"
desired_buffer_size = non_shared_params[0]._full_param_padded.size()
buffer = None # for sharded tensors
singleton_buffer = None # for singleton tensors
for buffer_name, t in v.items():
......@@ -2214,7 +2240,7 @@ class FullyShardedDataParallel(nn.Module):
return new_state_dict
@property
def _fsdp_instances(self) -> List[nn.Module]:
def _fsdp_instances(self) -> List["FullyShardedDataParallel"]:
"""Returns all fsdp modules in self.modules() including self."""
return [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)]
......
......@@ -238,7 +238,9 @@ class FlattenParamsWrapper(nn.Module):
"""We used to support only a single flat_param. This allows us to
be backward compatible.
"""
assert len(self.flat_params) == 1, "Incorrect access to flat_param"
assert (
len(self.flat_params) == 1
), f"Incorrect access to flat_param: len(self.flat_params)={len(self.flat_params)}"
return self.flat_params[0]
def _init_flatten_params(
......
......@@ -27,4 +27,4 @@ use_parentheses = true
skip_glob = ["build/*", "stubs/*"]
# Don't split "import" and "from".
force_sort_within_sections = true
known_third_party = ["benchmark_dataset", "datasets", "golden_configs", "models", "numpy", "parameterized", "pytest", "recommonmark", "setuptools", "torch", "torchvision"]
known_third_party = ["benchmark_dataset", "datasets", "golden_configs", "models", "numpy", "parameterized", "pytest", "recommonmark", "setuptools", "torch", "torchtext", "torchvision"]
......@@ -11,8 +11,6 @@ tests/utils/test_containers.py
tests/utils/test_parallel.py
tests/utils/test_state_dict.py
tests/utils/test_version.py
tests/nn/checkpoint/test_checkpoint_activations.py
tests/nn/checkpoint/test_checkpoint_activations_norm.py
tests/nn/misc/test_grad_bucket.py
tests/nn/misc/test_param_bucket.py
tests/nn/wrap/test_wrap.py
......
tests/nn/checkpoint/test_checkpoint_activations.py
tests/nn/checkpoint/test_checkpoint_activations_norm.py
tests/nn/data_parallel/test_fsdp_multiple_forward_checkpoint.py
tests/nn/misc/test_grad_bucket.py
tests/nn/misc/test_param_bucket.py
......
......@@ -50,7 +50,7 @@ class Model(nn.Module):
self.ln2 = nn.LayerNorm(D_MODEL).cuda().half()
if with_fsdp:
# Shared layers much be un-flatten.
# Shared layers must be un-flatten.
self.l0 = FSDP(self.l0, flatten_parameters=False, mixed_precision=False, compute_dtype=torch.float16)
self.l1 = FSDP(self.l1, flatten_parameters=False, mixed_precision=False, compute_dtype=torch.float16)
self.l1.append_shared_param(self.l0.module.weight)
......@@ -89,38 +89,46 @@ def temp_files():
@skip_if_single_gpu
@pytest.mark.parametrize("wrap_middle", ["none", "flat", "nonflat"])
def test_shared_weight_mevo(temp_files, wrap_middle):
@pytest.mark.parametrize("test_fn", ["train", "eval", "optim_state"])
def test_shared_weight_mevo(temp_files, wrap_middle, test_fn):
"""Test FSDP with a model with shared weights."""
if test_fn == "optim_state":
if wrap_middle != "flat":
pytest.skip("only support optim_state when root and middle part is flat")
world_size = 2
# Get ref.
model = Model()
sd_before = deepcopy(model.state_dict())
in_data = (torch.rand(BS, SEQ) * (VOCAB - 1)).cuda().long()
_train(model, in_data, world_size)
sd_after = deepcopy(model.state_dict())
# Before and after state should not be equal.
assert not objects_are_equal(sd_before, sd_after)
if test_fn == "train":
_train(model, in_data, world_size)
sd_after = deepcopy(model.state_dict())
# Before and after state should not be equal.
assert not objects_are_equal(sd_before, sd_after)
# Save data
torch.save(sd_before, temp_files[2])
torch.save(sd_after, temp_files[3])
if test_fn == "train":
torch.save(sd_after, temp_files[3])
torch.save(in_data, temp_files[4])
# Run FSDP
mp.spawn(
_dist_worker,
(world_size, temp_files, wrap_middle),
(world_size, temp_files, wrap_middle, test_fn),
nprocs=world_size,
)
def _dist_worker(rank, world_size, files, wrap_middle):
def _dist_worker(rank, world_size, files, wrap_middle, test_fn):
# Get data from files.
file1, file2, sd_before, sd_after, in_data = files
sd_before = torch.load(sd_before, map_location=lambda storage, loc: storage.cuda(rank))
sd_after = torch.load(sd_after, map_location=lambda storage, loc: storage.cuda(rank))
if test_fn == "train":
sd_after = torch.load(sd_after, map_location=lambda storage, loc: storage.cuda(rank))
in_data = torch.load(in_data, map_location=lambda storage, loc: storage.cuda(rank))
result = dist_init(rank=rank, world_size=world_size, filename=file1, filename_rpc=file2)
......@@ -130,19 +138,46 @@ def _dist_worker(rank, world_size, files, wrap_middle):
# To debug: first make with_fsdp=False (no inner wrapping) work, then enable inner wrapping
# and make that work.
Model(with_fsdp=True, wrap_middle=wrap_middle),
flatten_parameters=False,
flatten_parameters=test_fn == "optim_state",
mixed_precision=False,
compute_dtype=torch.float16,
)
fsdp_model.load_state_dict(sd_before)
_train(fsdp_model, in_data)
objects_are_equal(sd_after, fsdp_model.state_dict(), raise_exception=True)
if test_fn == "train":
_train(fsdp_model, in_data)
objects_are_equal(sd_after, fsdp_model.state_dict(), raise_exception=True)
elif test_fn == "eval":
_eval(fsdp_model, in_data)
elif test_fn == "optim_state":
optim = SGD(fsdp_model.parameters(), lr=0.1)
for _ in range(3):
out = fsdp_model(in_data)
out.backward()
optim.step()
sd = fsdp_model.gather_full_optim_state_dict(optim)
if rank == 0:
# There should 8 momentum buffers in the state.
assert len(sd["state"].keys()) == 8
else:
assert sd is None, "only rank 0 should have the optim state"
else:
assert 0, f"invalid test_fn {test_fn}"
teardown()
def _eval(model, in_data):
# run in eval mode
model.eval()
for _ in range(5):
out = model(in_data)
# adding torch.no_grad()
for _ in range(5):
with torch.no_grad():
out = model(in_data)
def _train(model, in_data, steps_per_iter=1):
optim = SGD(model.parameters(), lr=0.1)
for _ in range(3):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment