Unverified Commit 09028a0d authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[fix] optim/oss: support optimizers with additional step kwargs (#53)

* [fix] optim/oss: support optimizers with additional step kwargs

Some of the optimizers in apex support additional kwargs to step
such as scale.
parent 5251a69a
[settings]
known_third_party =numpy,pytest,setuptools,torch,torchtext
known_third_party =numpy,pytest,setuptools,torch,torchtext,torchvision
......@@ -91,9 +91,11 @@ class OSS(Optimizer):
param_groups[rank].append(param_group_rank)
return param_groups
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
# For example, the apex library contains fused optimizers with a step that supports extra kwargs.
def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) -> Optional[float]:
# Run the optimizer step on this shard only
loss = self.optim.step(closure=closure)
loss = self.optim.step(closure=closure, **kwargs) # type: ignore
# Sync all the states
for rank, param_groups in enumerate(self.partition_parameters()):
......
......@@ -66,6 +66,22 @@ def test_state_dict():
assert o.param_groups[0]["params"][0].device == x.device
class SGDWithStepKWArg(torch.optim.SGD):
def step(self, closure=None, kwarg=[]):
super().step()
kwarg.append(5)
def test_step_with_kwargs():
kwarg = []
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], SGDWithStepKWArg, lr=0.1)
x.backward()
o.step(0, kwarg=kwarg)
assert kwarg == [5]
assert x == torch.tensor([0.9], device=DEVICE)
def test_local_state_dict():
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], lr=0.1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment