Unverified Commit 468874c8 authored by Shruti Bhosale's avatar Shruti Bhosale Committed by GitHub
Browse files

FSDP: Fix saving and loading checkpoints with use_sharded_state=True (#574)



* fix saving and loading checkpoints with use_sharded_state=True

* mypy fix

* better fix of the infinite recursion

- we need to specifically call FSDP.state_dict from its local state_dict
- added unit test that fails without the fix and works with the fix
- fixed mypy for the overloaded functions

* make cpu-only fsdp work for state_dict at least
Co-authored-by: default avatarMin Xu <min.xu@acm.org>
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
Co-authored-by: default avatarMin Xu <m1n@fb.com>
parent bde4bac5
......@@ -11,7 +11,8 @@ import logging
from math import inf
import time
import traceback
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, NamedTuple, Optional, Set, Tuple, Union
import typing
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Mapping, NamedTuple, Optional, Set, Tuple, Union
import torch
from torch.autograd import Variable
......@@ -616,8 +617,18 @@ class FullyShardedDataParallel(nn.Module):
del self.orig_sizes
self._reset_lazy_init()
# TODO (Min): figuring out how to do typing for this overloaded function.
def state_dict(self, *args: Any, **kwargs: Any) -> "OrderedDict[str, torch.Tensor]": # type: ignore
@typing.overload
def state_dict(
self, destination: Mapping[str, torch.Tensor], prefix: str = ..., keep_vars: bool = ...
) -> Mapping[str, torch.Tensor]:
...
@typing.overload
def state_dict(self, prefix: str = ..., keep_vars: bool = ...) -> "OrderedDict[str, torch.Tensor]":
...
# Since we have overloads above, we can use Any here.
def state_dict(self, *args: Any, **kwargs: Any) -> Any:
"""
Returns the whole (unsharded) state of the module. Parameters are not
sharded, so the resulting state_dict can be loaded directly by the
......@@ -627,6 +638,7 @@ class FullyShardedDataParallel(nn.Module):
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
"""
if torch.cuda.is_available():
torch.cuda.synchronize()
self._lazy_init()
if self.mixed_precision:
......@@ -655,8 +667,18 @@ class FullyShardedDataParallel(nn.Module):
self._cast_buffers()
return state_dict
# TODO (Min): figuring out how to do typing for this overloaded function.
def local_state_dict(self, *args, **kwargs): # type: ignore
@typing.overload
def local_state_dict(
self, destination: Mapping[str, torch.Tensor], prefix: str = ..., keep_vars: bool = ...
) -> Mapping[str, torch.Tensor]:
...
@typing.overload
def local_state_dict(self, prefix: str = ..., keep_vars: bool = ...) -> "OrderedDict[str, torch.Tensor]":
...
# Since we have overloads above, we can use Any here.
def local_state_dict(self, *args: Any, **kwargs: Any) -> Any:
"""
Returns the local (sharded) state of the module. Parameters are sharded,
so the resulting state_dict can only be loaded after the Module has been
......@@ -667,7 +689,9 @@ class FullyShardedDataParallel(nn.Module):
for module in self.modules(): # includes self
if isinstance(module, FullyShardedDataParallel):
stack.enter_context(module._no_return_full_state_dict())
return self.state_dict(*args, **kwargs)
# We need to specially call FSDP's state_dict function in case
# self.state_dict is a function from a child class of FSDP.
return FullyShardedDataParallel.state_dict(self, *args, **kwargs)
@contextlib.contextmanager
def _no_return_full_state_dict(self) -> Generator:
......@@ -678,7 +702,7 @@ class FullyShardedDataParallel(nn.Module):
finally:
self._return_full_state_dict = backup
def load_state_dict(
def _load_state_dict(
self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
) -> NamedTuple:
"""
......@@ -695,6 +719,11 @@ class FullyShardedDataParallel(nn.Module):
self._lazy_init()
return self.module.load_state_dict(state_dict, strict)
def load_state_dict(
self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
) -> NamedTuple:
return self._load_state_dict(state_dict, strict)
def load_local_state_dict(
self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
) -> NamedTuple:
......@@ -704,7 +733,7 @@ class FullyShardedDataParallel(nn.Module):
for module in self.modules(): # includes self
if isinstance(module, FullyShardedDataParallel):
stack.enter_context(module._no_return_full_state_dict())
output = self.load_state_dict(state_dict, strict)
output = self._load_state_dict(state_dict, strict)
return output
@contextlib.contextmanager
......@@ -961,12 +990,15 @@ class FullyShardedDataParallel(nn.Module):
"""Create streams to overlap data transfer and computation."""
if len(self._streams) > 0 or not self._is_root:
return
if torch.cuda.is_available():
# Stream to move main FP32 params (may be on CPU) to FP16 for forward.
self._streams["fp32_to_fp16"] = torch.cuda.Stream()
# Stream for all-gathering parameters.
self._streams["all_gather"] = torch.cuda.Stream()
# Stream for overlapping grad reduction with the backward pass.
self._streams["post_backward"] = torch.cuda.Stream()
# Helper for bucketing reduce-scatter ops. This is also shared with
# children instances to improve bucket utilization.
self._reducer = ReduceScatterBucketer(self.bucket_cap_mb)
......@@ -984,6 +1016,8 @@ class FullyShardedDataParallel(nn.Module):
instance) needs to synchronize with the default stream to ensure the
previous optimizer step is done.
"""
if not torch.cuda.is_available():
return
if self.mixed_precision:
self._streams["fp32_to_fp16"].wait_stream(torch.cuda.current_stream())
else:
......
......@@ -2,7 +2,8 @@
# Licensed under the MIT License.
from contextlib import ExitStack, contextmanager
from typing import TYPE_CHECKING, Any, Dict, Generator, List, NamedTuple, Optional, Tuple, Union
import typing
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Mapping, NamedTuple, Optional, Tuple, Union
import torch
from torch import Tensor
......@@ -201,7 +202,18 @@ class FlattenParamsWrapper(nn.Module):
except AttributeError:
return getattr(self.module, name) # fallback to wrapped module
def state_dict(self, *args: Any, **kwargs: Any) -> "OrderedDict[str, Tensor]": # type: ignore
@typing.overload
def state_dict(
self, destination: Mapping[str, Tensor], prefix: str = ..., keep_vars: bool = ...
) -> Mapping[str, Tensor]:
...
@typing.overload
def state_dict(self, prefix: str = ..., keep_vars: bool = ...) -> "OrderedDict[str, Tensor]":
...
# Since we have overloads above, we can use Any here.
def state_dict(self, *args: Any, **kwargs: Any) -> Any:
"""Return the wrapped module's state_dict (unflattened)."""
if self.is_flattened and self._auto_unflatten_state_dict:
with self.unflatten_params(recurse=False):
......
......@@ -57,6 +57,8 @@ if TYPE_CHECKING:
else:
Base = nn.Module
skip_if_cuda = pytest.mark.skipif(torch.cuda.is_available(), reason="Testing only on CPUs to save time")
skip_if_no_cuda = pytest.mark.skipif(
not torch.cuda.is_available() or torch.cuda.device_count() < 1, reason="CUDA required"
)
......@@ -75,7 +77,7 @@ skip_if_py38 = pytest.mark.skipif(
skip_if_py39_no_cuda = pytest.mark.skipif(
not torch.cuda.is_available() and sys.version_info.major == 3 and sys.version_info.minor == 9,
reason="Python3.9 wo CUDA is skipped",
reason="Python3.9 without CUDA is skipped",
)
available_devices = ["cpu"]
......
......@@ -2,8 +2,11 @@
-r requirements.txt
# Tools for static checking.
# - flake8-annotations is needed to avoid F811 error with overload
# function typing with mypy.
black == 19.10b0
flake8 == 3.7.9
flake8-annotations == 2.6.2
isort == 5.6.4
mypy == 0.790
......
......@@ -71,10 +71,8 @@ class Module(Generic[T_co]):
# The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns
# back that same object. But if they pass nothing, an `OrederedDict` is created and returned.
T_destination = TypeVar('T_destination', bound=Mapping[str, Tensor])
@overload
def state_dict(self, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: ...
def state_dict(self, destination: Mapping[str, Tensor], prefix: str = ..., keep_vars: bool = ...) -> Mapping[str, Tensor]: ...
@overload
def state_dict(self, prefix: str = ..., keep_vars: bool = ...) -> OrderedDict[str, Tensor]: ...
......
......@@ -10,8 +10,8 @@ from parameterized import parameterized
import torch
from torch import nn
from fairscale.nn.data_parallel import FullyShardedDataParallel
from fairscale.utils.testing import objects_are_equal
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils.testing import dist_init, objects_are_equal, skip_if_cuda, teardown, temp_files_ctx
from .test_fsdp import (
CONFIG_OPTIONS,
......@@ -111,10 +111,10 @@ class TestSaveLoadStateDict(DistributedTest):
autocast = ddp_model.mixed_precision
cls._train_for_several_steps(ddp_model, 2, autocast)
state_1 = ddp_model.state_dict()
# You must make a new FullyShardedDataParallel instance to use module.load_state_dict
# You must make a new FSDP instance to use module.load_state_dict
unwrapped_model = TransformerWithSharedParams(group)
unwrapped_model.load_state_dict(state_1)
new_ddp_model = FullyShardedDataParallel(unwrapped_model, group, **config).cuda()
new_ddp_model = FSDP(unwrapped_model, group, **config).cuda()
cls._train_for_several_steps(new_ddp_model, 2, autocast)
try:
ddp_model.load_state_dict(new_ddp_model.state_dict())
......@@ -144,7 +144,7 @@ class TestSaveLoadStateDict(DistributedTest):
if config["mixed_precision"]:
config["compute_dtype"] = torch.float32
model = NestedWrappedModule(group, config)
model = FullyShardedDataParallel(model, group, **config).cuda()
model = FSDP(model, group, **config).cuda()
cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"])
# Round-trip state dict save/load/save.
......@@ -162,7 +162,7 @@ class TestSaveLoadStateDict(DistributedTest):
def _test_nested_wrapped_model_local_state_dict(cls, rank, group, config=None, local=None):
# Create a nested FSDP-wrapped instance.
model = NestedWrappedModule(group, config)
model = FullyShardedDataParallel(model, group, **config).cuda()
model = FSDP(model, group, **config).cuda()
cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"])
# Round trip state dict save/load/save.
......@@ -216,7 +216,7 @@ class TestStateDictDeviceDtype(DistributedTest):
if pure_fp16:
assert not config["mixed_precision"]
model = model.half()
fsdp_model = FullyShardedDataParallel(model, group, **config)
fsdp_model = FSDP(model, group, **config)
if not config["cpu_offload"]:
fsdp_model = fsdp_model.cuda()
autocast = fsdp_model.mixed_precision or pure_fp16
......@@ -244,5 +244,28 @@ class TestStateDictDeviceDtype(DistributedTest):
assert v.dtype == expected_dtype, f"{v.dtype} != {expected_dtype}"
@skip_if_cuda
def test_local_state_dict_calls_state_dict_recursion():
"""Testing the case of infinite recursive when FSDP is subclassed"""
class TestModule(FSDP):
def __init__(self):
super().__init__(module=nn.Linear(100, 100))
def state_dict(self, *args, **kwargs):
return self.local_state_dict(*args, **kwargs)
rank = 0
world_size = 1
with temp_files_ctx(2) as temp_files:
result = dist_init(rank, world_size, temp_files[0], temp_files[1])
assert result, "Dist init failed"
m = TestModule()
d = m.state_dict()
teardown()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment