Unverified Commit d16e9f61 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat] Gracefully handle local/global state dict queries (#89)

Return either the local or global state when queried, depending on a prior consolidation
parent 3d7f524a
...@@ -144,13 +144,18 @@ class OSS(Optimizer): ...@@ -144,13 +144,18 @@ class OSS(Optimizer):
""" """
Return the last known global optimizer state, which consist of a list of the shards. Return the last known global optimizer state, which consist of a list of the shards.
NOTE: This is limited to the replica which was responsible for the consolidation. NOTE:
- If the state has not been consolidated, this returns a shard's worth, not the global state.
- Returning the global state is limited to the replica which was responsible for the consolidation.
The state may also not be up to date, depending on when `consolidate_state_dict` was last called. The state may also not be up to date, depending on when `consolidate_state_dict` was last called.
""" """
assert ( if len(self._all_states) == 0:
len(self._all_states) > 0 logging.warning("Optimizer state has not been consolidated. Returning the local state")
), "The optimizer state is not materialized, please call consolidate_state_dict on every replica beforehand" logging.warning("Please call `consolidate_state_dict()` beforehand if you meant to save the global state")
state_dict = self.local_state_dict()
state_dict["local_state_dict"] = True
return state_dict
# Flatten the param_groups, save the partition which logs the rank <> shard correspondence # Flatten the param_groups, save the partition which logs the rank <> shard correspondence
partition: List[Tuple[int, int]] = [] partition: List[Tuple[int, int]] = []
...@@ -167,6 +172,7 @@ class OSS(Optimizer): ...@@ -167,6 +172,7 @@ class OSS(Optimizer):
"state": [s["state"] for s in self._all_states], "state": [s["state"] for s in self._all_states],
"param_groups": param_groups, "param_groups": param_groups,
"partition": partition, "partition": partition,
"local_state_dict": False,
} }
def load_local_state_dict(self, state_dict: dict) -> None: def load_local_state_dict(self, state_dict: dict) -> None:
...@@ -196,12 +202,16 @@ class OSS(Optimizer): ...@@ -196,12 +202,16 @@ class OSS(Optimizer):
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
""" Restore the global parameter groups as well as the shard """ """ Restore the global parameter groups as well as the shard """
# Get this optimizer's param_groups shard # Check whether we got a local or global dict
param_groups = state_dict["param_groups"][ if state_dict["local_state_dict"]:
state_dict["partition"][self.rank][0] : state_dict["partition"][self.rank][1] self.load_local_state_dict(state_dict)
] else:
# Dispatch this rank's state dictionary to the wrapped shard optimizer # Get this optimizer's param_groups shard
self.load_local_state_dict({"state": state_dict["state"][self.rank], "param_groups": param_groups}) param_groups = state_dict["param_groups"][
state_dict["partition"][self.rank][0] : state_dict["partition"][self.rank][1]
]
# Dispatch this rank's state dictionary to the wrapped shard optimizer
self.load_local_state_dict({"state": state_dict["state"][self.rank], "param_groups": param_groups})
def add_param_group(self, param_group: dict) -> None: def add_param_group(self, param_group: dict) -> None:
super().add_param_group(param_group) super().add_param_group(param_group)
......
...@@ -144,6 +144,20 @@ def test_local_state_dict(): ...@@ -144,6 +144,20 @@ def test_local_state_dict():
assert x == torch.tensor([0.9], device=DEVICE) assert x == torch.tensor([0.9], device=DEVICE)
def test_implicit_local_state_dict():
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], lr=0.1)
local_state_dict = o.state_dict()
o = optim.OSS([x], lr=0.01)
o.load_state_dict(local_state_dict)
# We should now be using a lr of 0.1.
assert o.optim.param_groups[0]["lr"] == 0.1
assert o.param_groups[0]["lr"] == 0.1
x.backward()
o.step()
assert x == torch.tensor([0.9], device=DEVICE)
def run_test_add_param_group(rank, world_size): def run_test_add_param_group(rank, world_size):
dist_init(rank, world_size) dist_init(rank, world_size)
params = [] params = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment