Unverified Commit c391e4b6 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[Core] improve robustness of pynccl (#3860)

parent 9117f892
...@@ -236,22 +236,25 @@ class NCCLCommunicator: ...@@ -236,22 +236,25 @@ class NCCLCommunicator:
if local_rank == -1: if local_rank == -1:
local_rank = self.rank local_rank = self.rank
self.local_rank = local_rank self.local_rank = local_rank
torch.cuda.set_device(local_rank) # don't use these args, as they can be -1
if rank == 0: # use `self.rank`, `self.local_rank` and `self.world_size` instead
del world_size, rank, local_rank
torch.cuda.set_device(self.local_rank)
if self.rank == 0:
self.unique_id = ncclGetUniqueId() self.unique_id = ncclGetUniqueId()
else: else:
self.unique_id = NcclUniqueId() self.unique_id = NcclUniqueId()
tensor = torch.ByteTensor(list( tensor = torch.ByteTensor(list(self.unique_id.internal)).cuda(
self.unique_id.internal)).cuda(local_rank) self.local_rank)
dist.broadcast(tensor, src=0) dist.broadcast(tensor, src=0)
byte_list = tensor.cpu().tolist() byte_list = tensor.cpu().tolist()
for i, byte in enumerate(byte_list): for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte self.unique_id.internal[i] = byte
self.comm = ctypes.c_void_p() self.comm = ctypes.c_void_p()
result = _c_ncclCommInitRank(ctypes.byref(self.comm), world_size, result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
self.unique_id, rank) self.unique_id, self.rank)
assert result == 0 assert result == 0
self.stream = torch.cuda.Stream(device=f"cuda:{local_rank}") self.stream = torch.cuda.Stream(device=f"cuda:{self.local_rank}")
def all_reduce(self, def all_reduce(self,
tensor: torch.Tensor, tensor: torch.Tensor,
...@@ -271,4 +274,6 @@ class NCCLCommunicator: ...@@ -271,4 +274,6 @@ class NCCLCommunicator:
# `dist` module might have been already destroyed # `dist` module might have been already destroyed
if hasattr(dist, 'destroy_process_group'): if hasattr(dist, 'destroy_process_group'):
dist.destroy_process_group() dist.destroy_process_group()
_c_ncclCommDestroy(self.comm) # function might have been already destroyed
if _c_ncclCommDestroy is not None:
_c_ncclCommDestroy(self.comm)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment