Unverified Commit 4b126c7b authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat] support a context for loading state_dict for FSDP (#1065)



* [fix]: add a context for supporting state_dict from a non-FSDP parent module

* formatting
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent 3cc7fa8d
...@@ -5,7 +5,13 @@ ...@@ -5,7 +5,13 @@
from typing import List from typing import List
from .fully_sharded_data_parallel import FullyShardedDataParallel, OffloadConfig, TrainingState, auto_wrap_bn from .fully_sharded_data_parallel import (
FullyShardedDataParallel,
OffloadConfig,
TrainingState,
auto_wrap_bn,
no_pre_load_state_dict_hook,
)
from .sharded_ddp import ShardedDataParallel from .sharded_ddp import ShardedDataParallel
__all__: List[str] = [] __all__: List[str] = []
...@@ -51,7 +51,7 @@ from fairscale.internal.parallel import ( ...@@ -51,7 +51,7 @@ from fairscale.internal.parallel import (
from fairscale.internal.params import calc_grad_norm, recursive_copy_to_device from fairscale.internal.params import calc_grad_norm, recursive_copy_to_device
from fairscale.internal.reduce_scatter_bucketer import ReduceScatterBucketer from fairscale.internal.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.internal.state_dict import replace_by_prefix_ from fairscale.internal.state_dict import replace_by_prefix_
from fairscale.nn.misc import FlattenParamsWrapper from fairscale.nn.misc import FlattenParamsWrapper, _enable_pre_load_state_dict_hook
from fairscale.nn.wrap import auto_wrap, config_auto_wrap_policy, enable_wrap from fairscale.nn.wrap import auto_wrap, config_auto_wrap_policy, enable_wrap
from . import fsdp_optim_utils as ou from . import fsdp_optim_utils as ou
...@@ -762,6 +762,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -762,6 +762,7 @@ class FullyShardedDataParallel(nn.Module):
self.numel_padded_per_param.append(0) self.numel_padded_per_param.append(0)
continue continue
p._is_sharded = True p._is_sharded = True
# TODO (Min): broadcast from rank 0 to avoid each rank need to init with the same seed?
# Replace p.data with the relevant shard. # Replace p.data with the relevant shard.
orig_data = p.data orig_data = p.data
...@@ -2581,10 +2582,25 @@ def _post_state_dict_hook( ...@@ -2581,10 +2582,25 @@ def _post_state_dict_hook(
return state_dict return state_dict
@contextlib.contextmanager
def no_pre_load_state_dict_hook() -> Generator:
"""Disable the pre-load hook.
This is needed if we are loading a state_dict that was not produced by
a root FSDP instance.
"""
global _enable_pre_load_state_dict_hook
bak = _enable_pre_load_state_dict_hook
_enable_pre_load_state_dict_hook = False
yield
_enable_pre_load_state_dict_hook = bak
def _pre_load_state_dict_hook( def _pre_load_state_dict_hook(
state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], prefix: str, *args: Any state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], prefix: str, *args: Any
) -> None: ) -> None:
replace_by_prefix_(state_dict, prefix, prefix + "_fsdp_wrapped_module.") if _enable_pre_load_state_dict_hook:
replace_by_prefix_(state_dict, prefix, prefix + "_fsdp_wrapped_module.")
def _clean_path(path: str) -> str: def _clean_path(path: str) -> str:
......
...@@ -9,7 +9,7 @@ from typing import List ...@@ -9,7 +9,7 @@ from typing import List
# in favor of fairscale.nn.checkpoint.checkpoint_wrapper. # in favor of fairscale.nn.checkpoint.checkpoint_wrapper.
from fairscale.nn.checkpoint import checkpoint_wrapper from fairscale.nn.checkpoint import checkpoint_wrapper
from .flatten_params_wrapper import FlattenParamsWrapper from .flatten_params_wrapper import FlattenParamsWrapper, _enable_pre_load_state_dict_hook
from .param_bucket import GradBucket, ParamBucket from .param_bucket import GradBucket, ParamBucket
__all__: List[str] = [] __all__: List[str] = []
...@@ -49,6 +49,9 @@ from fairscale.internal.state_dict import replace_by_prefix_ ...@@ -49,6 +49,9 @@ from fairscale.internal.state_dict import replace_by_prefix_
if TYPE_CHECKING: if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401 from collections import OrderedDict # noqa: F401
# See no_pre_load_state_dict_hook context manager function in FSDP for more details.
_enable_pre_load_state_dict_hook = True
class FlatParameter(nn.Parameter): class FlatParameter(nn.Parameter):
"""A parameter that is initialized from a list of parameters and can be """A parameter that is initialized from a list of parameters and can be
...@@ -543,6 +546,8 @@ def _post_state_dict_hook( ...@@ -543,6 +546,8 @@ def _post_state_dict_hook(
def _pre_load_state_dict_hook( def _pre_load_state_dict_hook(
state_dict: Union[Dict[str, Tensor], "OrderedDict[str, Tensor]"], prefix: str, *args: Any state_dict: Union[Dict[str, Tensor], "OrderedDict[str, Tensor]"], prefix: str, *args: Any
) -> None: ) -> None:
if not _enable_pre_load_state_dict_hook:
return
# Push everything down to ._fpw_module level. # Push everything down to ._fpw_module level.
replace_by_prefix_(state_dict, prefix, prefix + "_fpw_module.") replace_by_prefix_(state_dict, prefix, prefix + "_fpw_module.")
# The flat_param_* keys actually needs to move one level up. # The flat_param_* keys actually needs to move one level up.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment