"docs/git@developer.sourcefind.cn:one/TransferBench.git" did not exist on "fb07c307a1389a6a92a6e2e21dc243418f92ed82"
Unverified Commit 220ee323 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[fix] optim/oss: PyTorch already handles putting state on proper device (#54)

parent 09028a0d
...@@ -139,10 +139,7 @@ class OSS(Optimizer): ...@@ -139,10 +139,7 @@ class OSS(Optimizer):
def load_local_state_dict(self, state_dict: dict) -> None: def load_local_state_dict(self, state_dict: dict) -> None:
""" Loads this rank's state_dict. """ """ Loads this rank's state_dict. """
# Make sure that the state is on the appropriate device self.optim.load_state_dict(state_dict)
state_dict_ondevice = recursive_copy_to_device(state_dict, non_blocking=False, device=self._device)
self.optim.load_state_dict(state_dict_ondevice)
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
""" Restore the global parameter groups as well as the shard """ """ Restore the global parameter groups as well as the shard """
......
...@@ -41,7 +41,12 @@ def test_create(): ...@@ -41,7 +41,12 @@ def test_create():
def test_state_dict(): def test_state_dict():
x = torch.tensor([1.0], device=DEVICE, requires_grad=True) x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], lr=0.1) o = optim.OSS([x], lr=0.1, momentum=0.9)
x.backward()
o.step()
assert x == torch.tensor([0.9], device=DEVICE)
assert o.optim.state[x]["momentum_buffer"] == torch.tensor([1.0], device=DEVICE)
o.zero_grad()
o.consolidate_state_dict() # Sync state dict in between replicas - even if there are none o.consolidate_state_dict() # Sync state dict in between replicas - even if there are none
state_dict = o.state_dict() state_dict = o.state_dict()
...@@ -54,13 +59,16 @@ def test_state_dict(): ...@@ -54,13 +59,16 @@ def test_state_dict():
# Check that it's correctly loaded # Check that it's correctly loaded
o = optim.OSS([x], lr=0.01) o = optim.OSS([x], lr=0.01)
o.load_state_dict(state_dict) o.load_state_dict(state_dict)
# Check that state is correct and on proper device
assert o.optim.state[x]["momentum_buffer"] == torch.tensor([1.0], device=DEVICE)
# We should now be using a lr of 0.1, both within the optimizer # We should now be using a lr of 0.1, both within the optimizer
# and as exposed by the .param_groups attribute # and as exposed by the .param_groups attribute
assert o.param_groups[0]["lr"] == 0.1 assert o.param_groups[0]["lr"] == 0.1
x.backward() x.backward()
o.step() o.step()
assert x == torch.tensor([0.9], device=DEVICE) assert x == torch.tensor([0.71], device=DEVICE)
assert o.optim.state[x]["momentum_buffer"] == torch.tensor([1.9], device=DEVICE)
# Check that the exposed param_groups are on the proper device # Check that the exposed param_groups are on the proper device
assert o.param_groups[0]["params"][0].device == x.device assert o.param_groups[0]["params"][0].device == x.device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment