Unverified Commit 5c153079 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding some quality of life for `pipeline` function. (#14322)



* Adding some quality of life for `pipeline` function.

* Update docs/source/main_classes/pipelines.rst
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/pipelines/__init__.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Improve the tests.
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 321eb562
...@@ -45,7 +45,7 @@ The pipeline abstraction ...@@ -45,7 +45,7 @@ The pipeline abstraction
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The `pipeline` abstraction is a wrapper around all the other available pipelines. It is instantiated as any other The `pipeline` abstraction is a wrapper around all the other available pipelines. It is instantiated as any other
pipeline but requires an additional argument which is the `task`. pipeline but can provide additional quality of life.
Simple call on one item: Simple call on one item:
...@@ -55,6 +55,15 @@ Simple call on one item: ...@@ -55,6 +55,15 @@ Simple call on one item:
>>> pipe("This restaurant is awesome") >>> pipe("This restaurant is awesome")
[{'label': 'POSITIVE', 'score': 0.9998743534088135}] [{'label': 'POSITIVE', 'score': 0.9998743534088135}]
If you want to use a specific model from the `hub <https://huggingface.co>`__ you can ignore the task if the model on
the hub already defines it:
.. code-block::
>>> pipe = pipeline(model="roberta-large-mnli")
>>> pipe("This restaurant is awesome")
[{'label': 'POSITIVE', 'score': 0.9998743534088135}]
To call a pipeline on many items, you can either call with a `list`. To call a pipeline on many items, you can either call with a `list`.
.. code-block:: .. code-block::
...@@ -226,6 +235,32 @@ For users, a rule of thumb is: ...@@ -226,6 +235,32 @@ For users, a rule of thumb is:
- The larger the GPU the more likely batching is going to be more interesting - The larger the GPU the more likely batching is going to be more interesting
- As soon as you enable batching, make sure you can handle OOMs nicely. - As soon as you enable batching, make sure you can handle OOMs nicely.
Pipeline custom code
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
If you want to override a specific pipeline.
Don't hesitate to create an issue for your task at hand, the goal of the pipeline is to be easy to use and support most
cases, so :obj:`transformers` could maybe support your use case.
If you want to try simply you can:
- Subclass your pipeline of choice
.. code-block::
class MyPipeline(TextClassificationPipeline):
def postprocess(...):
...
scores = scores * 100
...
my_pipeline = MyPipeline(model=model, tokenizer=tokenizer, ...)
# or if you use `pipeline` function, then:
my_pipeline = pipeline(model="xxxx", pipeline_class=MyPipeline)
That should enable you to do all the custom code you want.
Implementing a pipeline Implementing a pipeline
......
...@@ -2,6 +2,9 @@ ...@@ -2,6 +2,9 @@
# There's no way to ignore "F401 '...' imported but unused" warnings in this # There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all. # module, but to preserve other warnings. So, don't check this module at all.
import io
import json
# coding=utf-8 # coding=utf-8
# Copyright 2018 The HuggingFace Inc. team. # Copyright 2018 The HuggingFace Inc. team.
# #
...@@ -21,7 +24,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union ...@@ -21,7 +24,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
from ..configuration_utils import PretrainedConfig from ..configuration_utils import PretrainedConfig
from ..feature_extraction_utils import PreTrainedFeatureExtractor from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..file_utils import is_tf_available, is_torch_available from ..file_utils import http_get, is_tf_available, is_torch_available
from ..models.auto.configuration_auto import AutoConfig from ..models.auto.configuration_auto import AutoConfig
from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
...@@ -248,6 +251,29 @@ SUPPORTED_TASKS = { ...@@ -248,6 +251,29 @@ SUPPORTED_TASKS = {
} }
def get_task(model: str, use_auth_token: Optional[str] = None) -> str:
tmp = io.BytesIO()
headers = {}
if use_auth_token:
headers["Authorization"] = f"Bearer {use_auth_token}"
try:
http_get(f"https://huggingface.co/api/models/{model}", tmp, headers=headers)
tmp.seek(0)
body = tmp.read()
data = json.loads(body)
except Exception as e:
raise RuntimeError(f"Instantiating a pipeline without a task set raised an error: {e}")
if "pipeline_tag" not in data:
raise RuntimeError(
f"The model {model} does not seem to have a correct `pipeline_tag` set to infer the task automatically"
)
if data.get("library_name", "transformers") != "transformers":
raise RuntimeError(f"This model is meant to be used with {data['library_name']} not with transformers")
task = data["pipeline_tag"]
return task
def check_task(task: str) -> Tuple[Dict, Any]: def check_task(task: str) -> Tuple[Dict, Any]:
""" """
Checks an incoming task string, to validate it's correct and return the default Pipeline and Model classes, and Checks an incoming task string, to validate it's correct and return the default Pipeline and Model classes, and
...@@ -299,7 +325,7 @@ def check_task(task: str) -> Tuple[Dict, Any]: ...@@ -299,7 +325,7 @@ def check_task(task: str) -> Tuple[Dict, Any]:
def pipeline( def pipeline(
task: str, task: str = None,
model: Optional = None, model: Optional = None,
config: Optional[Union[str, PretrainedConfig]] = None, config: Optional[Union[str, PretrainedConfig]] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
...@@ -309,6 +335,7 @@ def pipeline( ...@@ -309,6 +335,7 @@ def pipeline(
use_fast: bool = True, use_fast: bool = True,
use_auth_token: Optional[Union[str, bool]] = None, use_auth_token: Optional[Union[str, bool]] = None,
model_kwargs: Dict[str, Any] = None, model_kwargs: Dict[str, Any] = None,
pipeline_class: Optional[Any] = None,
**kwargs **kwargs
) -> Pipeline: ) -> Pipeline:
""" """
...@@ -422,6 +449,14 @@ def pipeline( ...@@ -422,6 +449,14 @@ def pipeline(
""" """
if model_kwargs is None: if model_kwargs is None:
model_kwargs = {} model_kwargs = {}
if task is None and model is None:
raise RuntimeError(
"Impossible to instantiate a pipeline without either a task or a model"
"being specified."
"Please provide a task class or a model"
)
if model is None and tokenizer is not None: if model is None and tokenizer is not None:
raise RuntimeError( raise RuntimeError(
"Impossible to instantiate a pipeline with tokenizer specified but not the model " "Impossible to instantiate a pipeline with tokenizer specified but not the model "
...@@ -435,9 +470,18 @@ def pipeline( ...@@ -435,9 +470,18 @@ def pipeline(
"Please provide a PreTrainedModel class or a path/identifier to a pretrained model when providing feature_extractor." "Please provide a PreTrainedModel class or a path/identifier to a pretrained model when providing feature_extractor."
) )
if task is None and model is not None:
if not isinstance(model, str):
raise RuntimeError(
"Inferring the task automatically requires to check the hub with a model_id defined as a `str`."
f"{model} is not a valid model_id."
)
task = get_task(model, use_auth_token)
# Retrieve the task # Retrieve the task
targeted_task, task_options = check_task(task) targeted_task, task_options = check_task(task)
task_class = targeted_task["impl"] if pipeline_class is None:
pipeline_class = targeted_task["impl"]
# Use default model/config/tokenizer for the task if no model is provided # Use default model/config/tokenizer for the task if no model is provided
if model is None: if model is None:
...@@ -549,4 +593,4 @@ def pipeline( ...@@ -549,4 +593,4 @@ def pipeline(
if feature_extractor is not None: if feature_extractor is not None:
kwargs["feature_extractor"] = feature_extractor kwargs["feature_extractor"] = feature_extractor
return task_class(model=model, framework=framework, task=task, **kwargs) return pipeline_class(model=model, framework=framework, task=task, **kwargs)
...@@ -29,8 +29,10 @@ from transformers import ( ...@@ -29,8 +29,10 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
IBertConfig, IBertConfig,
RobertaConfig, RobertaConfig,
TextClassificationPipeline,
pipeline, pipeline,
) )
from transformers.pipelines import get_task
from transformers.pipelines.base import _pad from transformers.pipelines.base import _pad
from transformers.testing_utils import is_pipeline_test, require_torch from transformers.testing_utils import is_pipeline_test, require_torch
...@@ -261,6 +263,29 @@ class CommonPipelineTest(unittest.TestCase): ...@@ -261,6 +263,29 @@ class CommonPipelineTest(unittest.TestCase):
for output in text_classifier(dataset): for output in text_classifier(dataset):
self.assertEqual(output, {"label": ANY(str), "score": ANY(float)}) self.assertEqual(output, {"label": ANY(str), "score": ANY(float)})
@require_torch
def test_check_task_auto_inference(self):
pipe = pipeline(model="Narsil/tiny-distilbert-sequence-classification")
self.assertIsInstance(pipe, TextClassificationPipeline)
@require_torch
def test_pipeline_override(self):
class MyPipeline(TextClassificationPipeline):
pass
text_classifier = pipeline(model="Narsil/tiny-distilbert-sequence-classification", pipeline_class=MyPipeline)
self.assertIsInstance(text_classifier, MyPipeline)
def test_check_task(self):
task = get_task("gpt2")
self.assertEqual(task, "text-generation")
with self.assertRaises(RuntimeError):
# Wrong framework
get_task("espnet/siddhana_slurp_entity_asr_train_asr_conformer_raw_en_word_valid.acc.ave_10best")
@is_pipeline_test @is_pipeline_test
class PipelinePadTest(unittest.TestCase): class PipelinePadTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment