"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "b3e54698dd701667ba1d06501a7a9e431c020863"
Unverified Commit 63942218 authored by Yanming Wang's avatar Yanming Wang Committed by GitHub
Browse files

Fix XLA fp16 and bf16 error checking (#18913)



* Fix XLA fp16 and bf16 error checking

* Update src/transformers/training_args.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 6690ba3f
...@@ -93,6 +93,15 @@ def get_int_from_env(env_keys, default): ...@@ -93,6 +93,15 @@ def get_int_from_env(env_keys, default):
return default return default
def get_xla_device_type(device: "torch.device") -> Optional[str]:
"""
Returns the xla device type (CPU|GPU|TPU) or None if the device is a non-xla device.
"""
if is_torch_tpu_available():
return xm.xla_real_devices([device])[0].split(":")[0]
return None
class OptimizerNames(ExplicitEnum): class OptimizerNames(ExplicitEnum):
""" """
Stores the acceptable string identifiers for optimizers. Stores the acceptable string identifiers for optimizers.
...@@ -1108,7 +1117,7 @@ class TrainingArguments: ...@@ -1108,7 +1117,7 @@ class TrainingArguments:
self.framework == "pt" self.framework == "pt"
and is_torch_available() and is_torch_available()
and (self.device.type != "cuda") and (self.device.type != "cuda")
and not (self.device.type == "xla" and "GPU_NUM_DEVICES" in os.environ) and (get_xla_device_type(self.device) != "GPU")
and (self.fp16 or self.fp16_full_eval) and (self.fp16 or self.fp16_full_eval)
): ):
raise ValueError( raise ValueError(
...@@ -1120,7 +1129,7 @@ class TrainingArguments: ...@@ -1120,7 +1129,7 @@ class TrainingArguments:
self.framework == "pt" self.framework == "pt"
and is_torch_available() and is_torch_available()
and (self.device.type != "cuda") and (self.device.type != "cuda")
and not (self.device.type == "xla" and "GPU_NUM_DEVICES" in os.environ) and (get_xla_device_type(self.device) != "GPU")
and (self.device.type != "cpu") and (self.device.type != "cpu")
and (self.bf16 or self.bf16_full_eval) and (self.bf16 or self.bf16_full_eval)
): ):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment