Unverified Commit 2aa2a144 authored by Kazuaki Ishizaki's avatar Kazuaki Ishizaki Committed by GitHub
Browse files

Make tensor device correct when ACCELERATE_TORCH_DEVICE is defined (#31751)

return correct device when ACCELERATE_TORCH_DEVICE is defined
parent 8c5c180d
......@@ -2194,7 +2194,9 @@ class TrainingArguments:
# trigger an error that a device index is missing. Index 0 takes into account the
# GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0`
# will use the first GPU in that env, i.e. GPU#1
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device(
"cuda:0" if torch.cuda.is_available() else os.environ.get("ACCELERATE_TORCH_DEVICE", "cpu")
)
# Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at
# the default value.
self._n_gpu = torch.cuda.device_count()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment