Commit c8b0c1e5 authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Improve exception type.

ImportError isn't really appropriate when there's no import involved.
parent 4c09a960
...@@ -107,7 +107,7 @@ class ServeCommand(BaseTransformersCLICommand): ...@@ -107,7 +107,7 @@ class ServeCommand(BaseTransformersCLICommand):
self._host = host self._host = host
self._port = port self._port = port
if not _serve_dependancies_installed: if not _serve_dependancies_installed:
raise ImportError( raise RuntimeError(
"Using serve command requires FastAPI and unicorn. " "Using serve command requires FastAPI and unicorn. "
"Please install transformers with [serving]: pip install transformers[serving]." "Please install transformers with [serving]: pip install transformers[serving]."
"Or install FastAPI and unicorn separatly." "Or install FastAPI and unicorn separatly."
......
...@@ -8,7 +8,7 @@ from transformers.commands import BaseTransformersCLICommand ...@@ -8,7 +8,7 @@ from transformers.commands import BaseTransformersCLICommand
if not is_tf_available() and not is_torch_available(): if not is_tf_available() and not is_torch_available():
raise ImportError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training") raise RuntimeError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
# TF training parameters # TF training parameters
USE_XLA = False USE_XLA = False
......
...@@ -324,7 +324,7 @@ def squad_convert_examples_to_features( ...@@ -324,7 +324,7 @@ def squad_convert_examples_to_features(
del new_features del new_features
if return_dataset == "pt": if return_dataset == "pt":
if not is_torch_available(): if not is_torch_available():
raise ImportError("Pytorch must be installed to return a pytorch dataset.") raise RuntimeError("PyTorch must be installed to return a PyTorch dataset.")
# Convert to Tensors and build dataset # Convert to Tensors and build dataset
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
...@@ -354,7 +354,7 @@ def squad_convert_examples_to_features( ...@@ -354,7 +354,7 @@ def squad_convert_examples_to_features(
return features, dataset return features, dataset
elif return_dataset == "tf": elif return_dataset == "tf":
if not is_tf_available(): if not is_tf_available():
raise ImportError("TensorFlow must be installed to return a TensorFlow dataset.") raise RuntimeError("TensorFlow must be installed to return a TensorFlow dataset.")
def gen(): def gen():
for ex in features: for ex in features:
......
...@@ -294,7 +294,7 @@ class SingleSentenceClassificationProcessor(DataProcessor): ...@@ -294,7 +294,7 @@ class SingleSentenceClassificationProcessor(DataProcessor):
return features return features
elif return_tensors == "tf": elif return_tensors == "tf":
if not is_tf_available(): if not is_tf_available():
raise ImportError("return_tensors set to 'tf' but TensorFlow 2.0 can't be imported") raise RuntimeError("return_tensors set to 'tf' but TensorFlow 2.0 can't be imported")
import tensorflow as tf import tensorflow as tf
def gen(): def gen():
...@@ -309,7 +309,7 @@ class SingleSentenceClassificationProcessor(DataProcessor): ...@@ -309,7 +309,7 @@ class SingleSentenceClassificationProcessor(DataProcessor):
return dataset return dataset
elif return_tensors == "pt": elif return_tensors == "pt":
if not is_torch_available(): if not is_torch_available():
raise ImportError("return_tensors set to 'pt' but PyTorch can't be imported") raise RuntimeError("return_tensors set to 'pt' but PyTorch can't be imported")
import torch import torch
from torch.utils.data import TensorDataset from torch.utils.data import TensorDataset
......
...@@ -68,7 +68,7 @@ def get_framework(model=None): ...@@ -68,7 +68,7 @@ def get_framework(model=None):
# Try to guess which framework to use from the model classname # Try to guess which framework to use from the model classname
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt" framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
elif not is_tf_available() and not is_torch_available(): elif not is_tf_available() and not is_torch_available():
raise ImportError( raise RuntimeError(
"At least one of TensorFlow 2.0 or PyTorch should be installed. " "At least one of TensorFlow 2.0 or PyTorch should be installed. "
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ " "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
"To install PyTorch, read the instructions at https://pytorch.org/." "To install PyTorch, read the instructions at https://pytorch.org/."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment