Unverified Commit 5a268b25 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat] OSS: Sync all attributes (#67)

Make sure that all attributes (not just LR) are in sync in between the OSS.param_groups and the actual wrapped optimizer. Some frameworks make it possible to alter any attribute on a scheduled basis, which proves useful depending on the optimizer, so the keys need to be generically supported (not just "lr"). Not syncing these attributes is a worst case scenario, since these adjustments are silently not propagated, fixing that. 
parent 3a203179
...@@ -67,6 +67,12 @@ class OSS(Optimizer): ...@@ -67,6 +67,12 @@ class OSS(Optimizer):
# Current device is set by the parameters allocated to this rank # Current device is set by the parameters allocated to this rank
self._device = split_param_groups[self.rank][0]["params"][0].device self._device = split_param_groups[self.rank][0]["params"][0].device
# Sync local and global param_groups keys
for global_group, local_group in zip(self.param_groups, self.optim.param_groups):
for k, v in local_group.items():
if k != "params":
global_group[k] = v
def partition_parameters(self) -> List[List[dict]]: def partition_parameters(self) -> List[List[dict]]:
"""Partitions parameters across distributed ranks. """Partitions parameters across distributed ranks.
...@@ -94,8 +100,8 @@ class OSS(Optimizer): ...@@ -94,8 +100,8 @@ class OSS(Optimizer):
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs. # NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
# For example, the apex library contains fused optimizers with a step that supports extra kwargs. # For example, the apex library contains fused optimizers with a step that supports extra kwargs.
def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) -> Optional[float]: def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) -> Optional[float]:
# Sync lr in case its been update by an LRScheduler. # Sync oss param_groups attributes in case they've been updated by a scheduler.
self._sync_lr() self._sync_param_groups()
# Run the optimizer step on this shard only # Run the optimizer step on this shard only
loss = self.optim.step(closure=closure, **kwargs) # type: ignore loss = self.optim.step(closure=closure, **kwargs) # type: ignore
...@@ -116,8 +122,8 @@ class OSS(Optimizer): ...@@ -116,8 +122,8 @@ class OSS(Optimizer):
This needs to be called on all replicas """ This needs to be called on all replicas """
# Sync lr in case its been update by an LRScheduler. # Sync lr and other attributes in case its been updated
self._sync_lr() self._sync_param_groups()
if self.rank == recipient_rank: if self.rank == recipient_rank:
# Pull the sharded state from all the other replicas # Pull the sharded state from all the other replicas
...@@ -176,9 +182,6 @@ class OSS(Optimizer): ...@@ -176,9 +182,6 @@ class OSS(Optimizer):
{"state": state_dict["state"][self.rank], "param_groups": state_dict["param_groups"][self.rank]} {"state": state_dict["state"][self.rank], "param_groups": state_dict["param_groups"][self.rank]}
) )
# Update the param_groups attribute for this instance
# TODO(ben)
def add_param_group(self, param_group: dict) -> None: def add_param_group(self, param_group: dict) -> None:
super().add_param_group(param_group) super().add_param_group(param_group)
if not self.in_super_constructor: if not self.in_super_constructor:
...@@ -186,10 +189,13 @@ class OSS(Optimizer): ...@@ -186,10 +189,13 @@ class OSS(Optimizer):
if len(param_groups) == len(self.optim.param_groups) + 1: if len(param_groups) == len(self.optim.param_groups) + 1:
self.optim.add_param_group(param_groups[-1]) self.optim.add_param_group(param_groups[-1])
def _sync_lr(self) -> None: def _sync_param_groups(self) -> None:
"""Sync learning rate (needed to support LRScheduler).""" """Sync learning rate and other optimizer attributes (needed to support schedulers)."""
for global_group, local_group in zip(self.param_groups, self.optim.param_groups): for global_group, local_group in zip(self.param_groups, self.optim.param_groups):
local_group["lr"] = global_group["lr"] for k in local_group.keys():
if k != "params":
# Params have been sharded and should not be synced here
local_group[k] = global_group[k]
def _collect_sharded_states(self) -> List[Dict[str, Any]]: def _collect_sharded_states(self) -> List[Dict[str, Any]]:
""" """
......
...@@ -54,11 +54,17 @@ def test_state_dict(): ...@@ -54,11 +54,17 @@ def test_state_dict():
assert "param_groups" in state_dict.keys() assert "param_groups" in state_dict.keys()
assert "state" in state_dict.keys() assert "state" in state_dict.keys()
# Check that the pulled state is what we expect # Check that the pulled state is what we expect, and that we have all the expected keys
assert state_dict["param_groups"][0][0]["lr"] == 0.1 assert state_dict["param_groups"][0][0]["lr"] == 0.1
assert state_dict["param_groups"][0][0]["momentum"] == 0.9
assert not state_dict["param_groups"][0][0]["nesterov"]
assert state_dict["param_groups"][0][0]["weight_decay"] == 0.0
assert state_dict["param_groups"][0][0]["dampening"] == 0.0
# Check that the pulled state and the .param_groups attribute are in sync # Check that the pulled state and the .param_groups attribute are in sync
assert state_dict["param_groups"][0][0]["lr"] == o.param_groups[0]["lr"] for k in state_dict["param_groups"][0][0].keys():
if k != "params":
assert state_dict["param_groups"][0][0][k] == o.param_groups[0][k]
# Check that it's correctly loaded # Check that it's correctly loaded
o = optim.OSS([x], lr=0.01) o = optim.OSS([x], lr=0.01)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment