"tests/ctrl/test_modeling_ctrl.py" did not exist on "fbd02d46939f8c1779d06b0881c41a4355d725ff"
Unverified Commit 49cd736a authored by Aaron Pham's avatar Aaron Pham Committed by GitHub
Browse files

feat: add pipeline registry abstraction (#17905)



* feat: add pipeline registry abstraction

- added `PipelineRegistry` abstraction
- updates `add_new_pipeline.mdx` (english docs) to reflect the api addition
- migrate `check_task` and `get_supported_tasks` from
  transformers/pipelines/__init__.py to
  transformers/pipelines/base.py#PipelineRegistry.{check_task,get_supported_tasks}
Signed-off-by: default avatarAaron Pham <29749331+aarnphm@users.noreply.github.com>

* fix: update with upstream/main

chore: Apply suggestions from sgugger's code review
Signed-off-by: default avatarAaron Pham <29749331+aarnphm@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* chore: PR updates

- revert src/transformers/dependency_versions_table.py from upstream/main
- updates pipeline registry to use global variables
Signed-off-by: default avatarAaron Pham <29749331+aarnphm@users.noreply.github.com>

* tests: add tests for pipeline registry
Signed-off-by: default avatarAaron Pham <29749331+aarnphm@users.noreply.github.com>

* tests: add test for output warning.
Signed-off-by: default avatarAaron Pham <29749331+aarnphm@users.noreply.github.com>

* chore: fmt and cleanup unused imports
Signed-off-by: default avatarAaron Pham <29749331+aarnphm@users.noreply.github.com>

* fix: change imports to top of the file and address comments
Signed-off-by: default avatarAaron Pham <29749331+aarnphm@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 9cb7cef2
...@@ -111,8 +111,35 @@ of arguments for ease of use (audio files, can be filenames, URLs or pure bytes) ...@@ -111,8 +111,35 @@ of arguments for ease of use (audio files, can be filenames, URLs or pure bytes)
## Adding it to the list of supported tasks ## Adding it to the list of supported tasks
Go to `src/transformers/pipelines/__init__.py` and fill in `SUPPORTED_TASKS` with your newly created pipeline. To register your `new-task` to the list of supported tasks, provide the
If possible it should provide a default model. following task template:
```python
my_new_task = {
"impl": MyPipeline,
"tf": (),
"pt": (AutoModelForAudioClassification,) if is_torch_available() else (),
"default": {"model": {"pt": "user/awesome_model"}},
"type": "audio", # current support type: text, audio, image, multimodal
}
```
<Tip>
Take a look at the `src/transformers/pipelines/__init__.py` and the dictionary `SUPPORTED_TASKS` to see how a task is defined.
If possible your custom task should provide a default model.
</Tip>
Then add your custom task to the list of supported tasks via
`PIPELINE_REGISTRY.register_pipeline()`:
```python
from transformers.pipelines import PIPELINE_REGISTRY
PIPELINE_REGISTRY.register_pipeline("new-task", my_new_task)
```
## Adding tests ## Adding tests
......
...@@ -41,6 +41,7 @@ from .base import ( ...@@ -41,6 +41,7 @@ from .base import (
Pipeline, Pipeline,
PipelineDataFormat, PipelineDataFormat,
PipelineException, PipelineException,
PipelineRegistry,
get_default_model_and_revision, get_default_model_and_revision,
infer_framework_load_model, infer_framework_load_model,
) )
...@@ -309,14 +310,14 @@ for task, values in SUPPORTED_TASKS.items(): ...@@ -309,14 +310,14 @@ for task, values in SUPPORTED_TASKS.items():
elif values["type"] != "multimodal": elif values["type"] != "multimodal":
raise ValueError(f"SUPPORTED_TASK {task} contains invalid type {values['type']}") raise ValueError(f"SUPPORTED_TASK {task} contains invalid type {values['type']}")
PIPELINE_REGISTRY = PipelineRegistry(supported_tasks=SUPPORTED_TASKS, task_aliases=TASK_ALIASES)
def get_supported_tasks() -> List[str]: def get_supported_tasks() -> List[str]:
""" """
Returns a list of supported task strings. Returns a list of supported task strings.
""" """
supported_tasks = list(SUPPORTED_TASKS.keys()) + list(TASK_ALIASES.keys()) return PIPELINE_REGISTRY.get_supported_tasks()
supported_tasks.sort()
return supported_tasks
def get_task(model: str, use_auth_token: Optional[str] = None) -> str: def get_task(model: str, use_auth_token: Optional[str] = None) -> str:
...@@ -375,20 +376,7 @@ def check_task(task: str) -> Tuple[Dict, Any]: ...@@ -375,20 +376,7 @@ def check_task(task: str) -> Tuple[Dict, Any]:
""" """
if task in TASK_ALIASES: return PIPELINE_REGISTRY.check_task(task)
task = TASK_ALIASES[task]
if task in SUPPORTED_TASKS:
targeted_task = SUPPORTED_TASKS[task]
return targeted_task, None
if task.startswith("translation"):
tokens = task.split("_")
if len(tokens) == 4 and tokens[0] == "translation" and tokens[2] == "to":
targeted_task = SUPPORTED_TASKS["translation"]
return targeted_task, (tokens[1], tokens[3])
raise KeyError(f"Invalid translation task {task}, use 'translation_XX_to_YY' format")
raise KeyError(f"Unknown task {task}, available tasks are {get_supported_tasks() + ['translation_XX_to_YY']}")
def pipeline( def pipeline(
......
...@@ -1087,3 +1087,41 @@ class ChunkPipeline(Pipeline): ...@@ -1087,3 +1087,41 @@ class ChunkPipeline(Pipeline):
model_iterator = PipelinePackIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size) model_iterator = PipelinePackIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size)
final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params) final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params)
return final_iterator return final_iterator
class PipelineRegistry:
def __init__(self, supported_tasks: Dict[str, Any], task_aliases: Dict[str, str]) -> None:
self.supported_tasks = supported_tasks
self.task_aliases = task_aliases
def get_supported_tasks(self) -> List[str]:
supported_task = list(self.supported_tasks.keys()) + list(self.task_aliases.keys())
supported_task.sort()
return supported_task
def check_task(self, task: str) -> Tuple[Dict, Any]:
if task in self.task_aliases:
task = self.task_aliases[task]
if task in self.supported_tasks:
targeted_task = self.supported_tasks[task]
return targeted_task, None
if task.startswith("translation"):
tokens = task.split("_")
if len(tokens) == 4 and tokens[0] == "translation" and tokens[2] == "to":
targeted_task = self.supported_tasks["translation"]
return targeted_task, (tokens[1], tokens[3])
raise KeyError(f"Invalid translation task {task}, use 'translation_XX_to_YY' format")
raise KeyError(
f"Unknown task {task}, available tasks are {self.get_supported_tasks() + ['translation_XX_to_YY']}"
)
def register_pipeline(self, task: str, task_impl: Dict[str, Any]) -> None:
if task in self.supported_tasks:
logger.warning(f"{task} is already registered. Overwriting pipeline for task {task}...")
self.supported_tasks[task] = task_impl
def to_dict(self):
return self.supported_tasks
...@@ -28,6 +28,7 @@ from transformers import ( ...@@ -28,6 +28,7 @@ from transformers import (
FEATURE_EXTRACTOR_MAPPING, FEATURE_EXTRACTOR_MAPPING,
TOKENIZER_MAPPING, TOKENIZER_MAPPING,
AutoFeatureExtractor, AutoFeatureExtractor,
AutoModelForSequenceClassification,
AutoTokenizer, AutoTokenizer,
DistilBertForSequenceClassification, DistilBertForSequenceClassification,
IBertConfig, IBertConfig,
...@@ -35,9 +36,10 @@ from transformers import ( ...@@ -35,9 +36,10 @@ from transformers import (
TextClassificationPipeline, TextClassificationPipeline,
pipeline, pipeline,
) )
from transformers.pipelines import get_task from transformers.pipelines import PIPELINE_REGISTRY, get_task
from transformers.pipelines.base import _pad from transformers.pipelines.base import Pipeline, _pad
from transformers.testing_utils import ( from transformers.testing_utils import (
CaptureLogger,
is_pipeline_test, is_pipeline_test,
nested_simplify, nested_simplify,
require_scatter, require_scatter,
...@@ -46,6 +48,7 @@ from transformers.testing_utils import ( ...@@ -46,6 +48,7 @@ from transformers.testing_utils import (
require_torch, require_torch,
slow, slow,
) )
from transformers.utils import logging as transformers_logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -746,3 +749,51 @@ class PipelineUtilsTest(unittest.TestCase): ...@@ -746,3 +749,51 @@ class PipelineUtilsTest(unittest.TestCase):
models_are_equal = False models_are_equal = False
return models_are_equal return models_are_equal
class CustomPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if "maybe_arg" in kwargs:
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
return preprocess_kwargs, {}, {}
def preprocess(self, text, maybe_arg=2):
input_ids = self.tokenizer(text, return_tensors="pt")
return input_ids
def _forward(self, model_inputs):
outputs = self.model(**model_inputs)
return outputs
def postprocess(self, model_outputs):
return model_outputs["logits"].softmax(-1).numpy()
@is_pipeline_test
class PipelineRegistryTest(unittest.TestCase):
def test_warning_logs(self):
transformers_logging.set_verbosity_debug()
logger_ = transformers_logging.get_logger("transformers.pipelines.base")
alias = "text-classification"
with CaptureLogger(logger_) as cm:
PIPELINE_REGISTRY.register_pipeline(alias, {})
self.assertIn(f"{alias} is already registered", cm.out)
@require_torch
def test_register_pipeline(self):
custom_text_classification = {
"impl": CustomPipeline,
"tf": (),
"pt": (AutoModelForSequenceClassification,),
"default": {"model": {"pt": "hf-internal-testing/tiny-random-distilbert"}},
"type": "text",
}
PIPELINE_REGISTRY.register_pipeline("custom-text-classification", custom_text_classification)
assert "custom-text-classification" in PIPELINE_REGISTRY.get_supported_tasks()
task_def, _ = PIPELINE_REGISTRY.check_task("custom-text-classification")
self.assertEqual(task_def, custom_text_classification)
self.assertEqual(task_def["type"], "text")
self.assertEqual(task_def["impl"], CustomPipeline)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment