"docs/vscode:/vscode.git/clone" did not exist on "e52f1cb669b77909a319bf084d45d2b2eb92c372"
Unverified Commit 06a6fea7 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Instantiate model only once in pipeline (#10888)



* Instantiate model only once in pipeline

* Remove documentation of deprecated method

* Add FutureWarning

* Update src/transformers/pipelines/base.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent cc2366bb
......@@ -47,6 +47,4 @@ Data format
Utilities
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: transformers.pipelines.get_framework
.. autoclass:: transformers.pipelines.PipelineException
......@@ -34,7 +34,7 @@ from .base import (
PipelineDataFormat,
PipelineException,
get_default_model,
get_framework,
infer_framework_from_model,
)
from .conversational import Conversation, ConversationalPipeline
from .feature_extraction import FeatureExtractionPipeline
......@@ -341,10 +341,6 @@ def pipeline(
# At that point framework might still be undetermined
model = get_default_model(targeted_task, framework, task_options)
framework = framework or get_framework(model)
task_class, model_class = targeted_task["impl"], targeted_task[framework]
# Try to infer tokenizer from model or config name (if provided as str)
if tokenizer is None:
if isinstance(model, str):
......@@ -365,6 +361,12 @@ def pipeline(
elif isinstance(config, str):
modelcard = config
# Infer the framework form the model
if framework is None:
framework, model = infer_framework_from_model(model, targeted_task, revision=revision)
task_class, model_class = targeted_task["impl"], targeted_task[framework]
# Instantiate tokenizer if needed
if isinstance(tokenizer, (str, tuple)):
if isinstance(tokenizer, tuple):
......@@ -406,16 +408,15 @@ def pipeline(
)
model = model_class.from_pretrained(model, config=config, revision=revision, **model_kwargs)
if task == "translation" and model.config.task_specific_params:
for key in model.config.task_specific_params:
if key.startswith("translation"):
task = key
warnings.warn(
'"translation" task was used, instead of "translation_XX_to_YY", defaulting to "{}"'.format(
task
),
UserWarning,
)
break
if task == "translation" and model.config.task_specific_params:
for key in model.config.task_specific_params:
if key.startswith("translation"):
task = key
warnings.warn(
f'"translation" task was used, instead of "translation_XX_to_YY", defaulting to "{task}"',
UserWarning,
)
break
return task_class(model=model, tokenizer=tokenizer, modelcard=modelcard, framework=framework, task=task, **kwargs)
......@@ -17,6 +17,7 @@ import json
import os
import pickle
import sys
import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager
from os.path import abspath, exists
......@@ -46,6 +47,55 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
def infer_framework_from_model(model, model_classes: Optional[Dict[str, type]] = None, revision: Optional[str] = None):
"""
Select framework (TensorFlow or PyTorch) to use from the :obj:`model` passed. Returns a tuple (framework, model).
If :obj:`model` is instantiated, this function will just infer the framework from the model class. Otherwise
:obj:`model` is actually a checkpoint name and this method will try to instantiate it using :obj:`model_classes`.
Since we don't want to instantiate the model twice, this model is returned for use by the pipeline.
If both frameworks are installed and available for :obj:`model`, PyTorch is selected.
Args:
model (:obj:`str`, :class:`~transformers.PreTrainedModel` or :class:`~transformers.TFPreTrainedModel`):
The model to infer the framework from. If :obj:`str`, a checkpoint name. The model to infer the framewrok
from.
model_classes (dictionary :obj:`str` to :obj:`type`, `optional`):
A mapping framework to class.
revision (:obj:`str`, `optional`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
Returns:
:obj:`Tuple`: A tuple framework, model.
"""
if not is_tf_available() and not is_torch_available():
raise RuntimeError(
"At least one of TensorFlow 2.0 or PyTorch should be installed. "
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
"To install PyTorch, read the instructions at https://pytorch.org/."
)
if isinstance(model, str):
if is_torch_available() and not is_tf_available():
model_class = model_classes.get("pt", AutoModel)
model = model_class.from_pretrained(model, revision=revision)
elif is_tf_available() and not is_torch_available():
model_class = model_classes.get("tf", TFAutoModel)
model = model_class.from_pretrained(model, revision=revision)
else:
try:
model_class = model_classes.get("pt", AutoModel)
model = model_class.from_pretrained(model, revision=revision)
except OSError:
model_class = model_classes.get("tf", TFAutoModel)
model = model_class.from_pretrained(model, revision=revision)
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
return framework, model
def get_framework(model, revision: Optional[str] = None):
"""
Select framework (TensorFlow or PyTorch) to use.
......@@ -55,6 +105,10 @@ def get_framework(model, revision: Optional[str] = None):
If both frameworks are installed, picks the one corresponding to the model passed (either a model class or
the model name). If no specific model is provided, defaults to using PyTorch.
"""
warnings.warn(
"`get_framework` is deprecated and will be removed in v5, use `infer_framework_from_model` instead.",
FutureWarning,
)
if not is_tf_available() and not is_torch_available():
raise RuntimeError(
"At least one of TensorFlow 2.0 or PyTorch should be installed. "
......@@ -474,7 +528,7 @@ class Pipeline(_ScikitCompat):
):
if framework is None:
framework = get_framework(model)
framework = infer_framework_from_model(model)
self.task = task
self.model = model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment