Unverified Commit 277d49a5 authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

Do not initialize `torch.distributed` process group if one is already initailized (#16487)

* Do not initialize torch process group twice

* Apply suggestions from code review
parent 2b483230
...@@ -77,6 +77,11 @@ class SageMakerTrainingArguments(TrainingArguments): ...@@ -77,6 +77,11 @@ class SageMakerTrainingArguments(TrainingArguments):
@cached_property @cached_property
def _setup_devices(self) -> "torch.device": def _setup_devices(self) -> "torch.device":
logger.info("PyTorch: setting up devices") logger.info("PyTorch: setting up devices")
if torch.distributed.is_initialized() and self.local_rank == -1:
logger.warning(
"torch.distributed process group is initialized, but local_rank == -1. "
"In order to use Torch DDP, launch your script with `python -m torch.distributed.launch"
)
if self.no_cuda: if self.no_cuda:
device = torch.device("cpu") device = torch.device("cpu")
self._n_gpu = 0 self._n_gpu = 0
...@@ -105,7 +110,8 @@ class SageMakerTrainingArguments(TrainingArguments): ...@@ -105,7 +110,8 @@ class SageMakerTrainingArguments(TrainingArguments):
else: else:
# Here, we'll use torch.distributed. # Here, we'll use torch.distributed.
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs # Initializes the distributed backend which will take care of synchronizing nodes/GPUs
torch.distributed.init_process_group(backend="nccl") if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
device = torch.device("cuda", self.local_rank) device = torch.device("cuda", self.local_rank)
self._n_gpu = 1 self._n_gpu = 1
......
...@@ -1022,10 +1022,15 @@ class TrainingArguments: ...@@ -1022,10 +1022,15 @@ class TrainingArguments:
@torch_required @torch_required
def _setup_devices(self) -> "torch.device": def _setup_devices(self) -> "torch.device":
logger.info("PyTorch: setting up devices") logger.info("PyTorch: setting up devices")
if torch.distributed.is_initialized() and self.local_rank == -1:
logger.warning(
"torch.distributed process group is initialized, but local_rank == -1. "
"In order to use Torch DDP, launch your script with `python -m torch.distributed.launch"
)
if self.no_cuda: if self.no_cuda:
device = torch.device("cpu") device = torch.device("cpu")
self._n_gpu = 0 self._n_gpu = 0
if self.local_rank != -1: if self.local_rank != -1 and not torch.distributed.is_initialized():
# Initializes distributed backend for cpu # Initializes distributed backend for cpu
if self.xpu_backend not in ("mpi", "ccl"): if self.xpu_backend not in ("mpi", "ccl"):
raise ValueError( raise ValueError(
...@@ -1076,7 +1081,8 @@ class TrainingArguments: ...@@ -1076,7 +1081,8 @@ class TrainingArguments:
else: else:
# Here, we'll use torch.distributed. # Here, we'll use torch.distributed.
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs # Initializes the distributed backend which will take care of synchronizing nodes/GPUs
torch.distributed.init_process_group(backend="nccl") if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
device = torch.device("cuda", self.local_rank) device = torch.device("cuda", self.local_rank)
self._n_gpu = 1 self._n_gpu = 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment