Unverified Commit 37c686e7 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[bugfix] OSS + Apex (#136)

* fixing the issue wrt Apex, validated with Latte, Classy would need another pass
parent 6d802f5a
...@@ -196,6 +196,9 @@ class OSS(Optimizer): ...@@ -196,6 +196,9 @@ class OSS(Optimizer):
) in self.per_device_params.items(): # all the params on this device (inc all ranks) ) in self.per_device_params.items(): # all the params on this device (inc all ranks)
self._broadcast_params(self._broadcast_buffers[device], device_params) self._broadcast_params(self._broadcast_buffers[device], device_params)
# Sync hypothethical new results from the wrapped optimizer to the exposed param_groups
self._sync_param_groups(local_to_global=True)
return loss return loss
def local_state_dict(self) -> dict: def local_state_dict(self) -> dict:
...@@ -334,12 +337,18 @@ class OSS(Optimizer): ...@@ -334,12 +337,18 @@ class OSS(Optimizer):
if len(param_groups) == len(self.optim.param_groups) + 1: if len(param_groups) == len(self.optim.param_groups) + 1:
self.optim.add_param_group(param_groups[-1]) self.optim.add_param_group(param_groups[-1])
def _sync_param_groups(self) -> None: def _sync_param_groups(self, local_to_global: bool = False) -> None:
"""Sync learning rate and other optimizer attributes (needed to support schedulers).""" """Sync learning rate and other optimizer attributes (needed to support schedulers).
If the global param groups have been altered, and we want to make sure that the
wrapped optimizer uses the up to date version.
Conversely if the wrapped optimizer has new keys, we expose them through the global param groups"""
for global_group, local_group in zip(self.param_groups, self.optim.param_groups): for global_group, local_group in zip(self.param_groups, self.optim.param_groups):
for k in local_group.keys(): # Sync everything but the parameters
if k != "params": for k in filter(lambda x: x != "params", local_group.keys()):
# Params have been sharded and should not be synced here if local_to_global:
global_group[k] = local_group[k]
elif k in global_group.keys():
local_group[k] = global_group[k] local_group[k] = global_group[k]
def _collect_sharded_states(self) -> List[Dict[str, Any]]: def _collect_sharded_states(self) -> List[Dict[str, Any]]:
......
...@@ -123,6 +123,21 @@ def test_step_with_kwargs(): ...@@ -123,6 +123,21 @@ def test_step_with_kwargs():
assert x == torch.tensor([0.9], device=DEVICE) assert x == torch.tensor([0.9], device=DEVICE)
def test_step_with_extra_inner_key():
class SGDWithNewKey(torch.optim.SGD):
# Dummy optimizer which adds a new key to the param groups
def step(self, closure=None):
super().step()
self.param_groups[0]["new_key"] = 0.1
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], SGDWithNewKey, lr=0.1)
x.backward()
o.step()
assert o.param_groups[0]["new_key"] == 0.1
assert x == torch.tensor([0.9], device=DEVICE)
def test_step_without_closure(): def test_step_without_closure():
class SGDWithoutClosure(torch.optim.SGD): class SGDWithoutClosure(torch.optim.SGD):
def step(self): def step(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment