Unverified Commit e1f36346 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

capture default device when refreshing the params (#786)

parent e00dfd95
......@@ -462,10 +462,12 @@ class OSS(Optimizer):
of some parameters changed.
"""
# Make sure that we capture the current default device
self._default_device = list(self._per_device_params.keys())[0]
# Create the optim which will work on the param shard
if not hasattr(self, "optim"):
self._clear_cache()
self._default_device = list(self._per_device_params.keys())[0]
self.optim = self._optim_constructor(self.partition_parameters()[self.rank], **self._optim_defaults)
OSS._sync_param_groups(self.optim.param_groups, self.param_groups)
......
......@@ -151,6 +151,9 @@ class TestSingleRank(unittest.TestCase):
# Check that OSS detects that the device changed
o.step()
# Check that the default device has been updated
assert o._default_device.type == DEVICE
def test_step_with_extra_inner_key(self):
class SGDWithNewKey(torch.optim.SGD):
# Dummy optimizer which adds a new key to the param groups
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment