Unverified Commit ab32cb7d authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[fix] optim/oss: work correctly with LRScheduler (#58)

* [fix] optim/oss: work correctly with LRScheduler

Sync lr before every step and before consolidate.
parent 8c8eb8e8
...@@ -94,6 +94,9 @@ class OSS(Optimizer): ...@@ -94,6 +94,9 @@ class OSS(Optimizer):
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs. # NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
# For example, the apex library contains fused optimizers with a step that supports extra kwargs. # For example, the apex library contains fused optimizers with a step that supports extra kwargs.
def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) -> Optional[float]: def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) -> Optional[float]:
# Sync lr in case its been update by an LRScheduler.
self._sync_lr()
# Run the optimizer step on this shard only # Run the optimizer step on this shard only
loss = self.optim.step(closure=closure, **kwargs) # type: ignore loss = self.optim.step(closure=closure, **kwargs) # type: ignore
...@@ -113,6 +116,9 @@ class OSS(Optimizer): ...@@ -113,6 +116,9 @@ class OSS(Optimizer):
This needs to be called on all replicas """ This needs to be called on all replicas """
# Sync lr in case its been update by an LRScheduler.
self._sync_lr()
if self.rank == recipient_rank: if self.rank == recipient_rank:
# Pull the sharded state from all the other replicas # Pull the sharded state from all the other replicas
# Store all the states in order, rank by rank # Store all the states in order, rank by rank
...@@ -154,17 +160,17 @@ class OSS(Optimizer): ...@@ -154,17 +160,17 @@ class OSS(Optimizer):
param = id_map[k] param = id_map[k]
self.optim.state[param] = recursive_copy_to_device(v, non_blocking=True, device=param.device) self.optim.state[param] = recursive_copy_to_device(v, non_blocking=True, device=param.device)
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
""" Restore the global parameter groups as well as the shard """
# Dispatch this rank's state dictionary to the wrapped shard optimizer
self.load_local_state_dict(state_dict["state"][self.rank])
# Restore the global param_groups (the params themselves are already correct) # Restore the global param_groups (the params themselves are already correct)
for global_group, local_group in zip(self.param_groups, self.optim.param_groups): for global_group, local_group in zip(self.param_groups, groups):
for k, v in local_group.items(): for k, v in local_group.items():
if k != "params": if k != "params":
global_group[k] = v global_group[k] = v
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
""" Restore the global parameter groups as well as the shard """
# Dispatch this rank's state dictionary to the wrapped shard optimizer
self.load_local_state_dict(state_dict["state"][self.rank])
def add_param_group(self, param_group: dict) -> None: def add_param_group(self, param_group: dict) -> None:
super().add_param_group(param_group) super().add_param_group(param_group)
if not self.in_super_constructor: if not self.in_super_constructor:
...@@ -172,6 +178,11 @@ class OSS(Optimizer): ...@@ -172,6 +178,11 @@ class OSS(Optimizer):
if len(param_groups) == len(self.optim.param_groups) + 1: if len(param_groups) == len(self.optim.param_groups) + 1:
self.optim.add_param_group(param_groups[-1]) self.optim.add_param_group(param_groups[-1])
def _sync_lr(self) -> None:
"""Sync learning rate (needed to support LRScheduler)."""
for global_group, local_group in zip(self.param_groups, self.optim.param_groups):
local_group["lr"] = global_group["lr"]
def _collect_sharded_states(self) -> List[Dict[str, Any]]: def _collect_sharded_states(self) -> List[Dict[str, Any]]:
""" """
Collect all the state shards, in CPU memory. Collect all the state shards, in CPU memory.
......
...@@ -74,6 +74,25 @@ def test_state_dict(): ...@@ -74,6 +74,25 @@ def test_state_dict():
assert o.param_groups[0]["params"][0].device == x.device assert o.param_groups[0]["params"][0].device == x.device
def test_lr_scheduler():
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
x2 = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], lr=0.01)
o2 = torch.optim.SGD([x2], lr=0.01)
s = torch.optim.lr_scheduler.StepLR(o, 1)
s2 = torch.optim.lr_scheduler.StepLR(o2, 1)
for _ in range(5):
x.backward()
o.zero_grad()
o.step()
s.step()
x2.backward()
o2.zero_grad()
o2.step()
s2.step()
assert x == x2
class SGDWithStepKWArg(torch.optim.SGD): class SGDWithStepKWArg(torch.optim.SGD):
def step(self, closure=None, kwarg=[]): def step(self, closure=None, kwarg=[]):
super().step() super().step()
...@@ -97,6 +116,8 @@ def test_local_state_dict(): ...@@ -97,6 +116,8 @@ def test_local_state_dict():
o = optim.OSS([x], lr=0.01) o = optim.OSS([x], lr=0.01)
o.load_local_state_dict(local_state_dict) o.load_local_state_dict(local_state_dict)
# We should now be using a lr of 0.1. # We should now be using a lr of 0.1.
assert o.optim.param_groups[0]["lr"] == 0.1
assert o.param_groups[0]["lr"] == 0.1
x.backward() x.backward()
o.step() o.step()
assert x == torch.tensor([0.9], device=DEVICE) assert x == torch.tensor([0.9], device=DEVICE)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment