Unverified Commit 0911b6bd authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Improving Pipelines by defaulting to framework='tf' when pytorch seems unavailable. (#7728)

* Improving Pipelines by defaulting to framework='tf' when

pytorch seems unavailable.

* Actually changing the default resolution order to account for model
defaults

Adding a new tests for each pipeline to check that pipeline(task) works
too without manually adding the framework too.
parent 3a134f7c
...@@ -85,31 +85,63 @@ if TYPE_CHECKING: ...@@ -85,31 +85,63 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def get_framework(model=None): def get_framework(model):
""" """
Select framework (TensorFlow or PyTorch) to use. Select framework (TensorFlow or PyTorch) to use.
Args: Args:
model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`, `optional`): model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`):
If both frameworks are installed, picks the one corresponding to the model passed (either a model class or If both frameworks are installed, picks the one corresponding to the model passed (either a model class or
the model name). If no specific model is provided, defaults to using PyTorch. the model name). If no specific model is provided, defaults to using PyTorch.
""" """
if is_tf_available() and is_torch_available() and model is not None and not isinstance(model, str): if not is_tf_available() and not is_torch_available():
# Both framework are available but the user supplied a model class instance.
# Try to guess which framework to use from the model classname
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
elif not is_tf_available() and not is_torch_available():
raise RuntimeError( raise RuntimeError(
"At least one of TensorFlow 2.0 or PyTorch should be installed. " "At least one of TensorFlow 2.0 or PyTorch should be installed. "
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ " "To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
"To install PyTorch, read the instructions at https://pytorch.org/." "To install PyTorch, read the instructions at https://pytorch.org/."
) )
if isinstance(model, str):
if is_torch_available() and not is_tf_available():
model = AutoModel.from_pretrained(model)
elif is_tf_available() and not is_torch_available():
model = TFAutoModel.from_pretrained(model)
else: else:
# framework = 'tf' if is_tf_available() else 'pt' try:
framework = "pt" if is_torch_available() else "tf" model = AutoModel.from_pretrained(model)
except OSError:
model = TFAutoModel.from_pretrained(model)
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
return framework return framework
def get_default_model(targeted_task: Dict, framework: Optional[str]) -> str:
"""
Select a default model to use for a given task. Defaults to pytorch if ambiguous.
Args:
targeted_task (:obj:`Dict` ):
Dictionnary representing the given task, that should contain default models
framework (:obj:`str`, None)
"pt", "tf" or None, representing a specific framework if it was specified, or None if we don't know yet.
Returns
:obj:`str` The model string representing the default model for this pipeline
"""
if is_torch_available() and not is_tf_available():
framework = "pt"
elif is_tf_available() and not is_torch_available():
framework = "tf"
default_models = targeted_task["default"]["model"]
if framework is None:
framework = "pt"
return default_models[framework]
class PipelineException(Exception): class PipelineException(Exception):
""" """
Raised by a :class:`~transformers.Pipeline` when handling __call__. Raised by a :class:`~transformers.Pipeline` when handling __call__.
...@@ -2685,14 +2717,16 @@ def pipeline( ...@@ -2685,14 +2717,16 @@ def pipeline(
if task not in SUPPORTED_TASKS: if task not in SUPPORTED_TASKS:
raise KeyError("Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys()))) raise KeyError("Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys())))
framework = framework or get_framework(model)
targeted_task = SUPPORTED_TASKS[task] targeted_task = SUPPORTED_TASKS[task]
task_class, model_class = targeted_task["impl"], targeted_task[framework]
# Use default model/config/tokenizer for the task if no model is provided # Use default model/config/tokenizer for the task if no model is provided
if model is None: if model is None:
model = targeted_task["default"]["model"][framework] # At that point framework might still be undetermined
model = get_default_model(targeted_task, framework)
framework = framework or get_framework(model)
task_class, model_class = targeted_task["impl"], targeted_task[framework]
# Try to infer tokenizer from model or config name (if provided as str) # Try to infer tokenizer from model or config name (if provided as str)
if tokenizer is None: if tokenizer is None:
......
...@@ -10,6 +10,7 @@ DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0 ...@@ -10,6 +10,7 @@ DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0
VALID_INPUTS = ["A simple string", ["list of strings"]] VALID_INPUTS = ["A simple string", ["list of strings"]]
NER_FINETUNED_MODELS = ["sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"] NER_FINETUNED_MODELS = ["sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"]
TF_NER_FINETUNED_MODELS = ["Narsil/small"]
# xlnet-base-cased disabled for now, since it crashes TF2 # xlnet-base-cased disabled for now, since it crashes TF2
FEATURE_EXTRACT_FINETUNED_MODELS = ["sshleifer/tiny-distilbert-base-cased"] FEATURE_EXTRACT_FINETUNED_MODELS = ["sshleifer/tiny-distilbert-base-cased"]
...@@ -804,6 +805,14 @@ class NerPipelineTests(unittest.TestCase): ...@@ -804,6 +805,14 @@ class NerPipelineTests(unittest.TestCase):
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, framework="tf", grouped_entities=True) nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, framework="tf", grouped_entities=True)
self._test_ner_pipeline(nlp, mandatory_keys) self._test_ner_pipeline(nlp, mandatory_keys)
@require_tf
def test_tf_only_ner(self):
mandatory_keys = {"entity", "word", "score"}
for model_name in TF_NER_FINETUNED_MODELS:
# We don't specificy framework='tf' but it gets detected automatically
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name)
self._test_ner_pipeline(nlp, mandatory_keys)
class PipelineCommonTests(unittest.TestCase): class PipelineCommonTests(unittest.TestCase):
pipelines = SUPPORTED_TASKS.keys() pipelines = SUPPORTED_TASKS.keys()
...@@ -815,6 +824,7 @@ class PipelineCommonTests(unittest.TestCase): ...@@ -815,6 +824,7 @@ class PipelineCommonTests(unittest.TestCase):
for task in self.pipelines: for task in self.pipelines:
with self.subTest(msg="Testing TF defaults with TF and {}".format(task)): with self.subTest(msg="Testing TF defaults with TF and {}".format(task)):
pipeline(task, framework="tf") pipeline(task, framework="tf")
pipeline(task)
@require_torch @require_torch
@slow @slow
...@@ -823,3 +833,4 @@ class PipelineCommonTests(unittest.TestCase): ...@@ -823,3 +833,4 @@ class PipelineCommonTests(unittest.TestCase):
for task in self.pipelines: for task in self.pipelines:
with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(task)): with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(task)):
pipeline(task, framework="pt") pipeline(task, framework="pt")
pipeline(task)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment