Unverified Commit e1f36346 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

capture default device when refreshing the params (#786)

parent e00dfd95
...@@ -462,10 +462,12 @@ class OSS(Optimizer): ...@@ -462,10 +462,12 @@ class OSS(Optimizer):
of some parameters changed. of some parameters changed.
""" """
# Make sure that we capture the current default device
self._default_device = list(self._per_device_params.keys())[0]
# Create the optim which will work on the param shard # Create the optim which will work on the param shard
if not hasattr(self, "optim"): if not hasattr(self, "optim"):
self._clear_cache() self._clear_cache()
self._default_device = list(self._per_device_params.keys())[0]
self.optim = self._optim_constructor(self.partition_parameters()[self.rank], **self._optim_defaults) self.optim = self._optim_constructor(self.partition_parameters()[self.rank], **self._optim_defaults)
OSS._sync_param_groups(self.optim.param_groups, self.param_groups) OSS._sync_param_groups(self.optim.param_groups, self.param_groups)
......
...@@ -151,6 +151,9 @@ class TestSingleRank(unittest.TestCase): ...@@ -151,6 +151,9 @@ class TestSingleRank(unittest.TestCase):
# Check that OSS detects that the device changed # Check that OSS detects that the device changed
o.step() o.step()
# Check that the default device has been updated
assert o._default_device.type == DEVICE
def test_step_with_extra_inner_key(self): def test_step_with_extra_inner_key(self):
class SGDWithNewKey(torch.optim.SGD): class SGDWithNewKey(torch.optim.SGD):
# Dummy optimizer which adds a new key to the param groups # Dummy optimizer which adds a new key to the param groups
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment