"docs/vscode:/vscode.git/clone" did not exist on "38bed912e36e1725ede1f0e8c61a514f378697c3"
Unverified Commit 277d49a5 authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

Do not initialize `torch.distributed` process group if one is already initailized (#16487)

* Do not initialize torch process group twice

* Apply suggestions from code review
parent 2b483230
......@@ -77,6 +77,11 @@ class SageMakerTrainingArguments(TrainingArguments):
@cached_property
def _setup_devices(self) -> "torch.device":
logger.info("PyTorch: setting up devices")
if torch.distributed.is_initialized() and self.local_rank == -1:
logger.warning(
"torch.distributed process group is initialized, but local_rank == -1. "
"In order to use Torch DDP, launch your script with `python -m torch.distributed.launch"
)
if self.no_cuda:
device = torch.device("cpu")
self._n_gpu = 0
......@@ -105,7 +110,8 @@ class SageMakerTrainingArguments(TrainingArguments):
else:
# Here, we'll use torch.distributed.
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs
torch.distributed.init_process_group(backend="nccl")
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1
......
......@@ -1022,10 +1022,15 @@ class TrainingArguments:
@torch_required
def _setup_devices(self) -> "torch.device":
logger.info("PyTorch: setting up devices")
if torch.distributed.is_initialized() and self.local_rank == -1:
logger.warning(
"torch.distributed process group is initialized, but local_rank == -1. "
"In order to use Torch DDP, launch your script with `python -m torch.distributed.launch"
)
if self.no_cuda:
device = torch.device("cpu")
self._n_gpu = 0
if self.local_rank != -1:
if self.local_rank != -1 and not torch.distributed.is_initialized():
# Initializes distributed backend for cpu
if self.xpu_backend not in ("mpi", "ccl"):
raise ValueError(
......@@ -1076,7 +1081,8 @@ class TrainingArguments:
else:
# Here, we'll use torch.distributed.
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs
torch.distributed.init_process_group(backend="nccl")
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment