Unverified Commit 0db50ce5 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] [MEVO]: make mevo work with eval and optim_state checkpointing (#851)



* [fix]: fix eval for shared weight FSDP

* fixing optim state saving

* add changelog

* reformat with newer local isort

* update test

* avoid computing reference state unless we are testing training

* added optim_state test

* make mypy happy

* move tests; maybe we need to CUDA memory related tests in the first of the lists
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent fd831c4a
...@@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
and the file path for storing params on SSD. Note: This is an experimental feature. [#855] and the file path for storing params on SSD. Note: This is an experimental feature. [#855]
### Changed ### Changed
- MEVO: fixed eval and checkpointing code paths [#851]
- Cleanup: Moving forward we would be testing all of our code with Python 3.9.7, CUDA 11.2 and the following three versions of PyTorch [#847]: - Cleanup: Moving forward we would be testing all of our code with Python 3.9.7, CUDA 11.2 and the following three versions of PyTorch [#847]:
- the most recent stable version - the most recent stable version
- the most recent LTS version - the most recent LTS version
......
...@@ -10,7 +10,6 @@ import tempfile ...@@ -10,7 +10,6 @@ import tempfile
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import torchtext import torchtext
from torchtext.data.utils import get_tokenizer from torchtext.data.utils import get_tokenizer
from torchtext.utils import download_from_url, extract_archive from torchtext.utils import download_from_url, extract_archive
......
...@@ -18,6 +18,8 @@ import torch.multiprocessing as mp ...@@ -18,6 +18,8 @@ import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import torchtext
from torchtext.data.utils import get_tokenizer
from fairscale.experimental.nn.ampnet_pipe import pipe from fairscale.experimental.nn.ampnet_pipe import pipe
from fairscale.nn.model_parallel import initialize_model_parallel from fairscale.nn.model_parallel import initialize_model_parallel
...@@ -25,8 +27,6 @@ from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group ...@@ -25,8 +27,6 @@ from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
from fairscale.nn.pipe import LazyModule from fairscale.nn.pipe import LazyModule
from fairscale.optim import GradScaler from fairscale.optim import GradScaler
from fairscale.utils.testing import dist_init, get_worker_map from fairscale.utils.testing import dist_init, get_worker_map
import torchtext
from torchtext.data.utils import get_tokenizer
try: try:
from fairscale.optim import Adam # type: ignore from fairscale.optim import Adam # type: ignore
......
...@@ -378,7 +378,7 @@ class BackwardTrigger(nn.Module): ...@@ -378,7 +378,7 @@ class BackwardTrigger(nn.Module):
def __init__(self, linked_param: torch.Tensor): def __init__(self, linked_param: torch.Tensor):
super().__init__() super().__init__()
assert isinstance(linked_param, nn.Parameter) assert isinstance(linked_param, nn.Parameter)
self.trigger = nn.Parameter(torch.rand(1, dtype=linked_param.dtype)) self.trigger = nn.Parameter(torch.rand(1, dtype=linked_param.dtype, device=linked_param.device))
self.trigger._linked_param = linked_param self.trigger._linked_param = linked_param
def forward(self) -> torch.Tensor: # type: ignore def forward(self) -> torch.Tensor: # type: ignore
...@@ -437,7 +437,8 @@ class MemoryEfficientVocabOutput(nn.Module): # AKA. MEVO ...@@ -437,7 +437,8 @@ class MemoryEfficientVocabOutput(nn.Module): # AKA. MEVO
print("DEBUG cur, peak", cur_mem, mem) print("DEBUG cur, peak", cur_mem, mem)
assert isinstance(input, torch.Tensor) assert isinstance(input, torch.Tensor)
assert isinstance(target, torch.Tensor) assert isinstance(target, torch.Tensor)
assert input.requires_grad if torch.is_grad_enabled():
assert input.requires_grad
input, target = _reshape_inputs(input, target) input, target = _reshape_inputs(input, target)
tokens, d_model = input.shape tokens, d_model = input.shape
......
...@@ -4,12 +4,15 @@ ...@@ -4,12 +4,15 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
"""These functions are used by FullyShardedDataParallel to help consolidate and shard optimizer states.""" """These functions are used by FullyShardedDataParallel to help consolidate and shard optimizer states."""
import copy import copy
from typing import Any, Dict, Iterator, List, Tuple, cast from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Tuple, cast
import torch import torch
from fairscale.nn.misc import FlattenParamsWrapper from fairscale.nn.misc import FlattenParamsWrapper
if TYPE_CHECKING:
from fairscale.nn.data_parallel import FullyShardedDataParallel
# These return keys are used by fairseq. To change, add @sshleifer as a reviewer. # These return keys are used by fairseq. To change, add @sshleifer as a reviewer.
UNFLAT_RETURN_KEYS = {"state", "param_groups", "uncollected_local_ids", "param_id_map"} UNFLAT_RETURN_KEYS = {"state", "param_groups", "uncollected_local_ids", "param_id_map"}
...@@ -84,10 +87,11 @@ def _extract_non_tensor_state(combined_state: Dict[int, Dict[str, List]], param_ ...@@ -84,10 +87,11 @@ def _extract_non_tensor_state(combined_state: Dict[int, Dict[str, List]], param_
def _unflatten_optim_state( def _unflatten_optim_state(
combined_state: Dict[int, Dict], combined_state: Dict[int, Dict],
instance_list: List[torch.nn.Module], instance_list: List["FullyShardedDataParallel"],
world_pad_info: List[List[List[int]]], world_pad_info: List[List[List[int]]],
singleton_state: Dict[int, Dict], singleton_state: Dict[int, Dict],
) -> Tuple[Dict[int, Dict], Dict[int, int]]: ) -> Tuple[Dict[int, Dict], Dict[int, int]]:
"""Convert optimizer state for flattened parameters into original, unflatten ones."""
# local ids are the keys in the current state (combined_state), (usually fewer) # local ids are the keys in the current state (combined_state), (usually fewer)
# global ids will be the keys in the unflattened state # global ids will be the keys in the unflattened state
next_global_id = 0 # gets incremented next_global_id = 0 # gets incremented
...@@ -100,7 +104,13 @@ def _unflatten_optim_state( ...@@ -100,7 +104,13 @@ def _unflatten_optim_state(
# Local corresponds to flattened, global corresponds to unflattened. # Local corresponds to flattened, global corresponds to unflattened.
# Casting needed only for mypy. # Casting needed only for mypy.
num_global_params = [cast(int, m.num_params_managed) for m in instance_list] num_global_params: List[int] = []
for m in instance_list:
if m.flatten_parameters:
num_flatten = cast(int, m.num_params_managed)
num_global_params.append(num_flatten)
else:
num_global_params.append(len(m.non_shared_params()))
global_to_local_id = {} global_to_local_id = {}
for local_id, num_unflat in enumerate(num_global_params): for local_id, num_unflat in enumerate(num_global_params):
for _ in range(num_unflat): for _ in range(num_unflat):
...@@ -129,18 +139,26 @@ def _unflatten_optim_state( ...@@ -129,18 +139,26 @@ def _unflatten_optim_state(
assert isinstance(v, list), f"got {k}: {v} for {local_id}" assert isinstance(v, list), f"got {k}: {v} for {local_id}"
v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, pad_info[local_id])] v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, pad_info[local_id])]
flat_buffer = torch.cat(v_unpad) flat_buffer = torch.cat(v_unpad)
# Casting needed only for mypy. if instance_list[local_id].flatten_parameters:
param_views: Iterator = cast(FlattenParamsWrapper, instance_list[local_id]).get_param_views([flat_buffer]) # Unflatten. Casting needed only for mypy.
for global_id, param_view in zip(sorted(local_to_global[local_id]), param_views): param_views: Iterator = cast(FlattenParamsWrapper, instance_list[local_id]).get_param_views(
assert k not in unflat_state[global_id], f"already added {k} to {global_id} {local_id}" [flat_buffer]
unflat_state[global_id][k] = param_view )
unflat_state[global_id].update(singleton_state[local_id]) for global_id, param_view in zip(sorted(local_to_global[local_id]), param_views):
assert k not in unflat_state[global_id], f"already added {k} to {global_id} {local_id}"
unflat_state[global_id][k] = param_view
else:
# Copy non-flatten state directly.
assert len(local_to_global[local_id]) == 1, "Only support a single non-flatten parameter"
global_id = local_to_global[local_id][0]
unflat_state[global_id][k] = flat_buffer
unflat_state[global_id].update(singleton_state[local_id])
return unflat_state, global_to_local_id return unflat_state, global_to_local_id
def build_unflat_state_dict( def build_unflat_state_dict(
instance_list: List[torch.nn.Module], instance_list: List["FullyShardedDataParallel"],
world_pad_info: List[List[List[int]]], world_pad_info: List[List[List[int]]],
state: Dict[int, Dict[str, List[torch.Tensor]]], state: Dict[int, Dict[str, List[torch.Tensor]]],
singleton_state: Dict[int, Dict[str, List[torch.Tensor]]], singleton_state: Dict[int, Dict[str, List[torch.Tensor]]],
......
...@@ -353,6 +353,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -353,6 +353,7 @@ class FullyShardedDataParallel(nn.Module):
params.append(param) params.append(param)
self._has_params = len(params) > 0 self._has_params = len(params) > 0
self._has_shared_params = False
# TODO(anj): Should we conditionally do this only if we have params? # TODO(anj): Should we conditionally do this only if we have params?
# TODO(anj): Figure out if we can allocate the buffer during sharding. # TODO(anj): Figure out if we can allocate the buffer during sharding.
...@@ -492,6 +493,14 @@ class FullyShardedDataParallel(nn.Module): ...@@ -492,6 +493,14 @@ class FullyShardedDataParallel(nn.Module):
len(list(filter(lambda p: not (hasattr(p, "_is_shared") and p._is_shared), self.params))) > 0 len(list(filter(lambda p: not (hasattr(p, "_is_shared") and p._is_shared), self.params))) > 0
), "Must have at least 1 non-shared param." ), "Must have at least 1 non-shared param."
self.params.append(p) self.params.append(p)
self._has_shared_params = True
def non_shared_params(self) -> List[nn.Parameter]:
"""Return the list of non-shared parameters."""
if self._has_shared_params:
return list(filter(lambda p: not (hasattr(p, "_is_shared") and p._is_shared), self.params))
else:
return self.params
def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel": def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel":
""" """
...@@ -1050,9 +1059,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1050,9 +1059,7 @@ class FullyShardedDataParallel(nn.Module):
non_shared_params = self.params non_shared_params = self.params
# filter out shared params for all but the owner FSDP module. # filter out shared params for all but the owner FSDP module.
if len(full_tensors) < len(non_shared_params): if len(full_tensors) < len(non_shared_params):
non_shared_params = list( non_shared_params = self.non_shared_params()
filter(lambda p: not (hasattr(p, "_is_shared") and p._is_shared), self.params)
)
assert len(full_tensors) == len( assert len(full_tensors) == len(
non_shared_params non_shared_params
), f"{len(full_tensors)} vs. {len(non_shared_params)}" ), f"{len(full_tensors)} vs. {len(non_shared_params)}"
...@@ -1809,6 +1816,18 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1809,6 +1816,18 @@ class FullyShardedDataParallel(nn.Module):
self.has_full_params = False self.has_full_params = False
if self._has_shared_params:
# self.has_full_params flag can be out of sync if a shared param is
# sharded by another FSDP instance. An example is that in eval case
# with reshard_after_forward=False but the sharing instance has
# reshard_after_forward=True. Then, on the second forward, the
# other instance can shard the shared param and but this instance
# can mistakenly think the full param is already gathered from the
# has_full_params flag.
#
# Therefore, we update the flag accordingly here.
self.has_full_params = not any(p._full_param_padded.storage().size() == 0 for p in self.params)
# Early exit if we already have full params and don't need full precision. # Early exit if we already have full params and don't need full precision.
if self.has_full_params and not force_full_precision: if self.has_full_params and not force_full_precision:
for p in self.params: for p in self.params:
...@@ -2148,7 +2167,14 @@ class FullyShardedDataParallel(nn.Module): ...@@ -2148,7 +2167,14 @@ class FullyShardedDataParallel(nn.Module):
for k, v in sd_state.items(): for k, v in sd_state.items():
gathered_state[k] = {} gathered_state[k] = {}
singleton_state[k] = {} singleton_state[k] = {}
desired_buffer_size = self._fsdp_instances[k].flat_param._full_param_padded.size() # type: ignore # For shared params, we are not flattening. We have only 1 non-shared
# param that has the optimizer state. So we handle it with the correct
# parameter list.
non_shared_params = cast(FullyShardedDataParallel, self._fsdp_instances[k]).non_shared_params()
assert (
len(non_shared_params) == 1
), f"Only flatten param or a single non-shared param is supported: len={len(non_shared_params)}"
desired_buffer_size = non_shared_params[0]._full_param_padded.size()
buffer = None # for sharded tensors buffer = None # for sharded tensors
singleton_buffer = None # for singleton tensors singleton_buffer = None # for singleton tensors
for buffer_name, t in v.items(): for buffer_name, t in v.items():
...@@ -2214,7 +2240,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -2214,7 +2240,7 @@ class FullyShardedDataParallel(nn.Module):
return new_state_dict return new_state_dict
@property @property
def _fsdp_instances(self) -> List[nn.Module]: def _fsdp_instances(self) -> List["FullyShardedDataParallel"]:
"""Returns all fsdp modules in self.modules() including self.""" """Returns all fsdp modules in self.modules() including self."""
return [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)] return [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)]
......
...@@ -238,7 +238,9 @@ class FlattenParamsWrapper(nn.Module): ...@@ -238,7 +238,9 @@ class FlattenParamsWrapper(nn.Module):
"""We used to support only a single flat_param. This allows us to """We used to support only a single flat_param. This allows us to
be backward compatible. be backward compatible.
""" """
assert len(self.flat_params) == 1, "Incorrect access to flat_param" assert (
len(self.flat_params) == 1
), f"Incorrect access to flat_param: len(self.flat_params)={len(self.flat_params)}"
return self.flat_params[0] return self.flat_params[0]
def _init_flatten_params( def _init_flatten_params(
......
...@@ -27,4 +27,4 @@ use_parentheses = true ...@@ -27,4 +27,4 @@ use_parentheses = true
skip_glob = ["build/*", "stubs/*"] skip_glob = ["build/*", "stubs/*"]
# Don't split "import" and "from". # Don't split "import" and "from".
force_sort_within_sections = true force_sort_within_sections = true
known_third_party = ["benchmark_dataset", "datasets", "golden_configs", "models", "numpy", "parameterized", "pytest", "recommonmark", "setuptools", "torch", "torchvision"] known_third_party = ["benchmark_dataset", "datasets", "golden_configs", "models", "numpy", "parameterized", "pytest", "recommonmark", "setuptools", "torch", "torchtext", "torchvision"]
...@@ -11,8 +11,6 @@ tests/utils/test_containers.py ...@@ -11,8 +11,6 @@ tests/utils/test_containers.py
tests/utils/test_parallel.py tests/utils/test_parallel.py
tests/utils/test_state_dict.py tests/utils/test_state_dict.py
tests/utils/test_version.py tests/utils/test_version.py
tests/nn/checkpoint/test_checkpoint_activations.py
tests/nn/checkpoint/test_checkpoint_activations_norm.py
tests/nn/misc/test_grad_bucket.py tests/nn/misc/test_grad_bucket.py
tests/nn/misc/test_param_bucket.py tests/nn/misc/test_param_bucket.py
tests/nn/wrap/test_wrap.py tests/nn/wrap/test_wrap.py
......
tests/nn/checkpoint/test_checkpoint_activations.py
tests/nn/checkpoint/test_checkpoint_activations_norm.py
tests/nn/data_parallel/test_fsdp_multiple_forward_checkpoint.py tests/nn/data_parallel/test_fsdp_multiple_forward_checkpoint.py
tests/nn/misc/test_grad_bucket.py tests/nn/misc/test_grad_bucket.py
tests/nn/misc/test_param_bucket.py tests/nn/misc/test_param_bucket.py
......
...@@ -50,7 +50,7 @@ class Model(nn.Module): ...@@ -50,7 +50,7 @@ class Model(nn.Module):
self.ln2 = nn.LayerNorm(D_MODEL).cuda().half() self.ln2 = nn.LayerNorm(D_MODEL).cuda().half()
if with_fsdp: if with_fsdp:
# Shared layers much be un-flatten. # Shared layers must be un-flatten.
self.l0 = FSDP(self.l0, flatten_parameters=False, mixed_precision=False, compute_dtype=torch.float16) self.l0 = FSDP(self.l0, flatten_parameters=False, mixed_precision=False, compute_dtype=torch.float16)
self.l1 = FSDP(self.l1, flatten_parameters=False, mixed_precision=False, compute_dtype=torch.float16) self.l1 = FSDP(self.l1, flatten_parameters=False, mixed_precision=False, compute_dtype=torch.float16)
self.l1.append_shared_param(self.l0.module.weight) self.l1.append_shared_param(self.l0.module.weight)
...@@ -89,38 +89,46 @@ def temp_files(): ...@@ -89,38 +89,46 @@ def temp_files():
@skip_if_single_gpu @skip_if_single_gpu
@pytest.mark.parametrize("wrap_middle", ["none", "flat", "nonflat"]) @pytest.mark.parametrize("wrap_middle", ["none", "flat", "nonflat"])
def test_shared_weight_mevo(temp_files, wrap_middle): @pytest.mark.parametrize("test_fn", ["train", "eval", "optim_state"])
def test_shared_weight_mevo(temp_files, wrap_middle, test_fn):
"""Test FSDP with a model with shared weights.""" """Test FSDP with a model with shared weights."""
if test_fn == "optim_state":
if wrap_middle != "flat":
pytest.skip("only support optim_state when root and middle part is flat")
world_size = 2 world_size = 2
# Get ref. # Get ref.
model = Model() model = Model()
sd_before = deepcopy(model.state_dict()) sd_before = deepcopy(model.state_dict())
in_data = (torch.rand(BS, SEQ) * (VOCAB - 1)).cuda().long() in_data = (torch.rand(BS, SEQ) * (VOCAB - 1)).cuda().long()
_train(model, in_data, world_size) if test_fn == "train":
sd_after = deepcopy(model.state_dict()) _train(model, in_data, world_size)
# Before and after state should not be equal. sd_after = deepcopy(model.state_dict())
assert not objects_are_equal(sd_before, sd_after) # Before and after state should not be equal.
assert not objects_are_equal(sd_before, sd_after)
# Save data # Save data
torch.save(sd_before, temp_files[2]) torch.save(sd_before, temp_files[2])
torch.save(sd_after, temp_files[3]) if test_fn == "train":
torch.save(sd_after, temp_files[3])
torch.save(in_data, temp_files[4]) torch.save(in_data, temp_files[4])
# Run FSDP # Run FSDP
mp.spawn( mp.spawn(
_dist_worker, _dist_worker,
(world_size, temp_files, wrap_middle), (world_size, temp_files, wrap_middle, test_fn),
nprocs=world_size, nprocs=world_size,
) )
def _dist_worker(rank, world_size, files, wrap_middle): def _dist_worker(rank, world_size, files, wrap_middle, test_fn):
# Get data from files. # Get data from files.
file1, file2, sd_before, sd_after, in_data = files file1, file2, sd_before, sd_after, in_data = files
sd_before = torch.load(sd_before, map_location=lambda storage, loc: storage.cuda(rank)) sd_before = torch.load(sd_before, map_location=lambda storage, loc: storage.cuda(rank))
sd_after = torch.load(sd_after, map_location=lambda storage, loc: storage.cuda(rank)) if test_fn == "train":
sd_after = torch.load(sd_after, map_location=lambda storage, loc: storage.cuda(rank))
in_data = torch.load(in_data, map_location=lambda storage, loc: storage.cuda(rank)) in_data = torch.load(in_data, map_location=lambda storage, loc: storage.cuda(rank))
result = dist_init(rank=rank, world_size=world_size, filename=file1, filename_rpc=file2) result = dist_init(rank=rank, world_size=world_size, filename=file1, filename_rpc=file2)
...@@ -130,19 +138,46 @@ def _dist_worker(rank, world_size, files, wrap_middle): ...@@ -130,19 +138,46 @@ def _dist_worker(rank, world_size, files, wrap_middle):
# To debug: first make with_fsdp=False (no inner wrapping) work, then enable inner wrapping # To debug: first make with_fsdp=False (no inner wrapping) work, then enable inner wrapping
# and make that work. # and make that work.
Model(with_fsdp=True, wrap_middle=wrap_middle), Model(with_fsdp=True, wrap_middle=wrap_middle),
flatten_parameters=False, flatten_parameters=test_fn == "optim_state",
mixed_precision=False, mixed_precision=False,
compute_dtype=torch.float16, compute_dtype=torch.float16,
) )
fsdp_model.load_state_dict(sd_before) fsdp_model.load_state_dict(sd_before)
_train(fsdp_model, in_data) if test_fn == "train":
_train(fsdp_model, in_data)
objects_are_equal(sd_after, fsdp_model.state_dict(), raise_exception=True) objects_are_equal(sd_after, fsdp_model.state_dict(), raise_exception=True)
elif test_fn == "eval":
_eval(fsdp_model, in_data)
elif test_fn == "optim_state":
optim = SGD(fsdp_model.parameters(), lr=0.1)
for _ in range(3):
out = fsdp_model(in_data)
out.backward()
optim.step()
sd = fsdp_model.gather_full_optim_state_dict(optim)
if rank == 0:
# There should 8 momentum buffers in the state.
assert len(sd["state"].keys()) == 8
else:
assert sd is None, "only rank 0 should have the optim state"
else:
assert 0, f"invalid test_fn {test_fn}"
teardown() teardown()
def _eval(model, in_data):
# run in eval mode
model.eval()
for _ in range(5):
out = model(in_data)
# adding torch.no_grad()
for _ in range(5):
with torch.no_grad():
out = model(in_data)
def _train(model, in_data, steps_per_iter=1): def _train(model, in_data, steps_per_iter=1):
optim = SGD(model.parameters(), lr=0.1) optim = SGD(model.parameters(), lr=0.1)
for _ in range(3): for _ in range(3):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment