Unverified Commit 5c153079 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding some quality of life for `pipeline` function. (#14322)



* Adding some quality of life for `pipeline` function.

* Update docs/source/main_classes/pipelines.rst
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/pipelines/__init__.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Improve the tests.
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 321eb562
......@@ -45,7 +45,7 @@ The pipeline abstraction
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The `pipeline` abstraction is a wrapper around all the other available pipelines. It is instantiated as any other
pipeline but requires an additional argument which is the `task`.
pipeline but can provide additional quality of life.
Simple call on one item:
......@@ -55,6 +55,15 @@ Simple call on one item:
>>> pipe("This restaurant is awesome")
[{'label': 'POSITIVE', 'score': 0.9998743534088135}]
If you want to use a specific model from the `hub <https://huggingface.co>`__ you can ignore the task if the model on
the hub already defines it:
.. code-block::
>>> pipe = pipeline(model="roberta-large-mnli")
>>> pipe("This restaurant is awesome")
[{'label': 'POSITIVE', 'score': 0.9998743534088135}]
To call a pipeline on many items, you can either call with a `list`.
.. code-block::
......@@ -226,6 +235,32 @@ For users, a rule of thumb is:
- The larger the GPU the more likely batching is going to be more interesting
- As soon as you enable batching, make sure you can handle OOMs nicely.
Pipeline custom code
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
If you want to override a specific pipeline.
Don't hesitate to create an issue for your task at hand, the goal of the pipeline is to be easy to use and support most
cases, so :obj:`transformers` could maybe support your use case.
If you want to try simply you can:
- Subclass your pipeline of choice
.. code-block::
class MyPipeline(TextClassificationPipeline):
def postprocess(...):
...
scores = scores * 100
...
my_pipeline = MyPipeline(model=model, tokenizer=tokenizer, ...)
# or if you use `pipeline` function, then:
my_pipeline = pipeline(model="xxxx", pipeline_class=MyPipeline)
That should enable you to do all the custom code you want.
Implementing a pipeline
......
......@@ -2,6 +2,9 @@
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
import io
import json
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
......@@ -21,7 +24,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
from ..configuration_utils import PretrainedConfig
from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..file_utils import is_tf_available, is_torch_available
from ..file_utils import http_get, is_tf_available, is_torch_available
from ..models.auto.configuration_auto import AutoConfig
from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
......@@ -248,6 +251,29 @@ SUPPORTED_TASKS = {
}
def get_task(model: str, use_auth_token: Optional[str] = None) -> str:
tmp = io.BytesIO()
headers = {}
if use_auth_token:
headers["Authorization"] = f"Bearer {use_auth_token}"
try:
http_get(f"https://huggingface.co/api/models/{model}", tmp, headers=headers)
tmp.seek(0)
body = tmp.read()
data = json.loads(body)
except Exception as e:
raise RuntimeError(f"Instantiating a pipeline without a task set raised an error: {e}")
if "pipeline_tag" not in data:
raise RuntimeError(
f"The model {model} does not seem to have a correct `pipeline_tag` set to infer the task automatically"
)
if data.get("library_name", "transformers") != "transformers":
raise RuntimeError(f"This model is meant to be used with {data['library_name']} not with transformers")
task = data["pipeline_tag"]
return task
def check_task(task: str) -> Tuple[Dict, Any]:
"""
Checks an incoming task string, to validate it's correct and return the default Pipeline and Model classes, and
......@@ -299,7 +325,7 @@ def check_task(task: str) -> Tuple[Dict, Any]:
def pipeline(
task: str,
task: str = None,
model: Optional = None,
config: Optional[Union[str, PretrainedConfig]] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
......@@ -309,6 +335,7 @@ def pipeline(
use_fast: bool = True,
use_auth_token: Optional[Union[str, bool]] = None,
model_kwargs: Dict[str, Any] = None,
pipeline_class: Optional[Any] = None,
**kwargs
) -> Pipeline:
"""
......@@ -422,6 +449,14 @@ def pipeline(
"""
if model_kwargs is None:
model_kwargs = {}
if task is None and model is None:
raise RuntimeError(
"Impossible to instantiate a pipeline without either a task or a model"
"being specified."
"Please provide a task class or a model"
)
if model is None and tokenizer is not None:
raise RuntimeError(
"Impossible to instantiate a pipeline with tokenizer specified but not the model "
......@@ -435,9 +470,18 @@ def pipeline(
"Please provide a PreTrainedModel class or a path/identifier to a pretrained model when providing feature_extractor."
)
if task is None and model is not None:
if not isinstance(model, str):
raise RuntimeError(
"Inferring the task automatically requires to check the hub with a model_id defined as a `str`."
f"{model} is not a valid model_id."
)
task = get_task(model, use_auth_token)
# Retrieve the task
targeted_task, task_options = check_task(task)
task_class = targeted_task["impl"]
if pipeline_class is None:
pipeline_class = targeted_task["impl"]
# Use default model/config/tokenizer for the task if no model is provided
if model is None:
......@@ -549,4 +593,4 @@ def pipeline(
if feature_extractor is not None:
kwargs["feature_extractor"] = feature_extractor
return task_class(model=model, framework=framework, task=task, **kwargs)
return pipeline_class(model=model, framework=framework, task=task, **kwargs)
......@@ -29,8 +29,10 @@ from transformers import (
AutoTokenizer,
IBertConfig,
RobertaConfig,
TextClassificationPipeline,
pipeline,
)
from transformers.pipelines import get_task
from transformers.pipelines.base import _pad
from transformers.testing_utils import is_pipeline_test, require_torch
......@@ -261,6 +263,29 @@ class CommonPipelineTest(unittest.TestCase):
for output in text_classifier(dataset):
self.assertEqual(output, {"label": ANY(str), "score": ANY(float)})
@require_torch
def test_check_task_auto_inference(self):
pipe = pipeline(model="Narsil/tiny-distilbert-sequence-classification")
self.assertIsInstance(pipe, TextClassificationPipeline)
@require_torch
def test_pipeline_override(self):
class MyPipeline(TextClassificationPipeline):
pass
text_classifier = pipeline(model="Narsil/tiny-distilbert-sequence-classification", pipeline_class=MyPipeline)
self.assertIsInstance(text_classifier, MyPipeline)
def test_check_task(self):
task = get_task("gpt2")
self.assertEqual(task, "text-generation")
with self.assertRaises(RuntimeError):
# Wrong framework
get_task("espnet/siddhana_slurp_entity_asr_train_asr_conformer_raw_en_word_valid.acc.ave_10best")
@is_pipeline_test
class PipelinePadTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment