Unverified Commit 220ee323 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[fix] optim/oss: PyTorch already handles putting state on proper device (#54)

parent 09028a0d
......@@ -139,10 +139,7 @@ class OSS(Optimizer):
def load_local_state_dict(self, state_dict: dict) -> None:
""" Loads this rank's state_dict. """
# Make sure that the state is on the appropriate device
state_dict_ondevice = recursive_copy_to_device(state_dict, non_blocking=False, device=self._device)
self.optim.load_state_dict(state_dict_ondevice)
self.optim.load_state_dict(state_dict)
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
""" Restore the global parameter groups as well as the shard """
......
......@@ -41,7 +41,12 @@ def test_create():
def test_state_dict():
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], lr=0.1)
o = optim.OSS([x], lr=0.1, momentum=0.9)
x.backward()
o.step()
assert x == torch.tensor([0.9], device=DEVICE)
assert o.optim.state[x]["momentum_buffer"] == torch.tensor([1.0], device=DEVICE)
o.zero_grad()
o.consolidate_state_dict() # Sync state dict in between replicas - even if there are none
state_dict = o.state_dict()
......@@ -54,13 +59,16 @@ def test_state_dict():
# Check that it's correctly loaded
o = optim.OSS([x], lr=0.01)
o.load_state_dict(state_dict)
# Check that state is correct and on proper device
assert o.optim.state[x]["momentum_buffer"] == torch.tensor([1.0], device=DEVICE)
# We should now be using a lr of 0.1, both within the optimizer
# and as exposed by the .param_groups attribute
assert o.param_groups[0]["lr"] == 0.1
x.backward()
o.step()
assert x == torch.tensor([0.9], device=DEVICE)
assert x == torch.tensor([0.71], device=DEVICE)
assert o.optim.state[x]["momentum_buffer"] == torch.tensor([1.9], device=DEVICE)
# Check that the exposed param_groups are on the proper device
assert o.param_groups[0]["params"][0].device == x.device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment