Unverified Commit e9688875 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Auto processor fix (#14623)



* Add AutoProcessor class
Init and tests
Add doc
Fix init
Update src/transformers/models/auto/processing_auto.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
Reverts to tokenizer or feature extractor when available
Adapt test

* Revert "Adapt test"

This reverts commit bbdde5fab02465f24b54b227390073082cb32093.

* Revert "Reverts to tokenizer or feature extractor when available"

This reverts commit 77659ff5d21b6cc0baf6f443017e35e056a525bb.

* Don't revert everything Lysandre!
Co-authored-by: default avatarSylvain Gugger <sylvain.gugger@gmail.com>
parent cbe60265
...@@ -28,8 +28,6 @@ from .configuration_auto import ( ...@@ -28,8 +28,6 @@ from .configuration_auto import (
model_type_to_module_name, model_type_to_module_name,
replace_list_option_in_docstrings, replace_list_option_in_docstrings,
) )
from .feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING_NAMES, AutoFeatureExtractor
from .tokenization_auto import TOKENIZER_MAPPING_NAMES, AutoTokenizer
PROCESSOR_MAPPING_NAMES = OrderedDict( PROCESSOR_MAPPING_NAMES = OrderedDict(
...@@ -85,9 +83,6 @@ class AutoProcessor: ...@@ -85,9 +83,6 @@ class AutoProcessor:
List options List options
For other types of models, this class will return the appropriate tokenizer (if available) or feature
extractor.
Params: Params:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
This can be either: This can be either:
...@@ -167,24 +162,11 @@ class AutoProcessor: ...@@ -167,24 +162,11 @@ class AutoProcessor:
return processor_class.from_pretrained(pretrained_model_name_or_path, **kwargs) return processor_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
model_type = config_class_to_model_type(type(config).__name__) model_type = config_class_to_model_type(type(config).__name__)
if model_type is not None and model_type in PROCESSOR_MAPPING_NAMES: if model_type is not None:
return PROCESSOR_MAPPING[type(config)].from_pretrained(pretrained_model_name_or_path, **kwargs) return PROCESSOR_MAPPING[type(config)].from_pretrained(pretrained_model_name_or_path, **kwargs)
# At this stage there doesn't seem to be a `Processor` class available for this model, so let's try a tokenizer
if model_type in TOKENIZER_MAPPING_NAMES:
return AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
# At this stage there doesn't seem to be a `Processor` class available for this model, so let's try a tokenizer
if model_type in FEATURE_EXTRACTOR_MAPPING_NAMES:
return AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
all_model_types = set(
PROCESSOR_MAPPING_NAMES.keys() + TOKENIZER_MAPPING_NAMES.keys() + FEATURE_EXTRACTOR_MAPPING_NAMES.keys()
)
all_model_types = list(all_model_types)
all_model_types.sort()
raise ValueError( raise ValueError(
f"Unrecognized processor in {pretrained_model_name_or_path}. Should have a `processor_type` key in " f"Unrecognized processor in {pretrained_model_name_or_path}. Should have a `processor_type` key in "
f"its {FEATURE_EXTRACTOR_NAME}, or one of the following `model_type` keys in its {CONFIG_NAME}: " f"its {FEATURE_EXTRACTOR_NAME}, or one of the following `model_type` keys in its {CONFIG_NAME}: "
f"{', '.join(all_model_types)}" f"{', '.join(c for c in PROCESSOR_MAPPING_NAMES.keys())}"
) )
...@@ -17,8 +17,7 @@ import os ...@@ -17,8 +17,7 @@ import os
import tempfile import tempfile
import unittest import unittest
from transformers import AutoProcessor, BeitFeatureExtractor, BertTokenizerFast, Wav2Vec2Config, Wav2Vec2Processor from transformers import AutoProcessor, Wav2Vec2Config, Wav2Vec2Processor
from transformers.testing_utils import require_torch
SAMPLE_PROCESSOR_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures") SAMPLE_PROCESSOR_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
...@@ -45,12 +44,3 @@ class AutoFeatureExtractorTest(unittest.TestCase): ...@@ -45,12 +44,3 @@ class AutoFeatureExtractorTest(unittest.TestCase):
processor = AutoProcessor.from_pretrained(tmpdirname) processor = AutoProcessor.from_pretrained(tmpdirname)
self.assertIsInstance(processor, Wav2Vec2Processor) self.assertIsInstance(processor, Wav2Vec2Processor)
def test_auto_processor_reverts_to_tokenizer(self):
processor = AutoProcessor.from_pretrained("bert-base-cased")
self.assertIsInstance(processor, BertTokenizerFast)
@require_torch
def test_auto_processor_reverts_to_feature_extractor(self):
processor = AutoProcessor.from_pretrained("microsoft/beit-base-patch16-224")
self.assertIsInstance(processor, BeitFeatureExtractor)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment