Unverified Commit f7c93b3c authored by Bram Vanroy's avatar Bram Vanroy Committed by GitHub
Browse files

Possible fix to make AMP work with DDP in the trainer (#4728)

* manually set device in trainer args

* check if current device is cuda before set_device

* Explicitly set GPU ID when using single GPU

This addresses https://github.com/huggingface/transformers/issues/4657#issuecomment-642228099
parent 66bcfbb1
......@@ -172,7 +172,11 @@ class TrainingArguments:
elif self.local_rank == -1:
# if n_gpu is > 1 we'll use nn.DataParallel.
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will
# trigger an error that a device index is missing. Index 0 takes into account the
# GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0`
# will use the first GPU in that env, i.e. GPU#1
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
else:
# Here, we'll use torch.distributed.
......@@ -180,6 +184,10 @@ class TrainingArguments:
torch.distributed.init_process_group(backend="nccl")
device = torch.device("cuda", self.local_rank)
n_gpu = 1
if device.type == "cuda":
torch.cuda.set_device(device)
return device, n_gpu
@property
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment