"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "f5cd27694a0c7d0036954c8350f774a5c1181a57"
Unverified Commit 586dcf6b authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fixing issue where generic model types wouldn't load properly with the pipeline (#18392)

* Adding a better error message when the model is improperly configured

within transformers.

* Update src/transformers/pipelines/__init__.py

* Black version.

* Overriding task aliases so that tokenizer+feature_extractor

values are correct.

* Fixing task aliases by overriding their names early

* X.

* Fixing feature-extraction.

* black again.

* Normalizing `translation` too.

* Fixing last few corner cases.

translation need to use its non normalized name (translation_XX_to_YY,
so that the task_specific_params are correctly overloaded).
This can be removed and cleaned up in a later PR.

`speech-encode-decoder` actually REQUIRES to pass a `tokenizer` manually
so the error needs to be discarded when the `tokenizer` is already
there.

* doc-builder fix.

* Fixing the real issue.

* Removing dead code.

* Do not import the actual config classes.
parent 14928921
......@@ -311,6 +311,11 @@ SUPPORTED_TASKS = {
NO_FEATURE_EXTRACTOR_TASKS = set()
NO_TOKENIZER_TASKS = set()
# Those model configs are special, they are generic over their task, meaning
# any tokenizer/feature_extractor might be use for a given model so we cannot
# use the statically defined TOKENIZER_MAPPING and FEATURE_EXTRACTOR_MAPPING to
# see if the model defines such objects or not.
MULTI_MODEL_CONFIGS = {"VisionTextDualEncoderConfig", "SpeechEncoderDecoderConfig"}
for task, values in SUPPORTED_TASKS.items():
if values["type"] == "text":
NO_FEATURE_EXTRACTOR_TASKS.add(task)
......@@ -380,8 +385,9 @@ def check_task(task: str) -> Tuple[Dict, Any]:
- `"zero-shot-image-classification"`
Returns:
(task_defaults`dict`, task_options: (`tuple`, None)) The actual dictionary required to initialize the pipeline
and some extra task options for parametrized tasks like "translation_XX_to_YY"
(normalized_task: `str`, task_defaults: `dict`, task_options: (`tuple`, None)) The normalized task name
(removed alias and options). The actual dictionary required to initialize the pipeline and some extra task
options for parametrized tasks like "translation_XX_to_YY"
"""
......@@ -614,7 +620,7 @@ def pipeline(
model, module_file + ".py", class_name, revision=revision, use_auth_token=use_auth_token
)
else:
targeted_task, task_options = check_task(task)
normalized_task, targeted_task, task_options = check_task(task)
if pipeline_class is None:
pipeline_class = targeted_task["impl"]
......@@ -667,12 +673,36 @@ def pipeline(
load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None
if (
tokenizer is None
and not load_tokenizer
and normalized_task not in NO_TOKENIZER_TASKS
# Using class name to avoid importing the real class.
and model_config.__class__.__name__ in MULTI_MODEL_CONFIGS
):
# This is a special category of models, that are fusions of multiple models
# so the model_config might not define a tokenizer, but it seems to be
# necessary for the task, so we're force-trying to load it.
load_tokenizer = True
if (
feature_extractor is None
and not load_feature_extractor
and normalized_task not in NO_FEATURE_EXTRACTOR_TASKS
# Using class name to avoid importing the real class.
and model_config.__class__.__name__ in MULTI_MODEL_CONFIGS
):
# This is a special category of models, that are fusions of multiple models
# so the model_config might not define a tokenizer, but it seems to be
# necessary for the task, so we're force-trying to load it.
load_feature_extractor = True
if task in NO_TOKENIZER_TASKS:
# These will never require a tokenizer.
# the model on the other hand might have a tokenizer, but
# the files could be missing from the hub, instead of failing
# on such repos, we just force to not load it.
load_tokenizer = False
if task in NO_FEATURE_EXTRACTOR_TASKS:
load_feature_extractor = False
......
......@@ -630,7 +630,6 @@ class PipedPipelineDataFormat(PipelineDataFormat):
for line in sys.stdin:
# Split for multi-columns
if "\t" in line:
line = line.split("\t")
if self.column:
# Dictionary to map arguments
......@@ -752,7 +751,6 @@ class Pipeline(_ScikitCompat):
binary_output: bool = False,
**kwargs,
):
if framework is None:
framework, model = infer_framework_load_model(model, config=model.config)
......@@ -1123,18 +1121,19 @@ class PipelineRegistry:
supported_task.sort()
return supported_task
def check_task(self, task: str) -> Tuple[Dict, Any]:
def check_task(self, task: str) -> Tuple[str, Dict, Any]:
if task in self.task_aliases:
task = self.task_aliases[task]
if task in self.supported_tasks:
targeted_task = self.supported_tasks[task]
return targeted_task, None
return task, targeted_task, None
if task.startswith("translation"):
tokens = task.split("_")
if len(tokens) == 4 and tokens[0] == "translation" and tokens[2] == "to":
targeted_task = self.supported_tasks["translation"]
return targeted_task, (tokens[1], tokens[3])
task = "translation"
return task, targeted_task, (tokens[1], tokens[3])
raise KeyError(f"Invalid translation task {task}, use 'translation_XX_to_YY' format")
raise KeyError(
......
......@@ -141,15 +141,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
@require_torch
def test_small_model_pt_seq2seq(self):
model_id = "hf-internal-testing/tiny-random-speech-encoder-decoder"
tokenizer = AutoTokenizer.from_pretrained(model_id)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model=model_id,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
model="hf-internal-testing/tiny-random-speech-encoder-decoder",
framework="pt",
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment