Unverified Commit 05de038f authored by Abhilash Majumder's avatar Abhilash Majumder Committed by GitHub
Browse files

Flex xpu bug fix (#26135)

flex gpu bug fix
parent 9709ab11
...@@ -1425,12 +1425,13 @@ class TrainingArguments: ...@@ -1425,12 +1425,13 @@ class TrainingArguments:
and is_torch_available() and is_torch_available()
and (self.device.type != "cuda") and (self.device.type != "cuda")
and (self.device.type != "npu") and (self.device.type != "npu")
and (self.device.type != "xpu")
and (get_xla_device_type(self.device) != "GPU") and (get_xla_device_type(self.device) != "GPU")
and (self.fp16 or self.fp16_full_eval) and (self.fp16 or self.fp16_full_eval)
): ):
raise ValueError( raise ValueError(
"FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation" "FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation"
" (`--fp16_full_eval`) can only be used on CUDA or NPU devices." " (`--fp16_full_eval`) can only be used on CUDA or NPU devices or certain XPU devices (with IPEX)."
) )
if ( if (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment