Unverified Commit c1c2d68d authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix name and get_class method in AutoFeatureExtractor (#13385)

parent a105c9b7
...@@ -26,6 +26,7 @@ from .configuration_auto import ( ...@@ -26,6 +26,7 @@ from .configuration_auto import (
CONFIG_MAPPING_NAMES, CONFIG_MAPPING_NAMES,
AutoConfig, AutoConfig,
config_class_to_model_type, config_class_to_model_type,
model_type_to_module_name,
replace_list_option_in_docstrings, replace_list_option_in_docstrings,
) )
...@@ -40,7 +41,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict( ...@@ -40,7 +41,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
("wav2vec2", "Wav2Vec2FeatureExtractor"), ("wav2vec2", "Wav2Vec2FeatureExtractor"),
("detr", "DetrFeatureExtractor"), ("detr", "DetrFeatureExtractor"),
("layoutlmv2", "LayoutLMv2FeatureExtractor"), ("layoutlmv2", "LayoutLMv2FeatureExtractor"),
("clip", "ClipFeatureExtractor"), ("clip", "CLIPFeatureExtractor"),
] ]
) )
...@@ -50,10 +51,13 @@ FEATURE_EXTRACTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FEATURE_EXTRA ...@@ -50,10 +51,13 @@ FEATURE_EXTRACTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FEATURE_EXTRA
def feature_extractor_class_from_name(class_name: str): def feature_extractor_class_from_name(class_name: str):
for module_name, extractors in FEATURE_EXTRACTOR_MAPPING_NAMES.items(): for module_name, extractors in FEATURE_EXTRACTOR_MAPPING_NAMES.items():
if class_name in extractors: if class_name in extractors:
break module_name = model_type_to_module_name(module_name)
module = importlib.import_module(f".{module_name}", "transformers.models") module = importlib.import_module(f".{module_name}", "transformers.models")
return getattr(module, class_name) return getattr(module, class_name)
break
return None
class AutoFeatureExtractor: class AutoFeatureExtractor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment