Unverified Commit d35f7296 authored by Yanming W's avatar Yanming W Committed by GitHub
Browse files

Restore fp16 support on xla gpu device (#22300)

parent 67c2dbdb
...@@ -598,7 +598,7 @@ class Trainer: ...@@ -598,7 +598,7 @@ class Trainer:
logger.info(f"Using {args.half_precision_backend} half precision backend") logger.info(f"Using {args.half_precision_backend} half precision backend")
self.do_grad_scaling = False self.do_grad_scaling = False
if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled() or is_torch_tpu_available()): if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled()):
# deepspeed and SageMaker Model Parallel manage their own half precision # deepspeed and SageMaker Model Parallel manage their own half precision
if args.half_precision_backend == "cuda_amp": if args.half_precision_backend == "cuda_amp":
self.use_cuda_amp = True self.use_cuda_amp = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment