"vscode:/vscode.git/clone" did not exist on "27edd2aeb410e3126655ca76ef3fd32913ca52fa"
Unverified Commit 3cd9b5bb authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[Core][Distributed] use existing torch.cuda.device (#4318)

[Core][Distributed] use existing torch.cuda.device context manager (#4318)
parent 468d761b
...@@ -250,15 +250,13 @@ class NCCLCommunicator: ...@@ -250,15 +250,13 @@ class NCCLCommunicator:
assert isinstance(device, torch.device) assert isinstance(device, torch.device)
self.device = device self.device = device
# nccl communicator and stream will use this device # nccl communicator and stream will use this device
current_device = torch.cuda.current_device() # `torch.cuda.device` is a context manager that changes the
try: # current cuda device to the specified one
torch.cuda.set_device(device) with torch.cuda.device(device):
NCCL_CHECK( NCCL_CHECK(
_c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size, _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
self.unique_id, self.rank)) self.unique_id, self.rank))
self.stream = torch.cuda.Stream() self.stream = torch.cuda.Stream()
finally:
torch.cuda.set_device(current_device)
def all_reduce(self, def all_reduce(self,
tensor: torch.Tensor, tensor: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment