Unverified Commit 2056f26e authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Extend pipelines for automodel tupels (#12025)



* fix_torch_device_generate_test

* remove @

* finish

* refactor

* add test

* fix test

* Attempt at simplification.

* Small fix.

* Fixing non existing AutoModel for TF.

* Naming.

* Remove extra condition.
Co-authored-by: default avatarpatrickvonplaten <patrick.v.platen@gmail.com>
parent f8bd8c6c
......@@ -37,7 +37,7 @@ from .base import (
PipelineDataFormat,
PipelineException,
get_default_model,
infer_framework_from_model,
infer_framework_load_model,
)
from .conversational import Conversation, ConversationalPipeline
from .feature_extraction import FeatureExtractionPipeline
......@@ -110,14 +110,14 @@ TASK_ALIASES = {
SUPPORTED_TASKS = {
"feature-extraction": {
"impl": FeatureExtractionPipeline,
"tf": TFAutoModel if is_tf_available() else None,
"pt": AutoModel if is_torch_available() else None,
"tf": (TFAutoModel,) if is_tf_available() else (),
"pt": (AutoModel,) if is_torch_available() else (),
"default": {"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"}},
},
"text-classification": {
"impl": TextClassificationPipeline,
"tf": TFAutoModelForSequenceClassification if is_tf_available() else None,
"pt": AutoModelForSequenceClassification if is_torch_available() else None,
"tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (),
"pt": (AutoModelForSequenceClassification,) if is_torch_available() else (),
"default": {
"model": {
"pt": "distilbert-base-uncased-finetuned-sst-2-english",
......@@ -127,8 +127,8 @@ SUPPORTED_TASKS = {
},
"token-classification": {
"impl": TokenClassificationPipeline,
"tf": TFAutoModelForTokenClassification if is_tf_available() else None,
"pt": AutoModelForTokenClassification if is_torch_available() else None,
"tf": (TFAutoModelForTokenClassification,) if is_tf_available() else (),
"pt": (AutoModelForTokenClassification,) if is_torch_available() else (),
"default": {
"model": {
"pt": "dbmdz/bert-large-cased-finetuned-conll03-english",
......@@ -138,16 +138,16 @@ SUPPORTED_TASKS = {
},
"question-answering": {
"impl": QuestionAnsweringPipeline,
"tf": TFAutoModelForQuestionAnswering if is_tf_available() else None,
"pt": AutoModelForQuestionAnswering if is_torch_available() else None,
"tf": (TFAutoModelForQuestionAnswering,) if is_tf_available() else (),
"pt": (AutoModelForQuestionAnswering,) if is_torch_available() else (),
"default": {
"model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"},
},
},
"table-question-answering": {
"impl": TableQuestionAnsweringPipeline,
"pt": AutoModelForTableQuestionAnswering if is_torch_available() else None,
"tf": None,
"pt": (AutoModelForTableQuestionAnswering,) if is_torch_available() else (),
"tf": (),
"default": {
"model": {
"pt": "google/tapas-base-finetuned-wtq",
......@@ -158,21 +158,21 @@ SUPPORTED_TASKS = {
},
"fill-mask": {
"impl": FillMaskPipeline,
"tf": TFAutoModelForMaskedLM if is_tf_available() else None,
"pt": AutoModelForMaskedLM if is_torch_available() else None,
"tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (),
"pt": (AutoModelForMaskedLM,) if is_torch_available() else (),
"default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}},
},
"summarization": {
"impl": SummarizationPipeline,
"tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None,
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
"default": {"model": {"pt": "sshleifer/distilbart-cnn-12-6", "tf": "t5-small"}},
},
# This task is a special case as it's parametrized by SRC, TGT languages.
"translation": {
"impl": TranslationPipeline,
"tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None,
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
"default": {
("en", "fr"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
("en", "de"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
......@@ -181,20 +181,20 @@ SUPPORTED_TASKS = {
},
"text2text-generation": {
"impl": Text2TextGenerationPipeline,
"tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None,
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
},
"text-generation": {
"impl": TextGenerationPipeline,
"tf": TFAutoModelForCausalLM if is_tf_available() else None,
"pt": AutoModelForCausalLM if is_torch_available() else None,
"tf": (TFAutoModelForCausalLM,) if is_tf_available() else (),
"pt": (AutoModelForCausalLM,) if is_torch_available() else (),
"default": {"model": {"pt": "gpt2", "tf": "gpt2"}},
},
"zero-shot-classification": {
"impl": ZeroShotClassificationPipeline,
"tf": TFAutoModelForSequenceClassification if is_tf_available() else None,
"pt": AutoModelForSequenceClassification if is_torch_available() else None,
"tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (),
"pt": (AutoModelForSequenceClassification,) if is_torch_available() else (),
"default": {
"model": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
"config": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
......@@ -203,14 +203,14 @@ SUPPORTED_TASKS = {
},
"conversational": {
"impl": ConversationalPipeline,
"tf": TFAutoModelForCausalLM if is_tf_available() else None,
"pt": AutoModelForCausalLM if is_torch_available() else None,
"tf": (TFAutoModelForSeq2SeqLM, TFAutoModelForCausalLM) if is_tf_available() else (),
"pt": (AutoModelForSeq2SeqLM, AutoModelForCausalLM) if is_torch_available() else (),
"default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}},
},
"image-classification": {
"impl": ImageClassificationPipeline,
"tf": None,
"pt": AutoModelForImageClassification if is_torch_available() else None,
"tf": (),
"pt": (AutoModelForImageClassification,) if is_torch_available() else (),
"default": {"model": {"pt": "google/vit-base-patch16-224"}},
},
}
......@@ -379,53 +379,35 @@ def pipeline(
>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
>>> pipeline('ner', model=model, tokenizer=tokenizer)
"""
# Retrieve the task
targeted_task, task_options = check_task(task)
task_class = targeted_task["impl"]
# Use default model/config/tokenizer for the task if no model is provided
if model is None:
# At that point framework might still be undetermined
model = get_default_model(targeted_task, framework, task_options)
model_name = model if isinstance(model, str) else None
# Infer the framework form the model
if framework is None:
framework, model = infer_framework_from_model(model, targeted_task, revision=revision, task=task)
task_class, model_class = targeted_task["impl"], targeted_task[framework]
# Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained
model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token)
# Config is the primordial information item.
# Instantiate config if needed
if isinstance(config, str):
config = AutoConfig.from_pretrained(config, revision=revision, _from_pipeline=task, **model_kwargs)
elif config is None and isinstance(model, str):
config = AutoConfig.from_pretrained(model, revision=revision, _from_pipeline=task, **model_kwargs)
# Instantiate model if needed
if isinstance(model, str):
# Handle transparent TF/PT model conversion
if framework == "pt" and model.endswith(".h5"):
model_kwargs["from_tf"] = True
logger.warning(
"Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. "
"Trying to load the model with PyTorch."
)
elif framework == "tf" and model.endswith(".bin"):
model_kwargs["from_pt"] = True
logger.warning(
"Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. "
"Trying to load the model with Tensorflow."
)
model_name = model if isinstance(model, str) else None
if model_class is None:
raise ValueError(
f"Pipeline using {framework} framework, but this framework is not supported by this pipeline."
)
# Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained
model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token)
model = model_class.from_pretrained(
model, config=config, revision=revision, _from_pipeline=task, **model_kwargs
)
# Infer the framework from the model
# Forced if framework already defined, inferred if it's None
# Will load the correct model if possible
model_classes = {"tf": targeted_task["tf"], "pt": targeted_task["pt"]}
framework, model = infer_framework_load_model(
model, model_classes=model_classes, config=config, framework=framework, revision=revision, task=task
)
model_config = model.config
......
......@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import csv
import importlib
import json
import os
import pickle
......@@ -21,11 +22,12 @@ import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager
from os.path import abspath, exists
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
from ..modelcard import ModelCard
from ..models.auto.configuration_auto import AutoConfig
from ..tokenization_utils import PreTrainedTokenizer, TruncationStrategy
from ..utils import logging
......@@ -48,8 +50,13 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
def infer_framework_from_model(
model, model_classes: Optional[Dict[str, type]] = None, task: Optional[str] = None, **model_kwargs
def infer_framework_load_model(
model,
config: AutoConfig,
model_classes: Optional[Dict[str, Tuple[type]]] = None,
task: Optional[str] = None,
framework: Optional[str] = None,
**model_kwargs
):
"""
Select framework (TensorFlow or PyTorch) to use from the :obj:`model` passed. Returns a tuple (framework, model).
......@@ -64,6 +71,8 @@ def infer_framework_from_model(
model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`):
The model to infer the framework from. If :obj:`str`, a checkpoint name. The model to infer the framewrok
from.
config (:class:`~transformers.AutoConfig`):
The config associated with the model to help using the correct class
model_classes (dictionary :obj:`str` to :obj:`type`, `optional`):
A mapping framework to class.
task (:obj:`str`):
......@@ -83,24 +92,100 @@ def infer_framework_from_model(
)
if isinstance(model, str):
model_kwargs["_from_pipeline"] = task
if is_torch_available() and not is_tf_available():
model_class = model_classes.get("pt", AutoModel)
model = model_class.from_pretrained(model, **model_kwargs)
elif is_tf_available() and not is_torch_available():
model_class = model_classes.get("tf", TFAutoModel)
model = model_class.from_pretrained(model, **model_kwargs)
else:
class_tuple = ()
look_pt = is_torch_available() and framework in {"pt", None}
look_tf = is_tf_available() and framework in {"tf", None}
if model_classes:
if look_pt:
class_tuple = class_tuple + model_classes.get("pt", (AutoModel,))
if look_tf:
class_tuple = class_tuple + model_classes.get("tf", (TFAutoModel,))
if config.architectures:
classes = []
for architecture in config.architectures:
transformers_module = importlib.import_module("transformers")
if look_tf:
_class = getattr(transformers_module, architecture, None)
if _class is not None:
classes.append(_class)
if look_pt:
_class = getattr(transformers_module, f"TF{architecture}", None)
if _class is not None:
classes.append(_class)
class_tuple = class_tuple + tuple(classes)
if len(class_tuple) == 0:
raise ValueError(f"Pipeline cannot infer suitable model classes from {model}")
for model_class in class_tuple:
kwargs = model_kwargs.copy()
if framework == "pt" and model.endswith(".h5"):
kwargs["from_tf"] = True
logger.warning(
"Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. "
"Trying to load the model with PyTorch."
)
elif framework == "tf" and model.endswith(".bin"):
kwargs["from_pt"] = True
logger.warning(
"Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. "
"Trying to load the model with Tensorflow."
)
try:
model_class = model_classes.get("pt", AutoModel)
model = model_class.from_pretrained(model, **model_kwargs)
except OSError:
model_class = model_classes.get("tf", TFAutoModel)
model = model_class.from_pretrained(model, **model_kwargs)
model = model_class.from_pretrained(model, **kwargs)
# Stop loading on the first successful load.
break
except (OSError, ValueError):
continue
if isinstance(model, str):
raise ValueError(f"Could not load model {model} with any of the following classes: {class_tuple}.")
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
return framework, model
def infer_framework_from_model(
model,
model_classes: Optional[Dict[str, Tuple[type]]] = None,
task: Optional[str] = None,
framework: Optional[str] = None,
**model_kwargs
):
"""
Select framework (TensorFlow or PyTorch) to use from the :obj:`model` passed. Returns a tuple (framework, model).
If :obj:`model` is instantiated, this function will just infer the framework from the model class. Otherwise
:obj:`model` is actually a checkpoint name and this method will try to instantiate it using :obj:`model_classes`.
Since we don't want to instantiate the model twice, this model is returned for use by the pipeline.
If both frameworks are installed and available for :obj:`model`, PyTorch is selected.
Args:
model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`):
The model to infer the framework from. If :obj:`str`, a checkpoint name. The model to infer the framewrok
from.
model_classes (dictionary :obj:`str` to :obj:`type`, `optional`):
A mapping framework to class.
task (:obj:`str`):
The task defining which pipeline will be returned.
model_kwargs:
Additional dictionary of keyword arguments passed along to the model's :obj:`from_pretrained(...,
**model_kwargs)` function.
Returns:
:obj:`Tuple`: A tuple framework, model.
"""
if isinstance(model, str):
config = AutoConfig.from_pretrained(model, _from_pipeline=task, **model_kwargs)
else:
config = model.config
return infer_framework_load_model(
model, config, model_classes=model_classes, _from_pipeline=task, task=task, framework=framework, **model_kwargs
)
def get_framework(model, revision: Optional[str] = None):
"""
Select framework (TensorFlow or PyTorch) to use.
......@@ -534,7 +619,7 @@ class Pipeline(_ScikitCompat):
):
if framework is None:
framework, model = infer_framework_from_model(model)
framework, model = infer_framework_load_model(model, config=model.config)
self.task = task
self.model = model
......
......@@ -18,6 +18,8 @@ from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
BlenderbotSmallForConditionalGeneration,
BlenderbotSmallTokenizer,
Conversation,
ConversationalPipeline,
is_torch_available,
......@@ -389,3 +391,32 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
self.assertEqual(result[0].generated_responses[1], "i don't have any plans yet. i'm not sure what to do yet.")
self.assertEqual(result[1].past_user_inputs[1], "What's your name?")
self.assertEqual(result[1].generated_responses[1], "i don't have a name, but i'm going to see a horror movie.")
@require_torch
@slow
def test_from_pipeline_conversation(self):
model_id = "facebook/blenderbot_small-90M"
# from model id
conversation_agent_from_model_id = pipeline("conversational", model=model_id, tokenizer=model_id)
# from model object
model = BlenderbotSmallForConditionalGeneration.from_pretrained(model_id)
tokenizer = BlenderbotSmallTokenizer.from_pretrained(model_id)
conversation_agent_from_model = pipeline("conversational", model=model, tokenizer=tokenizer)
conversation = Conversation("My name is Sarah and I live in London")
conversation_copy = Conversation("My name is Sarah and I live in London")
result_model_id = conversation_agent_from_model_id([conversation])
result_model = conversation_agent_from_model([conversation_copy])
# check for equality
self.assertEqual(
result_model_id.generated_responses[0],
"hi sarah, i live in london as well. do you have any plans for the weekend?",
)
self.assertEqual(
result_model_id.generated_responses[0],
result_model.generated_responses[0],
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment