Unverified Commit 586dcf6b authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fixing issue where generic model types wouldn't load properly with the pipeline (#18392)

* Adding a better error message when the model is improperly configured

within transformers.

* Update src/transformers/pipelines/__init__.py

* Black version.

* Overriding task aliases so that tokenizer+feature_extractor

values are correct.

* Fixing task aliases by overriding their names early

* X.

* Fixing feature-extraction.

* black again.

* Normalizing `translation` too.

* Fixing last few corner cases.

translation need to use its non normalized name (translation_XX_to_YY,
so that the task_specific_params are correctly overloaded).
This can be removed and cleaned up in a later PR.

`speech-encode-decoder` actually REQUIRES to pass a `tokenizer` manually
so the error needs to be discarded when the `tokenizer` is already
there.

* doc-builder fix.

* Fixing the real issue.

* Removing dead code.

* Do not import the actual config classes.
parent 14928921
...@@ -311,6 +311,11 @@ SUPPORTED_TASKS = { ...@@ -311,6 +311,11 @@ SUPPORTED_TASKS = {
NO_FEATURE_EXTRACTOR_TASKS = set() NO_FEATURE_EXTRACTOR_TASKS = set()
NO_TOKENIZER_TASKS = set() NO_TOKENIZER_TASKS = set()
# Those model configs are special, they are generic over their task, meaning
# any tokenizer/feature_extractor might be use for a given model so we cannot
# use the statically defined TOKENIZER_MAPPING and FEATURE_EXTRACTOR_MAPPING to
# see if the model defines such objects or not.
MULTI_MODEL_CONFIGS = {"VisionTextDualEncoderConfig", "SpeechEncoderDecoderConfig"}
for task, values in SUPPORTED_TASKS.items(): for task, values in SUPPORTED_TASKS.items():
if values["type"] == "text": if values["type"] == "text":
NO_FEATURE_EXTRACTOR_TASKS.add(task) NO_FEATURE_EXTRACTOR_TASKS.add(task)
...@@ -380,8 +385,9 @@ def check_task(task: str) -> Tuple[Dict, Any]: ...@@ -380,8 +385,9 @@ def check_task(task: str) -> Tuple[Dict, Any]:
- `"zero-shot-image-classification"` - `"zero-shot-image-classification"`
Returns: Returns:
(task_defaults`dict`, task_options: (`tuple`, None)) The actual dictionary required to initialize the pipeline (normalized_task: `str`, task_defaults: `dict`, task_options: (`tuple`, None)) The normalized task name
and some extra task options for parametrized tasks like "translation_XX_to_YY" (removed alias and options). The actual dictionary required to initialize the pipeline and some extra task
options for parametrized tasks like "translation_XX_to_YY"
""" """
...@@ -614,7 +620,7 @@ def pipeline( ...@@ -614,7 +620,7 @@ def pipeline(
model, module_file + ".py", class_name, revision=revision, use_auth_token=use_auth_token model, module_file + ".py", class_name, revision=revision, use_auth_token=use_auth_token
) )
else: else:
targeted_task, task_options = check_task(task) normalized_task, targeted_task, task_options = check_task(task)
if pipeline_class is None: if pipeline_class is None:
pipeline_class = targeted_task["impl"] pipeline_class = targeted_task["impl"]
...@@ -667,12 +673,36 @@ def pipeline( ...@@ -667,12 +673,36 @@ def pipeline(
load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None
if (
tokenizer is None
and not load_tokenizer
and normalized_task not in NO_TOKENIZER_TASKS
# Using class name to avoid importing the real class.
and model_config.__class__.__name__ in MULTI_MODEL_CONFIGS
):
# This is a special category of models, that are fusions of multiple models
# so the model_config might not define a tokenizer, but it seems to be
# necessary for the task, so we're force-trying to load it.
load_tokenizer = True
if (
feature_extractor is None
and not load_feature_extractor
and normalized_task not in NO_FEATURE_EXTRACTOR_TASKS
# Using class name to avoid importing the real class.
and model_config.__class__.__name__ in MULTI_MODEL_CONFIGS
):
# This is a special category of models, that are fusions of multiple models
# so the model_config might not define a tokenizer, but it seems to be
# necessary for the task, so we're force-trying to load it.
load_feature_extractor = True
if task in NO_TOKENIZER_TASKS: if task in NO_TOKENIZER_TASKS:
# These will never require a tokenizer. # These will never require a tokenizer.
# the model on the other hand might have a tokenizer, but # the model on the other hand might have a tokenizer, but
# the files could be missing from the hub, instead of failing # the files could be missing from the hub, instead of failing
# on such repos, we just force to not load it. # on such repos, we just force to not load it.
load_tokenizer = False load_tokenizer = False
if task in NO_FEATURE_EXTRACTOR_TASKS: if task in NO_FEATURE_EXTRACTOR_TASKS:
load_feature_extractor = False load_feature_extractor = False
......
...@@ -630,7 +630,6 @@ class PipedPipelineDataFormat(PipelineDataFormat): ...@@ -630,7 +630,6 @@ class PipedPipelineDataFormat(PipelineDataFormat):
for line in sys.stdin: for line in sys.stdin:
# Split for multi-columns # Split for multi-columns
if "\t" in line: if "\t" in line:
line = line.split("\t") line = line.split("\t")
if self.column: if self.column:
# Dictionary to map arguments # Dictionary to map arguments
...@@ -752,7 +751,6 @@ class Pipeline(_ScikitCompat): ...@@ -752,7 +751,6 @@ class Pipeline(_ScikitCompat):
binary_output: bool = False, binary_output: bool = False,
**kwargs, **kwargs,
): ):
if framework is None: if framework is None:
framework, model = infer_framework_load_model(model, config=model.config) framework, model = infer_framework_load_model(model, config=model.config)
...@@ -1123,18 +1121,19 @@ class PipelineRegistry: ...@@ -1123,18 +1121,19 @@ class PipelineRegistry:
supported_task.sort() supported_task.sort()
return supported_task return supported_task
def check_task(self, task: str) -> Tuple[Dict, Any]: def check_task(self, task: str) -> Tuple[str, Dict, Any]:
if task in self.task_aliases: if task in self.task_aliases:
task = self.task_aliases[task] task = self.task_aliases[task]
if task in self.supported_tasks: if task in self.supported_tasks:
targeted_task = self.supported_tasks[task] targeted_task = self.supported_tasks[task]
return targeted_task, None return task, targeted_task, None
if task.startswith("translation"): if task.startswith("translation"):
tokens = task.split("_") tokens = task.split("_")
if len(tokens) == 4 and tokens[0] == "translation" and tokens[2] == "to": if len(tokens) == 4 and tokens[0] == "translation" and tokens[2] == "to":
targeted_task = self.supported_tasks["translation"] targeted_task = self.supported_tasks["translation"]
return targeted_task, (tokens[1], tokens[3]) task = "translation"
return task, targeted_task, (tokens[1], tokens[3])
raise KeyError(f"Invalid translation task {task}, use 'translation_XX_to_YY' format") raise KeyError(f"Invalid translation task {task}, use 'translation_XX_to_YY' format")
raise KeyError( raise KeyError(
......
...@@ -141,15 +141,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel ...@@ -141,15 +141,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
@require_torch @require_torch
def test_small_model_pt_seq2seq(self): def test_small_model_pt_seq2seq(self):
model_id = "hf-internal-testing/tiny-random-speech-encoder-decoder"
tokenizer = AutoTokenizer.from_pretrained(model_id)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
speech_recognizer = pipeline( speech_recognizer = pipeline(
task="automatic-speech-recognition", model="hf-internal-testing/tiny-random-speech-encoder-decoder",
model=model_id,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
framework="pt", framework="pt",
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment