Unverified Commit 3cd9b5bb authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[Core][Distributed] use existing torch.cuda.device (#4318)

[Core][Distributed] use existing torch.cuda.device context manager (#4318)
parent 468d761b
......@@ -250,15 +250,13 @@ class NCCLCommunicator:
assert isinstance(device, torch.device)
self.device = device
# nccl communicator and stream will use this device
current_device = torch.cuda.current_device()
try:
torch.cuda.set_device(device)
# `torch.cuda.device` is a context manager that changes the
# current cuda device to the specified one
with torch.cuda.device(device):
NCCL_CHECK(
_c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
self.unique_id, self.rank))
self.stream = torch.cuda.Stream()
finally:
torch.cuda.set_device(current_device)
def all_reduce(self,
tensor: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment