"docs/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "a54728f8c1087a7a8c604732de21d12193b3c23a"
Unverified Commit bcc069dd authored by jeffhataws's avatar jeffhataws Committed by GitHub
Browse files

Enable bf16 option for XLA devices (#20684)

parent 9858ecd7
......@@ -565,7 +565,7 @@ class Trainer:
logger.info(f"Using {args.half_precision_backend} half precision backend")
self.do_grad_scaling = False
if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled()):
if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled() or is_torch_tpu_available()):
# deepspeed and SageMaker Model Parallel manage their own half precision
if args.half_precision_backend == "cuda_amp":
self.use_cuda_amp = True
......
......@@ -1122,9 +1122,9 @@ class TrainingArguments:
if self.bf16 or self.bf16_full_eval:
if self.no_cuda and not is_torch_bf16_cpu_available():
if self.no_cuda and not is_torch_bf16_cpu_available() and not is_torch_tpu_available():
# cpu
raise ValueError("Your setup doesn't support bf16/cpu. You need torch>=1.10")
raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10")
elif not self.no_cuda and torch.cuda.is_available() and not is_torch_bf16_gpu_available():
# gpu
raise ValueError(
......@@ -1172,12 +1172,13 @@ class TrainingArguments:
and is_torch_available()
and (self.device.type != "cuda")
and (get_xla_device_type(self.device) != "GPU")
and (get_xla_device_type(self.device) != "TPU")
and (self.device.type != "cpu")
and (self.bf16 or self.bf16_full_eval)
):
raise ValueError(
"BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation"
" (`--bf16_full_eval`) can only be used on CUDA or CPU devices."
" (`--bf16_full_eval`) can only be used on CUDA or CPU/TPU/NeuronCore devices."
)
if self.torchdynamo is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment