Unverified Commit c2d6f4b6 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] OSS restore state to proper device (#46)

* move the restored param groups to the original device

* adding a corresponding test
parent 9d6c7b6a
...@@ -148,7 +148,7 @@ class OSS(Optimizer): ...@@ -148,7 +148,7 @@ class OSS(Optimizer):
self.load_local_state_dict(state_dict["state"][self.rank]) self.load_local_state_dict(state_dict["state"][self.rank])
# Restore the global param_groups # Restore the global param_groups
self.param_groups = state_dict["param_groups"] self.param_groups = recursive_copy_to_device(state_dict["param_groups"], non_blocking=True, device=self._device)
def add_param_group(self, param_group: dict) -> None: def add_param_group(self, param_group: dict) -> None:
super().add_param_group(param_group) super().add_param_group(param_group)
......
...@@ -62,6 +62,9 @@ def test_state_dict(): ...@@ -62,6 +62,9 @@ def test_state_dict():
o.step() o.step()
assert x == torch.tensor([0.9], device=DEVICE) assert x == torch.tensor([0.9], device=DEVICE)
# Check that the exposed param_groups are on the proper device
assert o.param_groups[0]["params"][0].device == x.device
def test_local_state_dict(): def test_local_state_dict():
x = torch.tensor([1.0], device=DEVICE, requires_grad=True) x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment