Unverified Commit a649de55 authored by Matt's avatar Matt Committed by GitHub
Browse files

Raise a TF-specific error when importing Torch classes (#18280)



* Raise a TF-specific error when importing Torch classes

* Update src/transformers/utils/import_utils.py
Co-authored-by: default avatarLysandre Debut <lysandre.debut@reseau.eseo.fr>

* Add an inverse error for PyTorch users
Co-authored-by: default avatarLysandre Debut <lysandre.debut@reseau.eseo.fr>
parent 5e0ffd91
......@@ -693,6 +693,30 @@ PYTORCH_IMPORT_ERROR = """
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
"""
# docstyle-ignore
PYTORCH_IMPORT_ERROR_WITH_TF = """
{0} requires the PyTorch library but it was not found in your environment.
However, we were able to find a TensorFlow installation. TensorFlow classes begin
with "TF", but are otherwise identically named to our PyTorch classes. This
means that the TF equivalent of the class you tried to import would be "TF{0}".
If you want to use TensorFlow, please use TF classes instead!
If you really do want to use PyTorch please go to
https://pytorch.org/get-started/locally/ and follow the instructions that
match your environment.
"""
# docstyle-ignore
TF_IMPORT_ERROR_WITH_PYTORCH = """
{0} requires the TensorFlow library but it was not found in your environment.
However, we were able to find a PyTorch installation. PyTorch classes do not begin
with "TF", but are otherwise identically named to our TF classes.
If you want to use PyTorch, please use those classes instead!
If you really do want to use TensorFlow, please follow the instructions on the
installation page https://www.tensorflow.org/install that match your environment.
"""
# docstyle-ignore
SKLEARN_IMPORT_ERROR = """
......@@ -855,6 +879,15 @@ def requires_backends(obj, backends):
backends = [backends]
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
# Raise an error for users who might not realize that classes without "TF" are torch-only
if "torch" in backends and "tf" not in backends and not is_torch_available() and is_tf_available():
raise ImportError(PYTORCH_IMPORT_ERROR_WITH_TF.format(name))
# Raise the inverse error for PyTorch users trying to load TF classes
if "tf" in backends and "torch" not in backends and is_torch_available() and not is_tf_available():
raise ImportError(TF_IMPORT_ERROR_WITH_PYTORCH.format(name))
checks = (BACKENDS_MAPPING[backend] for backend in backends)
failed = [msg.format(name) for available, msg in checks if not available()]
if failed:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment