"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "3fcfbe7549d9694f96e1f19630add4adf99dd421"
Unverified Commit cf0af9a3 authored by heya5's avatar heya5 Committed by GitHub
Browse files

[Trainer] Add optional communication backends for torch.distributed when using GPU (#22247)

Update training_args.py
parent c4bf6f38
......@@ -1641,7 +1641,10 @@ class TrainingArguments:
# Here, we'll use torch.distributed.
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl", timeout=self.ddp_timeout_delta)
if self.xpu_backend and self.xpu_backend in ("mpi", "gloo"):
torch.distributed.init_process_group(backend=self.xpu_backend, timeout=self.ddp_timeout_delta)
else:
torch.distributed.init_process_group(backend="nccl", timeout=self.ddp_timeout_delta)
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment