Unverified Commit 88553373 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] OSS - enforce cuda parameters for state consolidation if NCCL backend (#573)

parent 04001e76
......@@ -326,7 +326,10 @@ class OSS(Optimizer):
self._all_states = []
should_collect_state = self.rank == recipient_rank or recipient_rank == -1
should_send_state = (self.rank != recipient_rank and recipient_rank != -1) or recipient_rank == -1
should_send_state = self.rank != recipient_rank
# NCCL requires CUDA tensors for all communication primitives
dist_device = torch.device("cuda") if self.backend == dist.Backend.NCCL else self._default_device
for rank in range(self.world_size):
if rank == self.rank:
......@@ -340,18 +343,18 @@ class OSS(Optimizer):
state_to_share = (
self.optim.state_dict()
if should_send_state
else torch.tensor([0], dtype=torch.uint8, device=self._default_device)
else torch.tensor([0], dtype=torch.uint8, device=dist_device)
)
broadcast_object(
state_to_share, src_rank=self.global_rank, group=self.group, dist_device=self._default_device,
state_to_share, src_rank=self.global_rank, group=self.group, dist_device=dist_device,
)
else:
# Fetch the optim state from the other replicas
replica_state = broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=self._default_device),
torch.tensor([0], dtype=torch.uint8, device=dist_device),
src_rank=self._local_to_global_rank[rank],
group=self.group,
dist_device=self._default_device,
dist_device=dist_device,
)
if should_collect_state:
......
......@@ -470,6 +470,11 @@ def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name):
_ = optimizer.step(closure=closure)
check_same_models_across_ranks(model, dist.group.WORLD, params_should_be_equal=True, check_broadcast_buffers=False)
# Check that if the model is moved to cpu, the optimizer consolidation still works
model.cpu()
optimizer = optim.OSS(model.parameters(), lr=0.1, momentum=0.99)
optimizer.consolidate_state_dict(recipient_rank=reference_rank)
dist.destroy_process_group()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment