"vscode:/vscode.git/clone" did not exist on "b72b766ee3a33168c68a65802b6e9c57aa5e6ff2"
Unverified Commit 88553373 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] OSS - enforce cuda parameters for state consolidation if NCCL backend (#573)

parent 04001e76
...@@ -326,7 +326,10 @@ class OSS(Optimizer): ...@@ -326,7 +326,10 @@ class OSS(Optimizer):
self._all_states = [] self._all_states = []
should_collect_state = self.rank == recipient_rank or recipient_rank == -1 should_collect_state = self.rank == recipient_rank or recipient_rank == -1
should_send_state = (self.rank != recipient_rank and recipient_rank != -1) or recipient_rank == -1 should_send_state = self.rank != recipient_rank
# NCCL requires CUDA tensors for all communication primitives
dist_device = torch.device("cuda") if self.backend == dist.Backend.NCCL else self._default_device
for rank in range(self.world_size): for rank in range(self.world_size):
if rank == self.rank: if rank == self.rank:
...@@ -340,18 +343,18 @@ class OSS(Optimizer): ...@@ -340,18 +343,18 @@ class OSS(Optimizer):
state_to_share = ( state_to_share = (
self.optim.state_dict() self.optim.state_dict()
if should_send_state if should_send_state
else torch.tensor([0], dtype=torch.uint8, device=self._default_device) else torch.tensor([0], dtype=torch.uint8, device=dist_device)
) )
broadcast_object( broadcast_object(
state_to_share, src_rank=self.global_rank, group=self.group, dist_device=self._default_device, state_to_share, src_rank=self.global_rank, group=self.group, dist_device=dist_device,
) )
else: else:
# Fetch the optim state from the other replicas # Fetch the optim state from the other replicas
replica_state = broadcast_object( replica_state = broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=self._default_device), torch.tensor([0], dtype=torch.uint8, device=dist_device),
src_rank=self._local_to_global_rank[rank], src_rank=self._local_to_global_rank[rank],
group=self.group, group=self.group,
dist_device=self._default_device, dist_device=dist_device,
) )
if should_collect_state: if should_collect_state:
......
...@@ -470,6 +470,11 @@ def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name): ...@@ -470,6 +470,11 @@ def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name):
_ = optimizer.step(closure=closure) _ = optimizer.step(closure=closure)
check_same_models_across_ranks(model, dist.group.WORLD, params_should_be_equal=True, check_broadcast_buffers=False) check_same_models_across_ranks(model, dist.group.WORLD, params_should_be_equal=True, check_broadcast_buffers=False)
# Check that if the model is moved to cpu, the optimizer consolidation still works
model.cpu()
optimizer = optim.OSS(model.parameters(), lr=0.1, momentum=0.99)
optimizer.consolidate_state_dict(recipient_rank=reference_rank)
dist.destroy_process_group() dist.destroy_process_group()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment