Unverified Commit 57079b08 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

Aligning OSS state dict with...

Aligning OSS state dict with `https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer` (#31)
parent d9e6ceaa
......@@ -127,7 +127,7 @@ class OSS(Optimizer):
len(self._all_states) > 0
), "The optimizer state is not materialized, please call consolidate_state_dict on every replica beforehand"
return {"states": self._all_states}
return {"state": self._all_states, "param_groups": self.param_groups}
def load_local_state_dict(self, state_dict: dict) -> None:
""" Loads this rank's state_dict. """
......@@ -140,7 +140,7 @@ class OSS(Optimizer):
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
""" Loads this rank's optimizer state_dict, given the global optimizer state. """
# Dispatch this rank's state dictionary to the local load
self.load_local_state_dict(state_dict["states"][self.rank])
self.load_local_state_dict(state_dict["state"][self.rank])
def add_param_group(self, param_group: dict) -> None:
super().add_param_group(param_group)
......
......@@ -222,7 +222,7 @@ def run_test_collect_shards(rank, world_size, reference_rank):
# - load it again
if rank == reference_rank:
optimizer_state_dict = optimizer.state_dict()
assert len(optimizer_state_dict["states"]) == world_size
assert len(optimizer_state_dict["state"]) == world_size
else:
optimizer_state_dict = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment