Unverified Commit 585f177b authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] Properly restore a sharded optim state (#39)



* hotfix a half-cooked optimizer state restoration, the global shared state also needs to be restored

* [cleanup] get 100% coverage on oss.py (#38)
authored-by: default avatarMandeep Singh Baines <msb@fb.com>

* better unit testing, check that the .param_groups attribute is properly in sync with the loaded state
Co-authored-by: default avatarmsbaines <35972327+msbaines@users.noreply.github.com>
parent 3427a039
...@@ -49,10 +49,12 @@ class OSS(Optimizer): ...@@ -49,10 +49,12 @@ class OSS(Optimizer):
in_super_constructor: bool in_super_constructor: bool
def __init__(self, params: _params_t, optim: Type[Optimizer] = SGD, group: Any = dist.group.WORLD, **defaults: Any): def __init__(self, params: _params_t, optim: Type[Optimizer] = SGD, group: Any = dist.group.WORLD, **defaults: Any):
# Hold all the nmodel params in the root .param_groups
self.in_super_constructor = True self.in_super_constructor = True
super().__init__(params, defaults) super().__init__(params, defaults)
self.in_super_constructor = False self.in_super_constructor = False
# Build the wrapped optimizer, responsible for a shard of the params
self.group = group self.group = group
self.rank = dist.get_rank(group) self.rank = dist.get_rank(group)
param_groups = self.partition_parameters() param_groups = self.partition_parameters()
...@@ -90,11 +92,14 @@ class OSS(Optimizer): ...@@ -90,11 +92,14 @@ class OSS(Optimizer):
return param_groups return param_groups
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
# Run the optimizer step on this shard only
loss = self.optim.step(closure=closure) loss = self.optim.step(closure=closure)
# Sync all the states
for rank, param_groups in enumerate(self.partition_parameters()): for rank, param_groups in enumerate(self.partition_parameters()):
for param_group in param_groups: for param_group in param_groups:
for param in param_group["params"]: for param in param_group["params"]:
dist.broadcast(param, rank, group=self.group) dist.broadcast(tensor=param, src=rank, group=self.group)
return loss return loss
def local_state_dict(self) -> dict: def local_state_dict(self) -> dict:
...@@ -138,10 +143,13 @@ class OSS(Optimizer): ...@@ -138,10 +143,13 @@ class OSS(Optimizer):
self.optim.load_state_dict(state_dict_ondevice) self.optim.load_state_dict(state_dict_ondevice)
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
""" Loads this rank's optimizer state_dict, given the global optimizer state. """ """ Restore the global parameter groups as well as the shard """
# Dispatch this rank's state dictionary to the local load # Dispatch this rank's state dictionary to the wrapped shard optimizer
self.load_local_state_dict(state_dict["state"][self.rank]) self.load_local_state_dict(state_dict["state"][self.rank])
# Restore the global param_groups
self.param_groups = state_dict["param_groups"]
def add_param_group(self, param_group: dict) -> None: def add_param_group(self, param_group: dict) -> None:
super().add_param_group(param_group) super().add_param_group(param_group)
if not self.in_super_constructor: if not self.in_super_constructor:
......
...@@ -44,9 +44,20 @@ def test_state_dict(): ...@@ -44,9 +44,20 @@ def test_state_dict():
o = optim.OSS([x], lr=0.1) o = optim.OSS([x], lr=0.1)
o.consolidate_state_dict() # Sync state dict in between replicas - even if there are none o.consolidate_state_dict() # Sync state dict in between replicas - even if there are none
state_dict = o.state_dict() state_dict = o.state_dict()
# Check that the pulled state is what we expect
assert state_dict["param_groups"][0]["lr"] == 0.1
# Check that the pulled state and the .param_groups attribute are in sync
assert state_dict["param_groups"][0]["lr"] == o.param_groups[0]["lr"]
# Check that it's correctly loaded
o = optim.OSS([x], lr=0.01) o = optim.OSS([x], lr=0.01)
o.load_state_dict(state_dict) o.load_state_dict(state_dict)
# We should now be using a lr of 0.1.
# We should now be using a lr of 0.1, both within the optimizer
# and as exposed by the .param_groups attribute
assert o.param_groups[0]["lr"] == 0.1
x.backward() x.backward()
o.step() o.step()
assert x == torch.tensor([0.9], device=DEVICE) assert x == torch.tensor([0.9], device=DEVICE)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment