Unverified Commit b6dc98cf authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

[fix] fix FSDP state_dict/load_state_dict for nested wrapped instances (#440)

parent 93d115c6
...@@ -29,6 +29,7 @@ from fairscale.utils.containers import ( ...@@ -29,6 +29,7 @@ from fairscale.utils.containers import (
) )
from fairscale.utils.parallel import chunk_and_pad, validate_process_group from fairscale.utils.parallel import chunk_and_pad, validate_process_group
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.utils.state_dict import replace_by_prefix_
if TYPE_CHECKING: if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401 from collections import OrderedDict # noqa: F401
...@@ -172,11 +173,11 @@ class FullyShardedDataParallel(nn.Module): ...@@ -172,11 +173,11 @@ class FullyShardedDataParallel(nn.Module):
params = list(p for p in module.parameters() if not hasattr(p, "_is_sharded")) params = list(p for p in module.parameters() if not hasattr(p, "_is_sharded"))
if self.flatten_parameters and len(params) > 0: if self.flatten_parameters and len(params) > 0:
self.module: nn.Module = FlattenParamsWrapper(module, param_list=params) self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(module, param_list=params)
del module # free original module in case it helps garbage collection del module # free original module in case it helps garbage collection
self.params = [self.module.flat_param] self.params = [self._fsdp_wrapped_module.flat_param]
else: else:
self.module = module self._fsdp_wrapped_module = module
self.params = params self.params = params
# Shard module parameters in place # Shard module parameters in place
...@@ -192,8 +193,23 @@ class FullyShardedDataParallel(nn.Module): ...@@ -192,8 +193,23 @@ class FullyShardedDataParallel(nn.Module):
# pass. This will be False when inside the no_sync context manager. # pass. This will be False when inside the no_sync context manager.
self.require_backward_grad_sync: bool = True self.require_backward_grad_sync: bool = True
# Enum to indicate if we're in the forward/backward pass, idle, etc.
self.training_state = TrainingState.IDLE self.training_state = TrainingState.IDLE
# Register hook after state_dict() to remove the "_fsdp_wrapped_module."
# prefix and before load_state_dict() to add it back.
self._register_state_dict_hook(_post_state_dict_hook)
self._register_load_state_dict_pre_hook(_pre_load_state_dict_hook)
# Flag to indicate whether state_dict() should automatically summon the
# full params. This defaults to True, but may be set to False if the
# user explicitly requests the local state dict via local_state_dict().
self._return_full_state_dict = True
@property
def module(self) -> nn.Module:
return self._fsdp_wrapped_module # note: may be a FlattenParamsWrapper instance
@torch.no_grad() @torch.no_grad()
def _all_buffers_to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None: def _all_buffers_to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
"""Move all buffers to the specified device and dtype, recursively.""" """Move all buffers to the specified device and dtype, recursively."""
...@@ -235,6 +251,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -235,6 +251,7 @@ class FullyShardedDataParallel(nn.Module):
.. warning:: This needs to be called on all ranks, since synchronization .. warning:: This needs to be called on all ranks, since synchronization
primitives will be used. primitives will be used.
""" """
self._lazy_init()
assert self._is_root, "clip_grad_norm should only be called on the root (parent) instance" assert self._is_root, "clip_grad_norm should only be called on the root (parent) instance"
assert self.training_state == TrainingState.IDLE assert self.training_state == TrainingState.IDLE
...@@ -374,7 +391,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -374,7 +391,7 @@ class FullyShardedDataParallel(nn.Module):
self._reset_lazy_init() self._reset_lazy_init()
# TODO (Min): figuring out how to do typing for this overloaded function. # TODO (Min): figuring out how to do typing for this overloaded function.
def state_dict(self, *args, **kwargs): # type: ignore def state_dict(self, *args: Any, **kwargs: Any) -> "OrderedDict[str, torch.Tensor]": # type: ignore
""" """
Returns the whole (unsharded) state of the module. Parameters are not Returns the whole (unsharded) state of the module. Parameters are not
sharded, so the resulting state_dict can be loaded directly by the sharded, so the resulting state_dict can be loaded directly by the
...@@ -384,16 +401,28 @@ class FullyShardedDataParallel(nn.Module): ...@@ -384,16 +401,28 @@ class FullyShardedDataParallel(nn.Module):
.. warning:: This needs to be called on all ranks, since synchronization .. warning:: This needs to be called on all ranks, since synchronization
primitives will be used. primitives will be used.
""" """
with self.summon_full_params(): if self.mixed_precision:
# Buffers dtype stays consistent with parameters. # Buffers dtype stays consistent with parameters.
self._all_buffers_to(dtype=torch.float32) self._all_buffers_to(dtype=torch.float32)
state_dict = self.module.state_dict(*args, **kwargs) if self._return_full_state_dict:
# We copy the state_dict since full param will be freed after if self.training_state != TrainingState.SUMMON_FULL_PARAMS:
# we exit the summon_full_params() context. with self.summon_full_params():
for key in state_dict.keys(): state_dict = super().state_dict(*args, **kwargs)
state_dict[key] = state_dict[key].clone() else:
torch.cuda.synchronize()
self._lazy_init()
state_dict = super().state_dict(*args, **kwargs)
else:
torch.cuda.synchronize()
self._lazy_init()
if self.flatten_parameters:
assert isinstance(self.module, FlattenParamsWrapper)
state_dict = self.module.flat_state_dict(*args, **kwargs)
else:
state_dict = super().state_dict(*args, **kwargs)
if self.mixed_precision:
# In case we are in mixed precision, restore buffers back to fp16. # In case we are in mixed precision, restore buffers back to fp16.
self._all_buffers_to(dtype=self.compute_dtype) self._all_buffers_to(dtype=self.compute_dtype)
return state_dict return state_dict
...@@ -405,12 +434,21 @@ class FullyShardedDataParallel(nn.Module): ...@@ -405,12 +434,21 @@ class FullyShardedDataParallel(nn.Module):
so the resulting state_dict can only be loaded after the Module has been so the resulting state_dict can only be loaded after the Module has been
wrapped with FullyShardedDataParallel. wrapped with FullyShardedDataParallel.
""" """
torch.cuda.synchronize() with contextlib.ExitStack() as stack:
self._lazy_init() # Tell any nested FSDP instances not to auto summon full params.
if self.flatten_parameters: for module in self.modules(): # includes self
return self.module.flat_state_dict(*args, **kwargs) # type: ignore if isinstance(module, FullyShardedDataParallel):
else: stack.enter_context(module._no_return_full_state_dict())
return self.module.state_dict(*args, **kwargs) return self.state_dict(*args, **kwargs)
@contextlib.contextmanager
def _no_return_full_state_dict(self) -> Generator:
backup = self._return_full_state_dict
self._return_full_state_dict = False
try:
yield
finally:
self._return_full_state_dict = backup
def load_state_dict( def load_state_dict(
self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
...@@ -421,16 +459,25 @@ class FullyShardedDataParallel(nn.Module): ...@@ -421,16 +459,25 @@ class FullyShardedDataParallel(nn.Module):
.. warning:: This needs to be called on all ranks, since synchronization .. warning:: This needs to be called on all ranks, since synchronization
primitives will be used. primitives will be used.
""" """
with self.summon_full_params(): if self._return_full_state_dict:
output = self.module.load_state_dict(state_dict, strict) with self.summon_full_params():
return output return self.module.load_state_dict(state_dict, strict)
else:
torch.cuda.synchronize()
self._lazy_init()
return self.module.load_state_dict(state_dict, strict)
def load_local_state_dict( def load_local_state_dict(
self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
) -> NamedTuple: ) -> NamedTuple:
"""Load a local (sharded) state_dict.""" """Load a local (sharded) state_dict."""
torch.cuda.synchronize() with contextlib.ExitStack() as stack:
return self.module.load_state_dict(state_dict, strict) # Tell any nested FSDP instances not to auto summon full params.
for module in self.modules(): # includes self
if isinstance(module, FullyShardedDataParallel):
stack.enter_context(module._no_return_full_state_dict())
output = self.load_state_dict(state_dict, strict)
return output
@contextlib.contextmanager @contextlib.contextmanager
def no_sync(self) -> Generator: def no_sync(self) -> Generator:
...@@ -457,30 +504,44 @@ class FullyShardedDataParallel(nn.Module): ...@@ -457,30 +504,44 @@ class FullyShardedDataParallel(nn.Module):
m.require_backward_grad_sync = old_flag m.require_backward_grad_sync = old_flag
@contextlib.contextmanager @contextlib.contextmanager
def summon_full_params(self) -> Generator: def summon_full_params(self, recurse: bool = True) -> Generator:
""" """
A context manager to expose full params for the underlying model. A context manager to expose full params for the current FSDP instance.
Can be useful *after* forward/backward for a model to get the params Can be useful *after* forward/backward for a model to get the params for
for additional processing or checking. additional processing or checking.
By default this will recursively summon all params for nested FSDP
instances; this can be disabled by setting ``recurse=False``.
This can be used on inner FSDPs. .. note:: This can be used on inner FSDPs.
This can *not* be used within a forward or backward pass. Nor can forward .. note:: This can *not* be used within a forward or backward pass. Nor
and backward be started from within this context. can forward and backward be started from within this context.
""" """
torch.cuda.synchronize() if recurse:
self._lazy_init() with contextlib.ExitStack() as stack:
self.assert_state(TrainingState.IDLE) # summon all params for any nested FlattenParamsWrapper instances
# Set the state so that we assert when trying to go into for module in self.modules():
# forward/backward. if isinstance(module, FullyShardedDataParallel):
self.training_state = TrainingState.SUMMON_FULL_PARAMS stack.enter_context(module.summon_full_params(recurse=False))
self._rebuild_full_params() # yield to the caller, with full params in all nested instances
try: yield
yield # exiting from the ExitStack will re-shard params
finally: return
self._free_full_params() else:
self._use_fp32_param_shard() torch.cuda.synchronize()
self.training_state = TrainingState.IDLE self._lazy_init()
self.assert_state(TrainingState.IDLE)
# Set the state so that we assert when trying to go into
# forward/backward.
self.training_state = TrainingState.SUMMON_FULL_PARAMS
self._rebuild_full_params()
try:
yield
finally:
self._free_full_params()
self._use_fp32_param_shard()
self.training_state = TrainingState.IDLE
def _reset_lazy_init(self) -> None: def _reset_lazy_init(self) -> None:
"""Reset instance so :func:`_lazy_init` will run on the next forward.""" """Reset instance so :func:`_lazy_init` will run on the next forward."""
...@@ -1011,3 +1072,23 @@ def alloc_storage_(data: torch.Tensor, size: torch.Size) -> None: ...@@ -1011,3 +1072,23 @@ def alloc_storage_(data: torch.Tensor, size: torch.Size) -> None:
return return
assert data.storage().size() == 0 assert data.storage().size() == 0
data.storage().resize_(size.numel()) data.storage().resize_(size.numel())
def _post_state_dict_hook(
module: nn.Module, state_dict: "OrderedDict[str, torch.Tensor]", prefix: str, *args: Any
) -> "OrderedDict[str, torch.Tensor]":
if module.training_state == TrainingState.SUMMON_FULL_PARAMS:
# We copy the state_dict since full param will be freed after
# we exit the summon_full_params() context.
for key in state_dict.keys():
state_dict[key] = state_dict[key].clone()
# Remove "_fsdp_wrapped_module." prefix
replace_by_prefix_(state_dict, prefix + "_fsdp_wrapped_module.", prefix)
return state_dict
def _pre_load_state_dict_hook(
state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], prefix: str, *args: Any
) -> None:
replace_by_prefix_(state_dict, prefix, prefix + "_fsdp_wrapped_module.")
...@@ -415,7 +415,7 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool: ...@@ -415,7 +415,7 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool:
shape_dtype_device_match = a.size() == b.size() and a.dtype == b.dtype and a.device == b.device shape_dtype_device_match = a.size() == b.size() and a.dtype == b.dtype and a.device == b.device
assert shape_dtype_device_match assert shape_dtype_device_match
return True return True
except AssertionError as e: except (AssertionError, RuntimeError) as e:
if raise_exception: if raise_exception:
raise e raise e
else: else:
......
...@@ -434,11 +434,9 @@ class TestSaveLoadStateDict(DistributedTest): ...@@ -434,11 +434,9 @@ class TestSaveLoadStateDict(DistributedTest):
ddp_model.state_dict() ddp_model.state_dict()
ddp_model.state_dict() # second call ddp_model.state_dict() # second call
@parameterized.expand([[False], [True]], name_func=rename_test) @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_state_dict_after_forward_mixed_precision(self, mixed_precision): def test_state_dict_after_forward(self, config):
test_fn = functools.partial( test_fn = functools.partial(self._test_module_state_dict, config)
self._test_module_state_dict, {"flatten_parameters": False, "mixed_precision": mixed_precision}
)
spawn_and_init(test_fn) spawn_and_init(test_fn)
@parameterized.expand([[False], [True]], name_func=rename_test) @parameterized.expand([[False], [True]], name_func=rename_test)
...@@ -474,6 +472,62 @@ class TestSaveLoadStateDict(DistributedTest): ...@@ -474,6 +472,62 @@ class TestSaveLoadStateDict(DistributedTest):
except Exception: except Exception:
pass pass
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_nested_wrapped_model(self, config):
if config["mixed_precision"]:
return # TODO(myleott) this is broken until we support FP32 all-gather for state_dict
test_fn = functools.partial(self._test_nested_wrapped_model, config=config)
spawn_and_init(test_fn)
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_nested_wrapped_model_local_state_dict(self, config):
if config["mixed_precision"]:
return # TODO(myleott) this is broken until we support FP32 all-gather for state_dict
test_fn = functools.partial(self._test_nested_wrapped_model_local_state_dict, config=config)
spawn_and_init(test_fn)
@classmethod
def _test_nested_wrapped_model(cls, rank, group, config=None):
# Get reference state dict without any nested FSDP instances.
model = NestedWrappedModule(group, None).cuda()
model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, process_group=group)
cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"])
ref_state_dict = {k: v.clone() for k, v in model.module.state_dict().items()}
# Create a nested FSDP-wrapped instance.
model = NestedWrappedModule(group, config)
model = FullyShardedDataParallel(model, group, **config).cuda()
cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"])
# Round-trip state dict save/load/save.
state_dict = {k: v.clone() for k, v in model.state_dict().items()}
model.load_state_dict(state_dict)
state_dict = model.state_dict()
assert ref_state_dict.keys() == state_dict.keys(), f"{ref_state_dict.keys()} != {state_dict.keys()}"
for key in ref_state_dict.keys():
assert objects_are_equal(
ref_state_dict[key], state_dict[key], raise_exception=False
), f"{key}, {ref_state_dict[key]} != {state_dict[key]}"
@classmethod
def _test_nested_wrapped_model_local_state_dict(cls, rank, group, config=None, local=None):
# Create a nested FSDP-wrapped instance.
model = NestedWrappedModule(group, config)
model = FullyShardedDataParallel(model, group, **config).cuda()
cls._train_for_several_steps(model, 2, autocast=config["mixed_precision"])
# Round trip state dict save/load/save.
ref_state_dict = {k: v.clone() for k, v in model.local_state_dict().items()}
model.load_local_state_dict(ref_state_dict)
state_dict = model.local_state_dict()
assert ref_state_dict.keys() == state_dict.keys(), f"{ref_state_dict.keys()} != {state_dict.keys()}"
for key in ref_state_dict.keys():
assert objects_are_equal(
ref_state_dict[key], state_dict[key], raise_exception=False
), f"{key}, {ref_state_dict[key]} != {state_dict[key]}"
class TestHooks(DistributedTest): class TestHooks(DistributedTest):
# Feel free to modify these tests as the implementation changes. # Feel free to modify these tests as the implementation changes.
...@@ -689,7 +743,10 @@ class NestedWrappedModule(nn.Module): ...@@ -689,7 +743,10 @@ class NestedWrappedModule(nn.Module):
torch.manual_seed(0) # keep everything deterministic torch.manual_seed(0) # keep everything deterministic
self.module = nn.Sequential( self.module = nn.Sequential(
nn.Linear(8, 4), _maybe_wrap(nn.Linear(4, 16)), _maybe_wrap(nn.Linear(16, 4)), nn.Linear(4, 8), nn.Linear(8, 4),
_maybe_wrap(nn.Sequential(_maybe_wrap(nn.Linear(4, 16)), nn.Linear(16, 16),)),
_maybe_wrap(nn.Linear(16, 4)),
nn.Linear(4, 8),
) )
def get_input(self, device): def get_input(self, device):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment