Unverified Commit 9c4e6d1a authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] flaky SDP tests with Gloo, checking all handles (#499)

* seemingly fix flakyness for gloo by checking all coms handles
parent 8eaa3622
......@@ -120,6 +120,7 @@ class ShardedDataParallel(nn.Module):
# Communication related attributes
self.process_group = process_group if process_group is not None else dist.group.WORLD
self.backend = dist.get_backend(self.process_group)
self.world_size_scaling = 1.0 / dist.get_world_size(self.process_group) # > 0
self.reference_global_rank = OSS.get_global_rank(self.process_group, 0) # picking rank 0 as the reference
self.rank = dist.get_rank(self.process_group)
......@@ -311,16 +312,18 @@ class ShardedDataParallel(nn.Module):
blocking (bool): wait for the operation to conclude.
"""
last_work_handle = None
work_handles = []
for buffer in self.module.buffers(recurse=True):
last_work_handle = dist.broadcast(
buffer.data, self.reference_global_rank, self.process_group, async_op=True
work_handles.append(
dist.broadcast(buffer.data, self.reference_global_rank, self.process_group, async_op=True)
)
if blocking and last_work_handle:
# Only wait for the last coms, they're inlined on the same CUDA stream
last_work_handle.wait()
if blocking and work_handles:
if self.backend != dist.Backend.NCCL:
_ = list(filter(lambda x: x.wait(), work_handles))
else:
work_handles[-1].wait()
def zero_grad(self, set_to_none: bool = False) -> None:
r"""Sets gradients of all model parameters to zero. See similar function
......@@ -505,16 +508,18 @@ class ShardedDataParallel(nn.Module):
Sync the complete model states in between the ranks
"""
last_work_handle = None
work_handles = []
for t in self.module.state_dict().values():
last_work_handle = dist.broadcast(
t, src=self.reference_global_rank, group=self.process_group, async_op=True
work_handles.append(
dist.broadcast(t, src=self.reference_global_rank, group=self.process_group, async_op=True)
)
# Only wait for the last handle, they're inlined in the same CUDA stream
if last_work_handle:
last_work_handle.wait()
# gloo does not guarantee inlining like NCCL, wait for all requests
if self.backend != dist.Backend.NCCL:
_ = list(filter(lambda x: x.wait(), work_handles))
elif work_handles:
work_handles[-1].wait()
def _passing_sync_batchnorm_handle(self, module: nn.Module) -> None:
"""
......
......@@ -546,19 +546,19 @@ class OSS(Optimizer):
def _broadcast_params(self) -> None:
"""Helper function to broadcast all the parameters from a given device"""
last_work_handle = None # Work handles are consumed within this scope, no callback
work_handles = [] # Work handles are consumed within this scope, no callback
for device in self.buckets.keys():
for src_rank, bucket in enumerate(self.buckets[device]):
if bucket.numel() > 0:
global_src_rank = self.get_global_rank(self.group, src_rank)
last_work_handle = dist.broadcast(
tensor=bucket, src=self._local_to_global_rank[src_rank], group=self.group, async_op=True
work_handles.append(
dist.broadcast(
tensor=bucket, src=self._local_to_global_rank[src_rank], group=self.group, async_op=True
)
)
# Only check on the last handle, they're all inlined on the same CUDA stream
if last_work_handle:
last_work_handle.wait()
_ = list(filter(lambda x: x.wait(), work_handles))
def _setup_flat_buffers(self) -> None:
"""Make all params which are on the same device and tied to the same rank views of a single buffer.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment