Unverified Commit 99ba36e7 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Clean up auto mapping names (#21903)



* add new test

* fix after new test

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 50a8ed3e
......@@ -43,7 +43,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("blenderbot", "BlenderbotModel"),
("blenderbot-small", "BlenderbotSmallModel"),
("blip", "BlipModel"),
("blip_2", "Blip2Model"),
("blip-2", "Blip2Model"),
("bloom", "BloomModel"),
("bridgetower", "BridgeTowerModel"),
("camembert", "CamembertModel"),
......@@ -64,7 +64,6 @@ MODEL_MAPPING_NAMES = OrderedDict(
("deberta", "DebertaModel"),
("deberta-v2", "DebertaV2Model"),
("decision_transformer", "DecisionTransformerModel"),
("decision_transformer_gpt2", "DecisionTransformerGPT2Model"),
("deformable_detr", "DeformableDetrModel"),
("deit", "DeiTModel"),
("deta", "DetaModel"),
......@@ -128,7 +127,6 @@ MODEL_MAPPING_NAMES = OrderedDict(
("mvp", "MvpModel"),
("nat", "NatModel"),
("nezha", "NezhaModel"),
("nllb", "M2M100Model"),
("nystromformer", "NystromformerModel"),
("oneformer", "OneFormerModel"),
("openai-gpt", "OpenAIGPTModel"),
......
......@@ -56,7 +56,6 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("hubert", "Wav2Vec2Processor"),
("layoutlmv2", "LayoutLMv2Processor"),
("layoutlmv3", "LayoutLMv3Processor"),
("layoutxlm", "LayoutXLMProcessor"),
("markuplm", "MarkupLMProcessor"),
("oneformer", "OneFormerProcessor"),
("owlvit", "OwlViTProcessor"),
......@@ -72,7 +71,6 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"),
("wav2vec2", "Wav2Vec2Processor"),
("wav2vec2-conformer", "Wav2Vec2Processor"),
("wav2vec2_with_lm", "Wav2Vec2ProcessorWithLM"),
("wavlm", "Wav2Vec2Processor"),
("whisper", "WhisperProcessor"),
("xclip", "XCLIPProcessor"),
......
......@@ -23,6 +23,7 @@ from pathlib import Path
from transformers import is_flax_available, is_tf_available, is_torch_available
from transformers.models.auto import get_values
from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES
from transformers.models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING_NAMES
from transformers.models.auto.image_processing_auto import IMAGE_PROCESSOR_MAPPING_NAMES
from transformers.models.auto.modeling_auto import MODEL_MAPPING_NAMES
......@@ -646,6 +647,31 @@ def check_all_auto_object_names_being_defined():
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
def check_all_auto_mapping_names_in_config_mapping_names():
"""Check all keys defined in auto mappings (mappings of names) appear in `CONFIG_MAPPING_NAMES`."""
failures = []
# `TOKENIZER_PROCESSOR_MAPPING_NAMES` and `AutoTokenizer` is special, and don't need to follow the rule.
mapping_to_check = {
"IMAGE_PROCESSOR_MAPPING_NAMES": IMAGE_PROCESSOR_MAPPING_NAMES,
"FEATURE_EXTRACTOR_MAPPING_NAMES": FEATURE_EXTRACTOR_MAPPING_NAMES,
"PROCESSOR_MAPPING_NAMES": PROCESSOR_MAPPING_NAMES,
"MODEL_MAPPING_NAMES": MODEL_MAPPING_NAMES,
"TF_MODEL_MAPPING_NAMES": TF_MODEL_MAPPING_NAMES,
"FLAX_MODEL_MAPPING_NAMES": FLAX_MODEL_MAPPING_NAMES,
}
for name, mapping in mapping_to_check.items():
for model_type, class_names in mapping.items():
if model_type not in CONFIG_MAPPING_NAMES:
failures.append(
f"`{model_type}` appears in the mapping `{name}` but it is not defined in the keys of "
"`CONFIG_MAPPING_NAMES`."
)
if len(failures) > 0:
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
_re_decorator = re.compile(r"^\s*@(\S+)\s+$")
......@@ -922,6 +948,8 @@ def check_repo_quality():
check_all_models_are_auto_configured()
print("Checking all names in auto name mappings are defined.")
check_all_auto_object_names_being_defined()
print("Checking all keys in auto name mappings are defined in `CONFIG_MAPPING_NAMES`.")
check_all_auto_mapping_names_in_config_mapping_names()
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment