Unverified Commit 468874c8 authored by Shruti Bhosale's avatar Shruti Bhosale Committed by GitHub
Browse files

FSDP: Fix saving and loading checkpoints with use_sharded_state=True (#574)



* fix saving and loading checkpoints with use_sharded_state=True

* mypy fix

* better fix of the infinite recursion

- we need to specifically call FSDP.state_dict from its local state_dict
- added unit test that fails without the fix and works with the fix
- fixed mypy for the overloaded functions

* make cpu-only fsdp work for state_dict at least
Co-authored-by: default avatarMin Xu <min.xu@acm.org>
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
Co-authored-by: default avatarMin Xu <m1n@fb.com>
parent bde4bac5
...@@ -11,7 +11,8 @@ import logging ...@@ -11,7 +11,8 @@ import logging
from math import inf from math import inf
import time import time
import traceback import traceback
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, NamedTuple, Optional, Set, Tuple, Union import typing
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Mapping, NamedTuple, Optional, Set, Tuple, Union
import torch import torch
from torch.autograd import Variable from torch.autograd import Variable
...@@ -616,8 +617,18 @@ class FullyShardedDataParallel(nn.Module): ...@@ -616,8 +617,18 @@ class FullyShardedDataParallel(nn.Module):
del self.orig_sizes del self.orig_sizes
self._reset_lazy_init() self._reset_lazy_init()
# TODO (Min): figuring out how to do typing for this overloaded function. @typing.overload
def state_dict(self, *args: Any, **kwargs: Any) -> "OrderedDict[str, torch.Tensor]": # type: ignore def state_dict(
self, destination: Mapping[str, torch.Tensor], prefix: str = ..., keep_vars: bool = ...
) -> Mapping[str, torch.Tensor]:
...
@typing.overload
def state_dict(self, prefix: str = ..., keep_vars: bool = ...) -> "OrderedDict[str, torch.Tensor]":
...
# Since we have overloads above, we can use Any here.
def state_dict(self, *args: Any, **kwargs: Any) -> Any:
""" """
Returns the whole (unsharded) state of the module. Parameters are not Returns the whole (unsharded) state of the module. Parameters are not
sharded, so the resulting state_dict can be loaded directly by the sharded, so the resulting state_dict can be loaded directly by the
...@@ -627,6 +638,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -627,6 +638,7 @@ class FullyShardedDataParallel(nn.Module):
.. warning:: This needs to be called on all ranks, since synchronization .. warning:: This needs to be called on all ranks, since synchronization
primitives will be used. primitives will be used.
""" """
if torch.cuda.is_available():
torch.cuda.synchronize() torch.cuda.synchronize()
self._lazy_init() self._lazy_init()
if self.mixed_precision: if self.mixed_precision:
...@@ -655,8 +667,18 @@ class FullyShardedDataParallel(nn.Module): ...@@ -655,8 +667,18 @@ class FullyShardedDataParallel(nn.Module):
self._cast_buffers() self._cast_buffers()
return state_dict return state_dict
# TODO (Min): figuring out how to do typing for this overloaded function. @typing.overload
def local_state_dict(self, *args, **kwargs): # type: ignore def local_state_dict(
self, destination: Mapping[str, torch.Tensor], prefix: str = ..., keep_vars: bool = ...
) -> Mapping[str, torch.Tensor]:
...
@typing.overload
def local_state_dict(self, prefix: str = ..., keep_vars: bool = ...) -> "OrderedDict[str, torch.Tensor]":
...
# Since we have overloads above, we can use Any here.
def local_state_dict(self, *args: Any, **kwargs: Any) -> Any:
""" """
Returns the local (sharded) state of the module. Parameters are sharded, Returns the local (sharded) state of the module. Parameters are sharded,
so the resulting state_dict can only be loaded after the Module has been so the resulting state_dict can only be loaded after the Module has been
...@@ -667,7 +689,9 @@ class FullyShardedDataParallel(nn.Module): ...@@ -667,7 +689,9 @@ class FullyShardedDataParallel(nn.Module):
for module in self.modules(): # includes self for module in self.modules(): # includes self
if isinstance(module, FullyShardedDataParallel): if isinstance(module, FullyShardedDataParallel):
stack.enter_context(module._no_return_full_state_dict()) stack.enter_context(module._no_return_full_state_dict())
return self.state_dict(*args, **kwargs) # We need to specially call FSDP's state_dict function in case
# self.state_dict is a function from a child class of FSDP.
return FullyShardedDataParallel.state_dict(self, *args, **kwargs)
@contextlib.contextmanager @contextlib.contextmanager
def _no_return_full_state_dict(self) -> Generator: def _no_return_full_state_dict(self) -> Generator:
...@@ -678,7 +702,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -678,7 +702,7 @@ class FullyShardedDataParallel(nn.Module):
finally: finally:
self._return_full_state_dict = backup self._return_full_state_dict = backup
def load_state_dict( def _load_state_dict(
self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
) -> NamedTuple: ) -> NamedTuple:
""" """
...@@ -695,6 +719,11 @@ class FullyShardedDataParallel(nn.Module): ...@@ -695,6 +719,11 @@ class FullyShardedDataParallel(nn.Module):
self._lazy_init() self._lazy_init()
return self.module.load_state_dict(state_dict, strict) return self.module.load_state_dict(state_dict, strict)
def load_state_dict(
self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
) -> NamedTuple:
return self._load_state_dict(state_dict, strict)
def load_local_state_dict( def load_local_state_dict(
self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
) -> NamedTuple: ) -> NamedTuple:
...@@ -704,7 +733,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -704,7 +733,7 @@ class FullyShardedDataParallel(nn.Module):
for module in self.modules(): # includes self for module in self.modules(): # includes self
if isinstance(module, FullyShardedDataParallel): if isinstance(module, FullyShardedDataParallel):
stack.enter_context(module._no_return_full_state_dict()) stack.enter_context(module._no_return_full_state_dict())
output = self.load_state_dict(state_dict, strict) output = self._load_state_dict(state_dict, strict)
return output return output
@contextlib.contextmanager @contextlib.contextmanager
...@@ -961,12 +990,15 @@ class FullyShardedDataParallel(nn.Module): ...@@ -961,12 +990,15 @@ class FullyShardedDataParallel(nn.Module):
"""Create streams to overlap data transfer and computation.""" """Create streams to overlap data transfer and computation."""
if len(self._streams) > 0 or not self._is_root: if len(self._streams) > 0 or not self._is_root:
return return
if torch.cuda.is_available():
# Stream to move main FP32 params (may be on CPU) to FP16 for forward. # Stream to move main FP32 params (may be on CPU) to FP16 for forward.
self._streams["fp32_to_fp16"] = torch.cuda.Stream() self._streams["fp32_to_fp16"] = torch.cuda.Stream()
# Stream for all-gathering parameters. # Stream for all-gathering parameters.
self._streams["all_gather"] = torch.cuda.Stream() self._streams["all_gather"] = torch.cuda.Stream()
# Stream for overlapping grad reduction with the backward pass. # Stream for overlapping grad reduction with the backward pass.
self._streams["post_backward"] = torch.cuda.Stream() self._streams["post_backward"] = torch.cuda.Stream()
# Helper for bucketing reduce-scatter ops. This is also shared with # Helper for bucketing reduce-scatter ops. This is also shared with
# children instances to improve bucket utilization. # children instances to improve bucket utilization.
self._reducer = ReduceScatterBucketer(self.bucket_cap_mb) self._reducer = ReduceScatterBucketer(self.bucket_cap_mb)
...@@ -984,6 +1016,8 @@ class FullyShardedDataParallel(nn.Module): ...@@ -984,6 +1016,8 @@ class FullyShardedDataParallel(nn.Module):
instance) needs to synchronize with the default stream to ensure the instance) needs to synchronize with the default stream to ensure the
previous optimizer step is done. previous optimizer step is done.
""" """
if not torch.cuda.is_available():
return
if self.mixed_precision: if self.mixed_precision:
self._streams["fp32_to_fp16"].wait_stream(torch.cuda.current_stream()) self._streams["fp32_to_fp16"].wait_stream(torch.cuda.current_stream())
else: else:
......
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
# Licensed under the MIT License. # Licensed under the MIT License.
from contextlib import ExitStack, contextmanager from contextlib import ExitStack, contextmanager
from typing import TYPE_CHECKING, Any, Dict, Generator, List, NamedTuple, Optional, Tuple, Union import typing
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Mapping, NamedTuple, Optional, Tuple, Union
import torch import torch
from torch import Tensor from torch import Tensor
...@@ -201,7 +202,18 @@ class FlattenParamsWrapper(nn.Module): ...@@ -201,7 +202,18 @@ class FlattenParamsWrapper(nn.Module):
except AttributeError: except AttributeError:
return getattr(self.module, name) # fallback to wrapped module return getattr(self.module, name) # fallback to wrapped module
def state_dict(self, *args: Any, **kwargs: Any) -> "OrderedDict[str, Tensor]": # type: ignore @typing.overload
def state_dict(
self, destination: Mapping[str, Tensor], prefix: str = ..., keep_vars: bool = ...
) -> Mapping[str, Tensor]:
...
@typing.overload
def state_dict(self, prefix: str = ..., keep_vars: bool = ...) -> "OrderedDict[str, Tensor]":
...
# Since we have overloads above, we can use Any here.
def state_dict(self, *args: Any, **kwargs: Any) -> Any:
"""Return the wrapped module's state_dict (unflattened).""" """Return the wrapped module's state_dict (unflattened)."""
if self.is_flattened and self._auto_unflatten_state_dict: if self.is_flattened and self._auto_unflatten_state_dict:
with self.unflatten_params(recurse=False): with self.unflatten_params(recurse=False):
......
...@@ -57,6 +57,8 @@ if TYPE_CHECKING: ...@@ -57,6 +57,8 @@ if TYPE_CHECKING:
else: else:
Base = nn.Module Base = nn.Module
skip_if_cuda = pytest.mark.skipif(torch.cuda.is_available(), reason="Testing only on CPUs to save time")
skip_if_no_cuda = pytest.mark.skipif( skip_if_no_cuda = pytest.mark.skipif(
not torch.cuda.is_available() or torch.cuda.device_count() < 1, reason="CUDA required" not torch.cuda.is_available() or torch.cuda.device_count() < 1, reason="CUDA required"
) )
...@@ -75,7 +77,7 @@ skip_if_py38 = pytest.mark.skipif( ...@@ -75,7 +77,7 @@ skip_if_py38 = pytest.mark.skipif(
skip_if_py39_no_cuda = pytest.mark.skipif( skip_if_py39_no_cuda = pytest.mark.skipif(
not torch.cuda.is_available() and sys.version_info.major == 3 and sys.version_info.minor == 9, not torch.cuda.is_available() and sys.version_info.major == 3 and sys.version_info.minor == 9,
reason="Python3.9 wo CUDA is skipped", reason="Python3.9 without CUDA is skipped",
) )
available_devices = ["cpu"] available_devices = ["cpu"]
......
...@@ -2,8 +2,11 @@ ...@@ -2,8 +2,11 @@
-r requirements.txt -r requirements.txt
# Tools for static checking. # Tools for static checking.
# - flake8-annotations is needed to avoid F811 error with overload
# function typing with mypy.
black == 19.10b0 black == 19.10b0
flake8 == 3.7.9 flake8 == 3.7.9
flake8-annotations == 2.6.2
isort == 5.6.4 isort == 5.6.4
mypy == 0.790 mypy == 0.790
......
...@@ -71,10 +71,8 @@ class Module(Generic[T_co]): ...@@ -71,10 +71,8 @@ class Module(Generic[T_co]):
# The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns # The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns
# back that same object. But if they pass nothing, an `OrederedDict` is created and returned. # back that same object. But if they pass nothing, an `OrederedDict` is created and returned.
T_destination = TypeVar('T_destination', bound=Mapping[str, Tensor])
@overload @overload
def state_dict(self, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: ... def state_dict(self, destination: Mapping[str, Tensor], prefix: str = ..., keep_vars: bool = ...) -> Mapping[str, Tensor]: ...
@overload @overload
def state_dict(self, prefix: str = ..., keep_vars: bool = ...) -> OrderedDict[str, Tensor]: ... def state_dict(self, prefix: str = ..., keep_vars: bool = ...) -> OrderedDict[str, Tensor]: ...
......
...@@ -10,8 +10,8 @@ from parameterized import parameterized ...@@ -10,8 +10,8 @@ from parameterized import parameterized
import torch import torch
from torch import nn from torch import nn
from fairscale.nn.data_parallel import FullyShardedDataParallel from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils.testing import objects_are_equal from fairscale.utils.testing import dist_init, objects_are_equal, skip_if_cuda, teardown, temp_files_ctx
from .test_fsdp import ( from .test_fsdp import (
CONFIG_OPTIONS, CONFIG_OPTIONS,
...@@ -111,10 +111,10 @@ class TestSaveLoadStateDict(DistributedTest): ...@@ -111,10 +111,10 @@ class TestSaveLoadStateDict(DistributedTest):
autocast = ddp_model.mixed_precision autocast = ddp_model.mixed_precision
cls._train_for_several_steps(ddp_model, 2, autocast) cls._train_for_several_steps(ddp_model, 2, autocast)
state_1 = ddp_model.state_dict() state_1 = ddp_model.state_dict()
# You must make a new FullyShardedDataParallel instance to use module.load_state_dict # You must make a new FSDP instance to use module.load_state_dict
unwrapped_model = TransformerWithSharedParams(group) unwrapped_model = TransformerWithSharedParams(group)
unwrapped_model.load_state_dict(state_1) unwrapped_model.load_state_dict(state_1)
new_ddp_model = FullyShardedDataParallel(unwrapped_model, group, **config).cuda() new_ddp_model = FSDP(unwrapped_model, group, **config).cuda()
cls._train_for_several_steps(new_ddp_model, 2, autocast) cls._train_for_several_steps(new_ddp_model, 2, autocast)
try: try:
ddp_model.load_state_dict(new_ddp_model.state_dict()) ddp_model.load_state_dict(new_ddp_model.state_dict())
...@@ -144,7 +144,7 @@ class TestSaveLoadStateDict(DistributedTest): ...@@ -144,7 +144,7 @@ class TestSaveLoadStateDict(DistributedTest):
if config["mixed_precision"]: if config["mixed_precision"]:
config["compute_dtype"] = torch.float32 config["compute_dtype"] = torch.float32
model = NestedWrappedModule(group, config) model = NestedWrappedModule(group, config)
model = FullyShardedDataParallel(model, group, **config).cuda() model = FSDP(model, group, **config).cuda()
cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"]) cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"])
# Round-trip state dict save/load/save. # Round-trip state dict save/load/save.
...@@ -162,7 +162,7 @@ class TestSaveLoadStateDict(DistributedTest): ...@@ -162,7 +162,7 @@ class TestSaveLoadStateDict(DistributedTest):
def _test_nested_wrapped_model_local_state_dict(cls, rank, group, config=None, local=None): def _test_nested_wrapped_model_local_state_dict(cls, rank, group, config=None, local=None):
# Create a nested FSDP-wrapped instance. # Create a nested FSDP-wrapped instance.
model = NestedWrappedModule(group, config) model = NestedWrappedModule(group, config)
model = FullyShardedDataParallel(model, group, **config).cuda() model = FSDP(model, group, **config).cuda()
cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"]) cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"])
# Round trip state dict save/load/save. # Round trip state dict save/load/save.
...@@ -216,7 +216,7 @@ class TestStateDictDeviceDtype(DistributedTest): ...@@ -216,7 +216,7 @@ class TestStateDictDeviceDtype(DistributedTest):
if pure_fp16: if pure_fp16:
assert not config["mixed_precision"] assert not config["mixed_precision"]
model = model.half() model = model.half()
fsdp_model = FullyShardedDataParallel(model, group, **config) fsdp_model = FSDP(model, group, **config)
if not config["cpu_offload"]: if not config["cpu_offload"]:
fsdp_model = fsdp_model.cuda() fsdp_model = fsdp_model.cuda()
autocast = fsdp_model.mixed_precision or pure_fp16 autocast = fsdp_model.mixed_precision or pure_fp16
...@@ -244,5 +244,28 @@ class TestStateDictDeviceDtype(DistributedTest): ...@@ -244,5 +244,28 @@ class TestStateDictDeviceDtype(DistributedTest):
assert v.dtype == expected_dtype, f"{v.dtype} != {expected_dtype}" assert v.dtype == expected_dtype, f"{v.dtype} != {expected_dtype}"
@skip_if_cuda
def test_local_state_dict_calls_state_dict_recursion():
"""Testing the case of infinite recursive when FSDP is subclassed"""
class TestModule(FSDP):
def __init__(self):
super().__init__(module=nn.Linear(100, 100))
def state_dict(self, *args, **kwargs):
return self.local_state_dict(*args, **kwargs)
rank = 0
world_size = 1
with temp_files_ctx(2) as temp_files:
result = dist_init(rank, world_size, temp_files[0], temp_files[1])
assert result, "Dist init failed"
m = TestModule()
d = m.state_dict()
teardown()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment