Unverified Commit eab1551a authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[OSS] Fix for torch dist broadcast randomly failing on dummy object (#323)

* fix for torch dist broadcast failing on dummy object
parent 1ece280a
......@@ -327,6 +327,9 @@ class OSS(Optimizer):
self.local_state_dict(), non_blocking=True, device=torch.device("cpu")
)
# Tensor cannot be really empty, even if its size is meaningless
dummy_sync_tensor = torch.tensor([1], device=self._device)
for rank in range(self.world_size):
if rank == self.rank:
# Send the state to the reference replica
......@@ -346,10 +349,10 @@ class OSS(Optimizer):
# Discard this tensor/rank, broadcast necessary for syncing and because NCCL does not support gather
if _torch_broadcast_object:
dist.broadcast_object_list([0], src=global_rank, group=self.group)
dist.broadcast_object_list([dummy_sync_tensor], src=global_rank, group=self.group)
else:
broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=self._device),
torch.tensor([dummy_sync_tensor], dtype=torch.uint8, device=self._device),
src_rank=global_rank,
group=self.group,
dist_device=self._device,
......
......@@ -28,4 +28,4 @@ use_parentheses = true
skip_glob = ["build/*", "stubs/*"]
# Don't split "import" and "from".
force_sort_within_sections = true
known_third_party = ["datasets", "golden_configs", "helpers", "models", "numpy", "pytest", "recommonmark", "setuptools", "torch", "torch_pg", "torchtext", "torchvision"]
known_third_party = ["benchmark_dataset", "datasets", "golden_configs", "helpers", "models", "numpy", "pytest", "recommonmark", "setuptools", "torch", "torch_pg", "torchtext", "torchvision"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment