Unverified Commit 3d7f524a authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat ] OSS : optional closure argument for the optimizer (#86)

Make OSS compatible with optimizers which do not support the closure argument
parent 6851247a
......@@ -103,8 +103,11 @@ class OSS(Optimizer):
# Sync oss param_groups attributes in case they've been updated by a scheduler.
self._sync_param_groups()
# Run the optimizer step on this shard only
# Run the optimizer step on this shard only:
if closure is not None:
loss = self.optim.step(closure=closure, **kwargs) # type: ignore
else:
loss = self.optim.step(**kwargs)
# Sync all the states. Broadcast requests are issued async, we check completeness before moving on
requests = []
......
......@@ -103,13 +103,12 @@ def test_lr_scheduler():
assert x == x2
class SGDWithStepKWArg(torch.optim.SGD):
def test_step_with_kwargs():
class SGDWithStepKWArg(torch.optim.SGD):
def step(self, closure=None, kwarg=[]):
super().step()
kwarg.append(5)
def test_step_with_kwargs():
kwarg = []
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], SGDWithStepKWArg, lr=0.1)
......@@ -119,6 +118,18 @@ def test_step_with_kwargs():
assert x == torch.tensor([0.9], device=DEVICE)
def test_step_without_closure():
class SGDWithoutClosure(torch.optim.SGD):
def step(self):
return super().step()
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], SGDWithoutClosure, lr=0.1)
x.backward()
o.step()
assert x == torch.tensor([0.9], device=DEVICE)
def test_local_state_dict():
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], lr=0.1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment