"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "7bc6d76396f4a603161539aefaa6207d61260f60"
Unverified Commit 2056f26e authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Extend pipelines for automodel tupels (#12025)



* fix_torch_device_generate_test

* remove @

* finish

* refactor

* add test

* fix test

* Attempt at simplification.

* Small fix.

* Fixing non existing AutoModel for TF.

* Naming.

* Remove extra condition.
Co-authored-by: default avatarpatrickvonplaten <patrick.v.platen@gmail.com>
parent f8bd8c6c
...@@ -37,7 +37,7 @@ from .base import ( ...@@ -37,7 +37,7 @@ from .base import (
PipelineDataFormat, PipelineDataFormat,
PipelineException, PipelineException,
get_default_model, get_default_model,
infer_framework_from_model, infer_framework_load_model,
) )
from .conversational import Conversation, ConversationalPipeline from .conversational import Conversation, ConversationalPipeline
from .feature_extraction import FeatureExtractionPipeline from .feature_extraction import FeatureExtractionPipeline
...@@ -110,14 +110,14 @@ TASK_ALIASES = { ...@@ -110,14 +110,14 @@ TASK_ALIASES = {
SUPPORTED_TASKS = { SUPPORTED_TASKS = {
"feature-extraction": { "feature-extraction": {
"impl": FeatureExtractionPipeline, "impl": FeatureExtractionPipeline,
"tf": TFAutoModel if is_tf_available() else None, "tf": (TFAutoModel,) if is_tf_available() else (),
"pt": AutoModel if is_torch_available() else None, "pt": (AutoModel,) if is_torch_available() else (),
"default": {"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"}}, "default": {"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"}},
}, },
"text-classification": { "text-classification": {
"impl": TextClassificationPipeline, "impl": TextClassificationPipeline,
"tf": TFAutoModelForSequenceClassification if is_tf_available() else None, "tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (),
"pt": AutoModelForSequenceClassification if is_torch_available() else None, "pt": (AutoModelForSequenceClassification,) if is_torch_available() else (),
"default": { "default": {
"model": { "model": {
"pt": "distilbert-base-uncased-finetuned-sst-2-english", "pt": "distilbert-base-uncased-finetuned-sst-2-english",
...@@ -127,8 +127,8 @@ SUPPORTED_TASKS = { ...@@ -127,8 +127,8 @@ SUPPORTED_TASKS = {
}, },
"token-classification": { "token-classification": {
"impl": TokenClassificationPipeline, "impl": TokenClassificationPipeline,
"tf": TFAutoModelForTokenClassification if is_tf_available() else None, "tf": (TFAutoModelForTokenClassification,) if is_tf_available() else (),
"pt": AutoModelForTokenClassification if is_torch_available() else None, "pt": (AutoModelForTokenClassification,) if is_torch_available() else (),
"default": { "default": {
"model": { "model": {
"pt": "dbmdz/bert-large-cased-finetuned-conll03-english", "pt": "dbmdz/bert-large-cased-finetuned-conll03-english",
...@@ -138,16 +138,16 @@ SUPPORTED_TASKS = { ...@@ -138,16 +138,16 @@ SUPPORTED_TASKS = {
}, },
"question-answering": { "question-answering": {
"impl": QuestionAnsweringPipeline, "impl": QuestionAnsweringPipeline,
"tf": TFAutoModelForQuestionAnswering if is_tf_available() else None, "tf": (TFAutoModelForQuestionAnswering,) if is_tf_available() else (),
"pt": AutoModelForQuestionAnswering if is_torch_available() else None, "pt": (AutoModelForQuestionAnswering,) if is_torch_available() else (),
"default": { "default": {
"model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"}, "model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"},
}, },
}, },
"table-question-answering": { "table-question-answering": {
"impl": TableQuestionAnsweringPipeline, "impl": TableQuestionAnsweringPipeline,
"pt": AutoModelForTableQuestionAnswering if is_torch_available() else None, "pt": (AutoModelForTableQuestionAnswering,) if is_torch_available() else (),
"tf": None, "tf": (),
"default": { "default": {
"model": { "model": {
"pt": "google/tapas-base-finetuned-wtq", "pt": "google/tapas-base-finetuned-wtq",
...@@ -158,21 +158,21 @@ SUPPORTED_TASKS = { ...@@ -158,21 +158,21 @@ SUPPORTED_TASKS = {
}, },
"fill-mask": { "fill-mask": {
"impl": FillMaskPipeline, "impl": FillMaskPipeline,
"tf": TFAutoModelForMaskedLM if is_tf_available() else None, "tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (),
"pt": AutoModelForMaskedLM if is_torch_available() else None, "pt": (AutoModelForMaskedLM,) if is_torch_available() else (),
"default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}}, "default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}},
}, },
"summarization": { "summarization": {
"impl": SummarizationPipeline, "impl": SummarizationPipeline,
"tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None, "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None, "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
"default": {"model": {"pt": "sshleifer/distilbart-cnn-12-6", "tf": "t5-small"}}, "default": {"model": {"pt": "sshleifer/distilbart-cnn-12-6", "tf": "t5-small"}},
}, },
# This task is a special case as it's parametrized by SRC, TGT languages. # This task is a special case as it's parametrized by SRC, TGT languages.
"translation": { "translation": {
"impl": TranslationPipeline, "impl": TranslationPipeline,
"tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None, "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None, "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
"default": { "default": {
("en", "fr"): {"model": {"pt": "t5-base", "tf": "t5-base"}}, ("en", "fr"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
("en", "de"): {"model": {"pt": "t5-base", "tf": "t5-base"}}, ("en", "de"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
...@@ -181,20 +181,20 @@ SUPPORTED_TASKS = { ...@@ -181,20 +181,20 @@ SUPPORTED_TASKS = {
}, },
"text2text-generation": { "text2text-generation": {
"impl": Text2TextGenerationPipeline, "impl": Text2TextGenerationPipeline,
"tf": TFAutoModelForSeq2SeqLM if is_tf_available() else None, "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None, "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}}, "default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
}, },
"text-generation": { "text-generation": {
"impl": TextGenerationPipeline, "impl": TextGenerationPipeline,
"tf": TFAutoModelForCausalLM if is_tf_available() else None, "tf": (TFAutoModelForCausalLM,) if is_tf_available() else (),
"pt": AutoModelForCausalLM if is_torch_available() else None, "pt": (AutoModelForCausalLM,) if is_torch_available() else (),
"default": {"model": {"pt": "gpt2", "tf": "gpt2"}}, "default": {"model": {"pt": "gpt2", "tf": "gpt2"}},
}, },
"zero-shot-classification": { "zero-shot-classification": {
"impl": ZeroShotClassificationPipeline, "impl": ZeroShotClassificationPipeline,
"tf": TFAutoModelForSequenceClassification if is_tf_available() else None, "tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (),
"pt": AutoModelForSequenceClassification if is_torch_available() else None, "pt": (AutoModelForSequenceClassification,) if is_torch_available() else (),
"default": { "default": {
"model": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"}, "model": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
"config": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"}, "config": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
...@@ -203,14 +203,14 @@ SUPPORTED_TASKS = { ...@@ -203,14 +203,14 @@ SUPPORTED_TASKS = {
}, },
"conversational": { "conversational": {
"impl": ConversationalPipeline, "impl": ConversationalPipeline,
"tf": TFAutoModelForCausalLM if is_tf_available() else None, "tf": (TFAutoModelForSeq2SeqLM, TFAutoModelForCausalLM) if is_tf_available() else (),
"pt": AutoModelForCausalLM if is_torch_available() else None, "pt": (AutoModelForSeq2SeqLM, AutoModelForCausalLM) if is_torch_available() else (),
"default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}}, "default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}},
}, },
"image-classification": { "image-classification": {
"impl": ImageClassificationPipeline, "impl": ImageClassificationPipeline,
"tf": None, "tf": (),
"pt": AutoModelForImageClassification if is_torch_available() else None, "pt": (AutoModelForImageClassification,) if is_torch_available() else (),
"default": {"model": {"pt": "google/vit-base-patch16-224"}}, "default": {"model": {"pt": "google/vit-base-patch16-224"}},
}, },
} }
...@@ -379,53 +379,35 @@ def pipeline( ...@@ -379,53 +379,35 @@ def pipeline(
>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") >>> tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
>>> pipeline('ner', model=model, tokenizer=tokenizer) >>> pipeline('ner', model=model, tokenizer=tokenizer)
""" """
# Retrieve the task # Retrieve the task
targeted_task, task_options = check_task(task) targeted_task, task_options = check_task(task)
task_class = targeted_task["impl"]
# Use default model/config/tokenizer for the task if no model is provided # Use default model/config/tokenizer for the task if no model is provided
if model is None: if model is None:
# At that point framework might still be undetermined # At that point framework might still be undetermined
model = get_default_model(targeted_task, framework, task_options) model = get_default_model(targeted_task, framework, task_options)
model_name = model if isinstance(model, str) else None # Config is the primordial information item.
# Infer the framework form the model
if framework is None:
framework, model = infer_framework_from_model(model, targeted_task, revision=revision, task=task)
task_class, model_class = targeted_task["impl"], targeted_task[framework]
# Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained
model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token)
# Instantiate config if needed # Instantiate config if needed
if isinstance(config, str): if isinstance(config, str):
config = AutoConfig.from_pretrained(config, revision=revision, _from_pipeline=task, **model_kwargs) config = AutoConfig.from_pretrained(config, revision=revision, _from_pipeline=task, **model_kwargs)
elif config is None and isinstance(model, str):
config = AutoConfig.from_pretrained(model, revision=revision, _from_pipeline=task, **model_kwargs)
# Instantiate model if needed model_name = model if isinstance(model, str) else None
if isinstance(model, str):
# Handle transparent TF/PT model conversion
if framework == "pt" and model.endswith(".h5"):
model_kwargs["from_tf"] = True
logger.warning(
"Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. "
"Trying to load the model with PyTorch."
)
elif framework == "tf" and model.endswith(".bin"):
model_kwargs["from_pt"] = True
logger.warning(
"Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. "
"Trying to load the model with Tensorflow."
)
if model_class is None: # Retrieve use_auth_token and add it to model_kwargs to be used in .from_pretrained
raise ValueError( model_kwargs["use_auth_token"] = model_kwargs.get("use_auth_token", use_auth_token)
f"Pipeline using {framework} framework, but this framework is not supported by this pipeline."
)
model = model_class.from_pretrained( # Infer the framework from the model
model, config=config, revision=revision, _from_pipeline=task, **model_kwargs # Forced if framework already defined, inferred if it's None
) # Will load the correct model if possible
model_classes = {"tf": targeted_task["tf"], "pt": targeted_task["pt"]}
framework, model = infer_framework_load_model(
model, model_classes=model_classes, config=config, framework=framework, revision=revision, task=task
)
model_config = model.config model_config = model.config
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import csv import csv
import importlib
import json import json
import os import os
import pickle import pickle
...@@ -21,11 +22,12 @@ import warnings ...@@ -21,11 +22,12 @@ import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from os.path import abspath, exists from os.path import abspath, exists
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from ..feature_extraction_utils import PreTrainedFeatureExtractor from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
from ..modelcard import ModelCard from ..modelcard import ModelCard
from ..models.auto.configuration_auto import AutoConfig
from ..tokenization_utils import PreTrainedTokenizer, TruncationStrategy from ..tokenization_utils import PreTrainedTokenizer, TruncationStrategy
from ..utils import logging from ..utils import logging
...@@ -48,8 +50,13 @@ if TYPE_CHECKING: ...@@ -48,8 +50,13 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def infer_framework_from_model( def infer_framework_load_model(
model, model_classes: Optional[Dict[str, type]] = None, task: Optional[str] = None, **model_kwargs model,
config: AutoConfig,
model_classes: Optional[Dict[str, Tuple[type]]] = None,
task: Optional[str] = None,
framework: Optional[str] = None,
**model_kwargs
): ):
""" """
Select framework (TensorFlow or PyTorch) to use from the :obj:`model` passed. Returns a tuple (framework, model). Select framework (TensorFlow or PyTorch) to use from the :obj:`model` passed. Returns a tuple (framework, model).
...@@ -64,6 +71,8 @@ def infer_framework_from_model( ...@@ -64,6 +71,8 @@ def infer_framework_from_model(
model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`): model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`):
The model to infer the framework from. If :obj:`str`, a checkpoint name. The model to infer the framewrok The model to infer the framework from. If :obj:`str`, a checkpoint name. The model to infer the framewrok
from. from.
config (:class:`~transformers.AutoConfig`):
The config associated with the model to help using the correct class
model_classes (dictionary :obj:`str` to :obj:`type`, `optional`): model_classes (dictionary :obj:`str` to :obj:`type`, `optional`):
A mapping framework to class. A mapping framework to class.
task (:obj:`str`): task (:obj:`str`):
...@@ -83,24 +92,100 @@ def infer_framework_from_model( ...@@ -83,24 +92,100 @@ def infer_framework_from_model(
) )
if isinstance(model, str): if isinstance(model, str):
model_kwargs["_from_pipeline"] = task model_kwargs["_from_pipeline"] = task
if is_torch_available() and not is_tf_available(): class_tuple = ()
model_class = model_classes.get("pt", AutoModel) look_pt = is_torch_available() and framework in {"pt", None}
model = model_class.from_pretrained(model, **model_kwargs) look_tf = is_tf_available() and framework in {"tf", None}
elif is_tf_available() and not is_torch_available(): if model_classes:
model_class = model_classes.get("tf", TFAutoModel) if look_pt:
model = model_class.from_pretrained(model, **model_kwargs) class_tuple = class_tuple + model_classes.get("pt", (AutoModel,))
else: if look_tf:
class_tuple = class_tuple + model_classes.get("tf", (TFAutoModel,))
if config.architectures:
classes = []
for architecture in config.architectures:
transformers_module = importlib.import_module("transformers")
if look_tf:
_class = getattr(transformers_module, architecture, None)
if _class is not None:
classes.append(_class)
if look_pt:
_class = getattr(transformers_module, f"TF{architecture}", None)
if _class is not None:
classes.append(_class)
class_tuple = class_tuple + tuple(classes)
if len(class_tuple) == 0:
raise ValueError(f"Pipeline cannot infer suitable model classes from {model}")
for model_class in class_tuple:
kwargs = model_kwargs.copy()
if framework == "pt" and model.endswith(".h5"):
kwargs["from_tf"] = True
logger.warning(
"Model might be a TensorFlow model (ending with `.h5`) but TensorFlow is not available. "
"Trying to load the model with PyTorch."
)
elif framework == "tf" and model.endswith(".bin"):
kwargs["from_pt"] = True
logger.warning(
"Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. "
"Trying to load the model with Tensorflow."
)
try: try:
model_class = model_classes.get("pt", AutoModel) model = model_class.from_pretrained(model, **kwargs)
model = model_class.from_pretrained(model, **model_kwargs) # Stop loading on the first successful load.
except OSError: break
model_class = model_classes.get("tf", TFAutoModel) except (OSError, ValueError):
model = model_class.from_pretrained(model, **model_kwargs) continue
if isinstance(model, str):
raise ValueError(f"Could not load model {model} with any of the following classes: {class_tuple}.")
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt" framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
return framework, model return framework, model
def infer_framework_from_model(
model,
model_classes: Optional[Dict[str, Tuple[type]]] = None,
task: Optional[str] = None,
framework: Optional[str] = None,
**model_kwargs
):
"""
Select framework (TensorFlow or PyTorch) to use from the :obj:`model` passed. Returns a tuple (framework, model).
If :obj:`model` is instantiated, this function will just infer the framework from the model class. Otherwise
:obj:`model` is actually a checkpoint name and this method will try to instantiate it using :obj:`model_classes`.
Since we don't want to instantiate the model twice, this model is returned for use by the pipeline.
If both frameworks are installed and available for :obj:`model`, PyTorch is selected.
Args:
model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`):
The model to infer the framework from. If :obj:`str`, a checkpoint name. The model to infer the framewrok
from.
model_classes (dictionary :obj:`str` to :obj:`type`, `optional`):
A mapping framework to class.
task (:obj:`str`):
The task defining which pipeline will be returned.
model_kwargs:
Additional dictionary of keyword arguments passed along to the model's :obj:`from_pretrained(...,
**model_kwargs)` function.
Returns:
:obj:`Tuple`: A tuple framework, model.
"""
if isinstance(model, str):
config = AutoConfig.from_pretrained(model, _from_pipeline=task, **model_kwargs)
else:
config = model.config
return infer_framework_load_model(
model, config, model_classes=model_classes, _from_pipeline=task, task=task, framework=framework, **model_kwargs
)
def get_framework(model, revision: Optional[str] = None): def get_framework(model, revision: Optional[str] = None):
""" """
Select framework (TensorFlow or PyTorch) to use. Select framework (TensorFlow or PyTorch) to use.
...@@ -534,7 +619,7 @@ class Pipeline(_ScikitCompat): ...@@ -534,7 +619,7 @@ class Pipeline(_ScikitCompat):
): ):
if framework is None: if framework is None:
framework, model = infer_framework_from_model(model) framework, model = infer_framework_load_model(model, config=model.config)
self.task = task self.task = task
self.model = model self.model = model
......
...@@ -18,6 +18,8 @@ from transformers import ( ...@@ -18,6 +18,8 @@ from transformers import (
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
AutoTokenizer, AutoTokenizer,
BlenderbotSmallForConditionalGeneration,
BlenderbotSmallTokenizer,
Conversation, Conversation,
ConversationalPipeline, ConversationalPipeline,
is_torch_available, is_torch_available,
...@@ -389,3 +391,32 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas ...@@ -389,3 +391,32 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
self.assertEqual(result[0].generated_responses[1], "i don't have any plans yet. i'm not sure what to do yet.") self.assertEqual(result[0].generated_responses[1], "i don't have any plans yet. i'm not sure what to do yet.")
self.assertEqual(result[1].past_user_inputs[1], "What's your name?") self.assertEqual(result[1].past_user_inputs[1], "What's your name?")
self.assertEqual(result[1].generated_responses[1], "i don't have a name, but i'm going to see a horror movie.") self.assertEqual(result[1].generated_responses[1], "i don't have a name, but i'm going to see a horror movie.")
@require_torch
@slow
def test_from_pipeline_conversation(self):
model_id = "facebook/blenderbot_small-90M"
# from model id
conversation_agent_from_model_id = pipeline("conversational", model=model_id, tokenizer=model_id)
# from model object
model = BlenderbotSmallForConditionalGeneration.from_pretrained(model_id)
tokenizer = BlenderbotSmallTokenizer.from_pretrained(model_id)
conversation_agent_from_model = pipeline("conversational", model=model, tokenizer=tokenizer)
conversation = Conversation("My name is Sarah and I live in London")
conversation_copy = Conversation("My name is Sarah and I live in London")
result_model_id = conversation_agent_from_model_id([conversation])
result_model = conversation_agent_from_model([conversation_copy])
# check for equality
self.assertEqual(
result_model_id.generated_responses[0],
"hi sarah, i live in london as well. do you have any plans for the weekend?",
)
self.assertEqual(
result_model_id.generated_responses[0],
result_model.generated_responses[0],
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment