Unverified Commit ab32cb7d authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[fix] optim/oss: work correctly with LRScheduler (#58)

* [fix] optim/oss: work correctly with LRScheduler

Sync lr before every step and before consolidate.
parent 8c8eb8e8
......@@ -94,6 +94,9 @@ class OSS(Optimizer):
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
# For example, the apex library contains fused optimizers with a step that supports extra kwargs.
def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) -> Optional[float]:
# Sync lr in case its been update by an LRScheduler.
self._sync_lr()
# Run the optimizer step on this shard only
loss = self.optim.step(closure=closure, **kwargs) # type: ignore
......@@ -113,6 +116,9 @@ class OSS(Optimizer):
This needs to be called on all replicas """
# Sync lr in case its been update by an LRScheduler.
self._sync_lr()
if self.rank == recipient_rank:
# Pull the sharded state from all the other replicas
# Store all the states in order, rank by rank
......@@ -154,17 +160,17 @@ class OSS(Optimizer):
param = id_map[k]
self.optim.state[param] = recursive_copy_to_device(v, non_blocking=True, device=param.device)
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
""" Restore the global parameter groups as well as the shard """
# Dispatch this rank's state dictionary to the wrapped shard optimizer
self.load_local_state_dict(state_dict["state"][self.rank])
# Restore the global param_groups (the params themselves are already correct)
for global_group, local_group in zip(self.param_groups, self.optim.param_groups):
for global_group, local_group in zip(self.param_groups, groups):
for k, v in local_group.items():
if k != "params":
global_group[k] = v
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
""" Restore the global parameter groups as well as the shard """
# Dispatch this rank's state dictionary to the wrapped shard optimizer
self.load_local_state_dict(state_dict["state"][self.rank])
def add_param_group(self, param_group: dict) -> None:
super().add_param_group(param_group)
if not self.in_super_constructor:
......@@ -172,6 +178,11 @@ class OSS(Optimizer):
if len(param_groups) == len(self.optim.param_groups) + 1:
self.optim.add_param_group(param_groups[-1])
def _sync_lr(self) -> None:
"""Sync learning rate (needed to support LRScheduler)."""
for global_group, local_group in zip(self.param_groups, self.optim.param_groups):
local_group["lr"] = global_group["lr"]
def _collect_sharded_states(self) -> List[Dict[str, Any]]:
"""
Collect all the state shards, in CPU memory.
......
......@@ -74,6 +74,25 @@ def test_state_dict():
assert o.param_groups[0]["params"][0].device == x.device
def test_lr_scheduler():
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
x2 = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], lr=0.01)
o2 = torch.optim.SGD([x2], lr=0.01)
s = torch.optim.lr_scheduler.StepLR(o, 1)
s2 = torch.optim.lr_scheduler.StepLR(o2, 1)
for _ in range(5):
x.backward()
o.zero_grad()
o.step()
s.step()
x2.backward()
o2.zero_grad()
o2.step()
s2.step()
assert x == x2
class SGDWithStepKWArg(torch.optim.SGD):
def step(self, closure=None, kwarg=[]):
super().step()
......@@ -97,6 +116,8 @@ def test_local_state_dict():
o = optim.OSS([x], lr=0.01)
o.load_local_state_dict(local_state_dict)
# We should now be using a lr of 0.1.
assert o.optim.param_groups[0]["lr"] == 0.1
assert o.param_groups[0]["lr"] == 0.1
x.backward()
o.step()
assert x == torch.tensor([0.9], device=DEVICE)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment