Unverified Commit 0911b6bd authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Improving Pipelines by defaulting to framework='tf' when pytorch seems unavailable. (#7728)

* Improving Pipelines by defaulting to framework='tf' when

pytorch seems unavailable.

* Actually changing the default resolution order to account for model
defaults

Adding a new tests for each pipeline to check that pipeline(task) works
too without manually adding the framework too.
parent 3a134f7c
......@@ -85,31 +85,63 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
def get_framework(model=None):
def get_framework(model):
"""
Select framework (TensorFlow or PyTorch) to use.
Args:
model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`, `optional`):
model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`):
If both frameworks are installed, picks the one corresponding to the model passed (either a model class or
the model name). If no specific model is provided, defaults to using PyTorch.
"""
if is_tf_available() and is_torch_available() and model is not None and not isinstance(model, str):
# Both framework are available but the user supplied a model class instance.
# Try to guess which framework to use from the model classname
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
elif not is_tf_available() and not is_torch_available():
if not is_tf_available() and not is_torch_available():
raise RuntimeError(
"At least one of TensorFlow 2.0 or PyTorch should be installed. "
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
"To install PyTorch, read the instructions at https://pytorch.org/."
)
if isinstance(model, str):
if is_torch_available() and not is_tf_available():
model = AutoModel.from_pretrained(model)
elif is_tf_available() and not is_torch_available():
model = TFAutoModel.from_pretrained(model)
else:
# framework = 'tf' if is_tf_available() else 'pt'
framework = "pt" if is_torch_available() else "tf"
try:
model = AutoModel.from_pretrained(model)
except OSError:
model = TFAutoModel.from_pretrained(model)
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
return framework
def get_default_model(targeted_task: Dict, framework: Optional[str]) -> str:
"""
Select a default model to use for a given task. Defaults to pytorch if ambiguous.
Args:
targeted_task (:obj:`Dict` ):
Dictionnary representing the given task, that should contain default models
framework (:obj:`str`, None)
"pt", "tf" or None, representing a specific framework if it was specified, or None if we don't know yet.
Returns
:obj:`str` The model string representing the default model for this pipeline
"""
if is_torch_available() and not is_tf_available():
framework = "pt"
elif is_tf_available() and not is_torch_available():
framework = "tf"
default_models = targeted_task["default"]["model"]
if framework is None:
framework = "pt"
return default_models[framework]
class PipelineException(Exception):
"""
Raised by a :class:`~transformers.Pipeline` when handling __call__.
......@@ -2685,14 +2717,16 @@ def pipeline(
if task not in SUPPORTED_TASKS:
raise KeyError("Unknown task {}, available tasks are {}".format(task, list(SUPPORTED_TASKS.keys())))
framework = framework or get_framework(model)
targeted_task = SUPPORTED_TASKS[task]
task_class, model_class = targeted_task["impl"], targeted_task[framework]
# Use default model/config/tokenizer for the task if no model is provided
if model is None:
model = targeted_task["default"]["model"][framework]
# At that point framework might still be undetermined
model = get_default_model(targeted_task, framework)
framework = framework or get_framework(model)
task_class, model_class = targeted_task["impl"], targeted_task[framework]
# Try to infer tokenizer from model or config name (if provided as str)
if tokenizer is None:
......
......@@ -10,6 +10,7 @@ DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0
VALID_INPUTS = ["A simple string", ["list of strings"]]
NER_FINETUNED_MODELS = ["sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"]
TF_NER_FINETUNED_MODELS = ["Narsil/small"]
# xlnet-base-cased disabled for now, since it crashes TF2
FEATURE_EXTRACT_FINETUNED_MODELS = ["sshleifer/tiny-distilbert-base-cased"]
......@@ -804,6 +805,14 @@ class NerPipelineTests(unittest.TestCase):
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name, framework="tf", grouped_entities=True)
self._test_ner_pipeline(nlp, mandatory_keys)
@require_tf
def test_tf_only_ner(self):
mandatory_keys = {"entity", "word", "score"}
for model_name in TF_NER_FINETUNED_MODELS:
# We don't specificy framework='tf' but it gets detected automatically
nlp = pipeline(task="ner", model=model_name, tokenizer=model_name)
self._test_ner_pipeline(nlp, mandatory_keys)
class PipelineCommonTests(unittest.TestCase):
pipelines = SUPPORTED_TASKS.keys()
......@@ -815,6 +824,7 @@ class PipelineCommonTests(unittest.TestCase):
for task in self.pipelines:
with self.subTest(msg="Testing TF defaults with TF and {}".format(task)):
pipeline(task, framework="tf")
pipeline(task)
@require_torch
@slow
......@@ -823,3 +833,4 @@ class PipelineCommonTests(unittest.TestCase):
for task in self.pipelines:
with self.subTest(msg="Testing Torch defaults with PyTorch and {}".format(task)):
pipeline(task, framework="pt")
pipeline(task)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment