Unverified Commit 0d1f058b authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[feat] oss: add rank_local_state_dict staticmethod (#174)

parent b5ccedc0
...@@ -267,6 +267,18 @@ class OSS(Optimizer): ...@@ -267,6 +267,18 @@ class OSS(Optimizer):
"local_state_dict": False, "local_state_dict": False,
} }
@staticmethod
def rank_local_state_dict(rank: int, state_dict: dict) -> dict:
"""Returns the local_state_dict for a given rank.
Arguments:
rank (int): rank to get local_state_dict for
state_dict (dict): global state_dict
"""
# Get this optimizer's param_groups shard
param_groups = state_dict["param_groups"][state_dict["partition"][rank][0] : state_dict["partition"][rank][1]]
return {"state": state_dict["state"][rank], "param_groups": param_groups}
def load_local_state_dict(self, state_dict: dict) -> None: def load_local_state_dict(self, state_dict: dict) -> None:
"""Loads this rank's state_dict. """Loads this rank's state_dict.
...@@ -306,12 +318,8 @@ class OSS(Optimizer): ...@@ -306,12 +318,8 @@ class OSS(Optimizer):
if state_dict["local_state_dict"]: if state_dict["local_state_dict"]:
self.load_local_state_dict(state_dict) self.load_local_state_dict(state_dict)
else: else:
# Get this optimizer's param_groups shard
param_groups = state_dict["param_groups"][
state_dict["partition"][self.rank][0] : state_dict["partition"][self.rank][1]
]
# Dispatch this rank's state dictionary to the wrapped shard optimizer # Dispatch this rank's state dictionary to the wrapped shard optimizer
self.load_local_state_dict({"state": state_dict["state"][self.rank], "param_groups": param_groups}) self.load_local_state_dict(OSS.rank_local_state_dict(self.rank, state_dict))
def add_param_group(self, param_group: dict) -> None: def add_param_group(self, param_group: dict) -> None:
"""Add a param group to the :class:`Optimizer` s `param_groups`. """Add a param group to the :class:`Optimizer` s `param_groups`.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment