Unverified Commit c1c2d68d authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix name and get_class method in AutoFeatureExtractor (#13385)

parent a105c9b7
......@@ -26,6 +26,7 @@ from .configuration_auto import (
CONFIG_MAPPING_NAMES,
AutoConfig,
config_class_to_model_type,
model_type_to_module_name,
replace_list_option_in_docstrings,
)
......@@ -40,7 +41,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
("wav2vec2", "Wav2Vec2FeatureExtractor"),
("detr", "DetrFeatureExtractor"),
("layoutlmv2", "LayoutLMv2FeatureExtractor"),
("clip", "ClipFeatureExtractor"),
("clip", "CLIPFeatureExtractor"),
]
)
......@@ -50,10 +51,13 @@ FEATURE_EXTRACTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FEATURE_EXTRA
def feature_extractor_class_from_name(class_name: str):
for module_name, extractors in FEATURE_EXTRACTOR_MAPPING_NAMES.items():
if class_name in extractors:
break
module_name = model_type_to_module_name(module_name)
module = importlib.import_module(f".{module_name}", "transformers.models")
return getattr(module, class_name)
break
return None
class AutoFeatureExtractor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment