Unverified Commit b0c3fe1e authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[minor] add a checking around local_state_dict (#1040)


Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent 16fba4c0
......@@ -962,6 +962,11 @@ class FullyShardedDataParallel(nn.Module):
so the resulting state_dict can only be loaded after the Module has been
wrapped with FSDP.
"""
# Check state, specifically, we shouldn't be in SUMMON_FULL_PARAMS since
# that will produce full state, not sharded state.
self.assert_state(
[TrainingState.IDLE, TrainingState.FORWARD, TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST]
)
with contextlib.ExitStack() as stack:
# Tell any nested FSDP instances not to auto summon full params.
for module in self.modules(): # includes self
......@@ -1025,6 +1030,11 @@ class FullyShardedDataParallel(nn.Module):
self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
) -> NamedTuple:
"""Load a local (sharded) state_dict."""
# Check state, specifically, we shouldn't be in SUMMON_FULL_PARAMS since
# that will load full state, not sharded state.
self.assert_state(
[TrainingState.IDLE, TrainingState.FORWARD, TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST]
)
with contextlib.ExitStack() as stack:
# Tell any nested FSDP instances not to auto summon full params.
for module in self.modules(): # includes self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment