Unverified Commit d217278c authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat][OSS] handle the device being changed after construction (#523)

parent 2d2412e2
......@@ -220,6 +220,12 @@ class OSS(Optimizer):
# Sync oss param_groups attributes in case they've been updated by a scheduler.
OSS._sync_param_groups(self.param_groups, self.optim.param_groups)
# Catch a possible change of devices in between OSS construction and step()
if self._default_device.type != self.param_groups[0]["params"][0].device.type:
logging.info("OSS detected that the parameter changed devices, re-allocating buffers")
self._clear_cache()
self.refresh_trainable()
# Run the optimizer step on this shard only:
if closure is not None:
loss = self.optim.step(closure=closure, **kwargs) # type: ignore
......@@ -591,3 +597,9 @@ class OSS(Optimizer):
else:
# This rank has an empty shard, that's fine
self.buckets[device].append(torch.zeros(0, device=device))
# Clear the buffer keys which are not in use anymore (could be that the devices changed)
devices_in_use = list(self.per_device_params.keys())
devices_to_pop = list(filter(lambda x: x not in devices_in_use, self.buckets.keys()))
for d in devices_to_pop:
self.buckets.pop(d)
......@@ -154,6 +154,18 @@ class TestSingleRank(unittest.TestCase):
assert kwarg == [5]
assert x == torch.tensor([0.9], device=DEVICE)
@skip_if_no_cuda
def test_device_change(self):
x = torch.nn.Linear(1, 1).to("cpu")
o = optim.OSS(x.parameters(), torch.optim.SGD, lr=0.1)
# Move the model to device after OSS was constructed
x.to(DEVICE)
x(torch.zeros((1), device=DEVICE)).backward()
# Check that OSS detects that the device changed
o.step()
def test_step_with_extra_inner_key(self):
class SGDWithNewKey(torch.optim.SGD):
# Dummy optimizer which adds a new key to the param groups
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment