Unverified Commit 9c4e6d1a authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] flaky SDP tests with Gloo, checking all handles (#499)

* seemingly fix flakyness for gloo by checking all coms handles
parent 8eaa3622
...@@ -120,6 +120,7 @@ class ShardedDataParallel(nn.Module): ...@@ -120,6 +120,7 @@ class ShardedDataParallel(nn.Module):
# Communication related attributes # Communication related attributes
self.process_group = process_group if process_group is not None else dist.group.WORLD self.process_group = process_group if process_group is not None else dist.group.WORLD
self.backend = dist.get_backend(self.process_group)
self.world_size_scaling = 1.0 / dist.get_world_size(self.process_group) # > 0 self.world_size_scaling = 1.0 / dist.get_world_size(self.process_group) # > 0
self.reference_global_rank = OSS.get_global_rank(self.process_group, 0) # picking rank 0 as the reference self.reference_global_rank = OSS.get_global_rank(self.process_group, 0) # picking rank 0 as the reference
self.rank = dist.get_rank(self.process_group) self.rank = dist.get_rank(self.process_group)
...@@ -311,16 +312,18 @@ class ShardedDataParallel(nn.Module): ...@@ -311,16 +312,18 @@ class ShardedDataParallel(nn.Module):
blocking (bool): wait for the operation to conclude. blocking (bool): wait for the operation to conclude.
""" """
last_work_handle = None work_handles = []
for buffer in self.module.buffers(recurse=True): for buffer in self.module.buffers(recurse=True):
last_work_handle = dist.broadcast( work_handles.append(
buffer.data, self.reference_global_rank, self.process_group, async_op=True dist.broadcast(buffer.data, self.reference_global_rank, self.process_group, async_op=True)
) )
if blocking and last_work_handle: if blocking and work_handles:
# Only wait for the last coms, they're inlined on the same CUDA stream if self.backend != dist.Backend.NCCL:
last_work_handle.wait() _ = list(filter(lambda x: x.wait(), work_handles))
else:
work_handles[-1].wait()
def zero_grad(self, set_to_none: bool = False) -> None: def zero_grad(self, set_to_none: bool = False) -> None:
r"""Sets gradients of all model parameters to zero. See similar function r"""Sets gradients of all model parameters to zero. See similar function
...@@ -505,16 +508,18 @@ class ShardedDataParallel(nn.Module): ...@@ -505,16 +508,18 @@ class ShardedDataParallel(nn.Module):
Sync the complete model states in between the ranks Sync the complete model states in between the ranks
""" """
last_work_handle = None work_handles = []
for t in self.module.state_dict().values(): for t in self.module.state_dict().values():
last_work_handle = dist.broadcast( work_handles.append(
t, src=self.reference_global_rank, group=self.process_group, async_op=True dist.broadcast(t, src=self.reference_global_rank, group=self.process_group, async_op=True)
) )
# Only wait for the last handle, they're inlined in the same CUDA stream # gloo does not guarantee inlining like NCCL, wait for all requests
if last_work_handle: if self.backend != dist.Backend.NCCL:
last_work_handle.wait() _ = list(filter(lambda x: x.wait(), work_handles))
elif work_handles:
work_handles[-1].wait()
def _passing_sync_batchnorm_handle(self, module: nn.Module) -> None: def _passing_sync_batchnorm_handle(self, module: nn.Module) -> None:
""" """
......
...@@ -546,19 +546,19 @@ class OSS(Optimizer): ...@@ -546,19 +546,19 @@ class OSS(Optimizer):
def _broadcast_params(self) -> None: def _broadcast_params(self) -> None:
"""Helper function to broadcast all the parameters from a given device""" """Helper function to broadcast all the parameters from a given device"""
last_work_handle = None # Work handles are consumed within this scope, no callback work_handles = [] # Work handles are consumed within this scope, no callback
for device in self.buckets.keys(): for device in self.buckets.keys():
for src_rank, bucket in enumerate(self.buckets[device]): for src_rank, bucket in enumerate(self.buckets[device]):
if bucket.numel() > 0: if bucket.numel() > 0:
global_src_rank = self.get_global_rank(self.group, src_rank) work_handles.append(
last_work_handle = dist.broadcast( dist.broadcast(
tensor=bucket, src=self._local_to_global_rank[src_rank], group=self.group, async_op=True tensor=bucket, src=self._local_to_global_rank[src_rank], group=self.group, async_op=True
)
) )
# Only check on the last handle, they're all inlined on the same CUDA stream # Only check on the last handle, they're all inlined on the same CUDA stream
if last_work_handle: _ = list(filter(lambda x: x.wait(), work_handles))
last_work_handle.wait()
def _setup_flat_buffers(self) -> None: def _setup_flat_buffers(self) -> None:
"""Make all params which are on the same device and tied to the same rank views of a single buffer. """Make all params which are on the same device and tied to the same rank views of a single buffer.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment