Unverified Commit 06a6fea7 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Instantiate model only once in pipeline (#10888)



* Instantiate model only once in pipeline

* Remove documentation of deprecated method

* Add FutureWarning

* Update src/transformers/pipelines/base.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent cc2366bb
...@@ -47,6 +47,4 @@ Data format ...@@ -47,6 +47,4 @@ Data format
Utilities Utilities
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: transformers.pipelines.get_framework
.. autoclass:: transformers.pipelines.PipelineException .. autoclass:: transformers.pipelines.PipelineException
...@@ -34,7 +34,7 @@ from .base import ( ...@@ -34,7 +34,7 @@ from .base import (
PipelineDataFormat, PipelineDataFormat,
PipelineException, PipelineException,
get_default_model, get_default_model,
get_framework, infer_framework_from_model,
) )
from .conversational import Conversation, ConversationalPipeline from .conversational import Conversation, ConversationalPipeline
from .feature_extraction import FeatureExtractionPipeline from .feature_extraction import FeatureExtractionPipeline
...@@ -341,10 +341,6 @@ def pipeline( ...@@ -341,10 +341,6 @@ def pipeline(
# At that point framework might still be undetermined # At that point framework might still be undetermined
model = get_default_model(targeted_task, framework, task_options) model = get_default_model(targeted_task, framework, task_options)
framework = framework or get_framework(model)
task_class, model_class = targeted_task["impl"], targeted_task[framework]
# Try to infer tokenizer from model or config name (if provided as str) # Try to infer tokenizer from model or config name (if provided as str)
if tokenizer is None: if tokenizer is None:
if isinstance(model, str): if isinstance(model, str):
...@@ -365,6 +361,12 @@ def pipeline( ...@@ -365,6 +361,12 @@ def pipeline(
elif isinstance(config, str): elif isinstance(config, str):
modelcard = config modelcard = config
# Infer the framework form the model
if framework is None:
framework, model = infer_framework_from_model(model, targeted_task, revision=revision)
task_class, model_class = targeted_task["impl"], targeted_task[framework]
# Instantiate tokenizer if needed # Instantiate tokenizer if needed
if isinstance(tokenizer, (str, tuple)): if isinstance(tokenizer, (str, tuple)):
if isinstance(tokenizer, tuple): if isinstance(tokenizer, tuple):
...@@ -406,14 +408,13 @@ def pipeline( ...@@ -406,14 +408,13 @@ def pipeline(
) )
model = model_class.from_pretrained(model, config=config, revision=revision, **model_kwargs) model = model_class.from_pretrained(model, config=config, revision=revision, **model_kwargs)
if task == "translation" and model.config.task_specific_params: if task == "translation" and model.config.task_specific_params:
for key in model.config.task_specific_params: for key in model.config.task_specific_params:
if key.startswith("translation"): if key.startswith("translation"):
task = key task = key
warnings.warn( warnings.warn(
'"translation" task was used, instead of "translation_XX_to_YY", defaulting to "{}"'.format( f'"translation" task was used, instead of "translation_XX_to_YY", defaulting to "{task}"',
task
),
UserWarning, UserWarning,
) )
break break
......
...@@ -17,6 +17,7 @@ import json ...@@ -17,6 +17,7 @@ import json
import os import os
import pickle import pickle
import sys import sys
import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from os.path import abspath, exists from os.path import abspath, exists
...@@ -46,6 +47,55 @@ if TYPE_CHECKING: ...@@ -46,6 +47,55 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def infer_framework_from_model(model, model_classes: Optional[Dict[str, type]] = None, revision: Optional[str] = None):
"""
Select framework (TensorFlow or PyTorch) to use from the :obj:`model` passed. Returns a tuple (framework, model).
If :obj:`model` is instantiated, this function will just infer the framework from the model class. Otherwise
:obj:`model` is actually a checkpoint name and this method will try to instantiate it using :obj:`model_classes`.
Since we don't want to instantiate the model twice, this model is returned for use by the pipeline.
If both frameworks are installed and available for :obj:`model`, PyTorch is selected.
Args:
model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`):
The model to infer the framework from. If :obj:`str`, a checkpoint name. The model to infer the framewrok
from.
model_classes (dictionary :obj:`str` to :obj:`type`, `optional`):
A mapping framework to class.
revision (:obj:`str`, `optional`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
Returns:
:obj:`Tuple`: A tuple framework, model.
"""
if not is_tf_available() and not is_torch_available():
raise RuntimeError(
"At least one of TensorFlow 2.0 or PyTorch should be installed. "
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
"To install PyTorch, read the instructions at https://pytorch.org/."
)
if isinstance(model, str):
if is_torch_available() and not is_tf_available():
model_class = model_classes.get("pt", AutoModel)
model = model_class.from_pretrained(model, revision=revision)
elif is_tf_available() and not is_torch_available():
model_class = model_classes.get("tf", TFAutoModel)
model = model_class.from_pretrained(model, revision=revision)
else:
try:
model_class = model_classes.get("pt", AutoModel)
model = model_class.from_pretrained(model, revision=revision)
except OSError:
model_class = model_classes.get("tf", TFAutoModel)
model = model_class.from_pretrained(model, revision=revision)
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
return framework, model
def get_framework(model, revision: Optional[str] = None): def get_framework(model, revision: Optional[str] = None):
""" """
Select framework (TensorFlow or PyTorch) to use. Select framework (TensorFlow or PyTorch) to use.
...@@ -55,6 +105,10 @@ def get_framework(model, revision: Optional[str] = None): ...@@ -55,6 +105,10 @@ def get_framework(model, revision: Optional[str] = None):
If both frameworks are installed, picks the one corresponding to the model passed (either a model class or If both frameworks are installed, picks the one corresponding to the model passed (either a model class or
the model name). If no specific model is provided, defaults to using PyTorch. the model name). If no specific model is provided, defaults to using PyTorch.
""" """
warnings.warn(
"`get_framework` is deprecated and will be removed in v5, use `infer_framework_from_model` instead.",
FutureWarning,
)
if not is_tf_available() and not is_torch_available(): if not is_tf_available() and not is_torch_available():
raise RuntimeError( raise RuntimeError(
"At least one of TensorFlow 2.0 or PyTorch should be installed. " "At least one of TensorFlow 2.0 or PyTorch should be installed. "
...@@ -474,7 +528,7 @@ class Pipeline(_ScikitCompat): ...@@ -474,7 +528,7 @@ class Pipeline(_ScikitCompat):
): ):
if framework is None: if framework is None:
framework = get_framework(model) framework = infer_framework_from_model(model)
self.task = task self.task = task
self.model = model self.model = model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment