Unverified Commit ddb1a47e authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Automatically sort auto mappings (#17250)

* Automatically sort auto mappings

* Better class extraction

* Some auto class magic

* Adapt test and underlying behavior

* Remove re-used config

* Quality
parent 2f611f85
......@@ -857,6 +857,7 @@ jobs:
- run: black --check --preview examples tests src utils
- run: isort --check-only examples tests src utils
- run: python utils/custom_init_isort.py --check_only
- run: python utils/sort_auto_mappings.py --check_only
- run: flake8 examples tests src utils
- run: doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source
......
......@@ -48,6 +48,7 @@ quality:
black --check --preview $(check_dirs)
isort --check-only $(check_dirs)
python utils/custom_init_isort.py --check_only
python utils/sort_auto_mappings.py --check_only
flake8 $(check_dirs)
doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source
......@@ -55,6 +56,7 @@ quality:
extra_style_checks:
python utils/custom_init_isort.py
python utils/sort_auto_mappings.py
doc-builder style src/transformers docs/source --max_len 119 --path_to_docs docs/source
# this target runs checks on all files and potentially modifies some of them
......
......@@ -259,7 +259,6 @@ Flax), PyTorch, and/or TensorFlow.
| Swin | ❌ | ❌ | ✅ | ❌ | ❌ |
| T5 | ✅ | ✅ | ✅ | ✅ | ✅ |
| TAPAS | ✅ | ❌ | ✅ | ✅ | ❌ |
| TAPEX | ✅ | ✅ | ✅ | ✅ | ✅ |
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ |
| TrOCR | ❌ | ❌ | ✅ | ❌ | ❌ |
| UniSpeech | ❌ | ❌ | ✅ | ❌ | ❌ |
......
......@@ -74,7 +74,6 @@ Ready-made configurations include the following architectures:
- RoBERTa
- RoFormer
- T5
- TAPEX
- ViT
- XLM-RoBERTa
- XLM-RoBERTa-XL
......
......@@ -560,10 +560,17 @@ class _LazyAutoMapping(OrderedDict):
if key in self._extra_content:
return self._extra_content[key]
model_type = self._reverse_config_mapping[key.__name__]
if model_type not in self._model_mapping:
raise KeyError(key)
model_name = self._model_mapping[model_type]
return self._load_attr_from_module(model_type, model_name)
if model_type in self._model_mapping:
model_name = self._model_mapping[model_type]
return self._load_attr_from_module(model_type, model_name)
# Maybe there was several model types associated with this config.
model_types = [k for k, v in self._config_mapping.items() if v == key.__name__]
for mtype in model_types:
if mtype in self._model_mapping:
model_name = self._model_mapping[mtype]
return self._load_attr_from_module(mtype, model_name)
raise KeyError(key)
def _load_attr_from_module(self, model_type, attr):
module_name = model_type_to_module_name(model_type)
......
......@@ -38,30 +38,30 @@ logger = logging.get_logger(__name__)
FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
[
("beit", "BeitFeatureExtractor"),
("detr", "DetrFeatureExtractor"),
("deit", "DeiTFeatureExtractor"),
("hubert", "Wav2Vec2FeatureExtractor"),
("speech_to_text", "Speech2TextFeatureExtractor"),
("vit", "ViTFeatureExtractor"),
("wav2vec2", "Wav2Vec2FeatureExtractor"),
("detr", "DetrFeatureExtractor"),
("layoutlmv2", "LayoutLMv2FeatureExtractor"),
("clip", "CLIPFeatureExtractor"),
("flava", "FlavaFeatureExtractor"),
("perceiver", "PerceiverFeatureExtractor"),
("swin", "ViTFeatureExtractor"),
("vit_mae", "ViTFeatureExtractor"),
("segformer", "SegformerFeatureExtractor"),
("convnext", "ConvNextFeatureExtractor"),
("van", "ConvNextFeatureExtractor"),
("resnet", "ConvNextFeatureExtractor"),
("regnet", "ConvNextFeatureExtractor"),
("poolformer", "PoolFormerFeatureExtractor"),
("maskformer", "MaskFormerFeatureExtractor"),
("data2vec-audio", "Wav2Vec2FeatureExtractor"),
("data2vec-vision", "BeitFeatureExtractor"),
("deit", "DeiTFeatureExtractor"),
("detr", "DetrFeatureExtractor"),
("detr", "DetrFeatureExtractor"),
("dpt", "DPTFeatureExtractor"),
("flava", "FlavaFeatureExtractor"),
("glpn", "GLPNFeatureExtractor"),
("hubert", "Wav2Vec2FeatureExtractor"),
("layoutlmv2", "LayoutLMv2FeatureExtractor"),
("maskformer", "MaskFormerFeatureExtractor"),
("perceiver", "PerceiverFeatureExtractor"),
("poolformer", "PoolFormerFeatureExtractor"),
("regnet", "ConvNextFeatureExtractor"),
("resnet", "ConvNextFeatureExtractor"),
("segformer", "SegformerFeatureExtractor"),
("speech_to_text", "Speech2TextFeatureExtractor"),
("swin", "ViTFeatureExtractor"),
("van", "ConvNextFeatureExtractor"),
("vit", "ViTFeatureExtractor"),
("vit_mae", "ViTFeatureExtractor"),
("wav2vec2", "Wav2Vec2FeatureExtractor"),
("yolos", "YolosFeatureExtractor"),
]
)
......@@ -75,8 +75,10 @@ def feature_extractor_class_from_name(class_name: str):
module_name = model_type_to_module_name(module_name)
module = importlib.import_module(f".{module_name}", "transformers.models")
return getattr(module, class_name)
break
try:
return getattr(module, class_name)
except AttributeError:
continue
for config, extractor in FEATURE_EXTRACTOR_MAPPING._extra_content.items():
if getattr(extractor, "__name__", None) == class_name:
......
......@@ -28,31 +28,31 @@ logger = logging.get_logger(__name__)
FLAX_MODEL_MAPPING_NAMES = OrderedDict(
[
# Base model mapping
("xglm", "FlaxXGLMModel"),
("blenderbot-small", "FlaxBlenderbotSmallModel"),
("pegasus", "FlaxPegasusModel"),
("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"),
("distilbert", "FlaxDistilBertModel"),
("albert", "FlaxAlbertModel"),
("roberta", "FlaxRobertaModel"),
("xlm-roberta", "FlaxXLMRobertaModel"),
("bert", "FlaxBertModel"),
("bart", "FlaxBartModel"),
("beit", "FlaxBeitModel"),
("bert", "FlaxBertModel"),
("big_bird", "FlaxBigBirdModel"),
("bart", "FlaxBartModel"),
("blenderbot", "FlaxBlenderbotModel"),
("blenderbot-small", "FlaxBlenderbotSmallModel"),
("clip", "FlaxCLIPModel"),
("distilbert", "FlaxDistilBertModel"),
("electra", "FlaxElectraModel"),
("gpt2", "FlaxGPT2Model"),
("gpt_neo", "FlaxGPTNeoModel"),
("gptj", "FlaxGPTJModel"),
("electra", "FlaxElectraModel"),
("clip", "FlaxCLIPModel"),
("vit", "FlaxViTModel"),
("marian", "FlaxMarianModel"),
("mbart", "FlaxMBartModel"),
("t5", "FlaxT5Model"),
("mt5", "FlaxMT5Model"),
("wav2vec2", "FlaxWav2Vec2Model"),
("marian", "FlaxMarianModel"),
("blenderbot", "FlaxBlenderbotModel"),
("pegasus", "FlaxPegasusModel"),
("roberta", "FlaxRobertaModel"),
("roformer", "FlaxRoFormerModel"),
("t5", "FlaxT5Model"),
("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"),
("vit", "FlaxViTModel"),
("wav2vec2", "FlaxWav2Vec2Model"),
("xglm", "FlaxXGLMModel"),
("xlm-roberta", "FlaxXLMRobertaModel"),
]
)
......@@ -60,56 +60,56 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
[
# Model for pre-training mapping
("albert", "FlaxAlbertForPreTraining"),
("roberta", "FlaxRobertaForMaskedLM"),
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
("bart", "FlaxBartForConditionalGeneration"),
("bert", "FlaxBertForPreTraining"),
("big_bird", "FlaxBigBirdForPreTraining"),
("bart", "FlaxBartForConditionalGeneration"),
("electra", "FlaxElectraForPreTraining"),
("mbart", "FlaxMBartForConditionalGeneration"),
("t5", "FlaxT5ForConditionalGeneration"),
("mt5", "FlaxMT5ForConditionalGeneration"),
("wav2vec2", "FlaxWav2Vec2ForPreTraining"),
("roberta", "FlaxRobertaForMaskedLM"),
("roformer", "FlaxRoFormerForMaskedLM"),
("t5", "FlaxT5ForConditionalGeneration"),
("wav2vec2", "FlaxWav2Vec2ForPreTraining"),
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
]
)
FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Masked LM mapping
("distilbert", "FlaxDistilBertForMaskedLM"),
("albert", "FlaxAlbertForMaskedLM"),
("roberta", "FlaxRobertaForMaskedLM"),
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
("bart", "FlaxBartForConditionalGeneration"),
("bert", "FlaxBertForMaskedLM"),
("big_bird", "FlaxBigBirdForMaskedLM"),
("bart", "FlaxBartForConditionalGeneration"),
("distilbert", "FlaxDistilBertForMaskedLM"),
("electra", "FlaxElectraForMaskedLM"),
("mbart", "FlaxMBartForConditionalGeneration"),
("roberta", "FlaxRobertaForMaskedLM"),
("roformer", "FlaxRoFormerForMaskedLM"),
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
]
)
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
("blenderbot-small", "FlaxBlenderbotSmallForConditionalGeneration"),
("pegasus", "FlaxPegasusForConditionalGeneration"),
("bart", "FlaxBartForConditionalGeneration"),
("blenderbot", "FlaxBlenderbotForConditionalGeneration"),
("blenderbot-small", "FlaxBlenderbotSmallForConditionalGeneration"),
("encoder-decoder", "FlaxEncoderDecoderModel"),
("marian", "FlaxMarianMTModel"),
("mbart", "FlaxMBartForConditionalGeneration"),
("t5", "FlaxT5ForConditionalGeneration"),
("mt5", "FlaxMT5ForConditionalGeneration"),
("marian", "FlaxMarianMTModel"),
("encoder-decoder", "FlaxEncoderDecoderModel"),
("blenderbot", "FlaxBlenderbotForConditionalGeneration"),
("pegasus", "FlaxPegasusForConditionalGeneration"),
("t5", "FlaxT5ForConditionalGeneration"),
]
)
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Image-classsification
("vit", "FlaxViTForImageClassification"),
("beit", "FlaxBeitForImageClassification"),
("vit", "FlaxViTForImageClassification"),
]
)
......@@ -122,75 +122,75 @@ FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Causal LM mapping
("gpt2", "FlaxGPT2LMHeadModel"),
("gpt_neo", "FlaxGPTNeoForCausalLM"),
("gptj", "FlaxGPTJForCausalLM"),
("xglm", "FlaxXGLMForCausalLM"),
("bart", "FlaxBartForCausalLM"),
("bert", "FlaxBertForCausalLM"),
("roberta", "FlaxRobertaForCausalLM"),
("big_bird", "FlaxBigBirdForCausalLM"),
("electra", "FlaxElectraForCausalLM"),
("gpt2", "FlaxGPT2LMHeadModel"),
("gpt_neo", "FlaxGPTNeoForCausalLM"),
("gptj", "FlaxGPTJForCausalLM"),
("roberta", "FlaxRobertaForCausalLM"),
("xglm", "FlaxXGLMForCausalLM"),
]
)
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Sequence Classification mapping
("distilbert", "FlaxDistilBertForSequenceClassification"),
("albert", "FlaxAlbertForSequenceClassification"),
("roberta", "FlaxRobertaForSequenceClassification"),
("xlm-roberta", "FlaxXLMRobertaForSequenceClassification"),
("bart", "FlaxBartForSequenceClassification"),
("bert", "FlaxBertForSequenceClassification"),
("big_bird", "FlaxBigBirdForSequenceClassification"),
("bart", "FlaxBartForSequenceClassification"),
("distilbert", "FlaxDistilBertForSequenceClassification"),
("electra", "FlaxElectraForSequenceClassification"),
("mbart", "FlaxMBartForSequenceClassification"),
("roberta", "FlaxRobertaForSequenceClassification"),
("roformer", "FlaxRoFormerForSequenceClassification"),
("xlm-roberta", "FlaxXLMRobertaForSequenceClassification"),
]
)
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
# Model for Question Answering mapping
("distilbert", "FlaxDistilBertForQuestionAnswering"),
("albert", "FlaxAlbertForQuestionAnswering"),
("roberta", "FlaxRobertaForQuestionAnswering"),
("xlm-roberta", "FlaxXLMRobertaForQuestionAnswering"),
("bart", "FlaxBartForQuestionAnswering"),
("bert", "FlaxBertForQuestionAnswering"),
("big_bird", "FlaxBigBirdForQuestionAnswering"),
("bart", "FlaxBartForQuestionAnswering"),
("distilbert", "FlaxDistilBertForQuestionAnswering"),
("electra", "FlaxElectraForQuestionAnswering"),
("mbart", "FlaxMBartForQuestionAnswering"),
("roberta", "FlaxRobertaForQuestionAnswering"),
("roformer", "FlaxRoFormerForQuestionAnswering"),
("xlm-roberta", "FlaxXLMRobertaForQuestionAnswering"),
]
)
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Token Classification mapping
("distilbert", "FlaxDistilBertForTokenClassification"),
("albert", "FlaxAlbertForTokenClassification"),
("roberta", "FlaxRobertaForTokenClassification"),
("xlm-roberta", "FlaxXLMRobertaForTokenClassification"),
("bert", "FlaxBertForTokenClassification"),
("big_bird", "FlaxBigBirdForTokenClassification"),
("distilbert", "FlaxDistilBertForTokenClassification"),
("electra", "FlaxElectraForTokenClassification"),
("roberta", "FlaxRobertaForTokenClassification"),
("roformer", "FlaxRoFormerForTokenClassification"),
("xlm-roberta", "FlaxXLMRobertaForTokenClassification"),
]
)
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
[
# Model for Multiple Choice mapping
("distilbert", "FlaxDistilBertForMultipleChoice"),
("albert", "FlaxAlbertForMultipleChoice"),
("roberta", "FlaxRobertaForMultipleChoice"),
("xlm-roberta", "FlaxXLMRobertaForMultipleChoice"),
("bert", "FlaxBertForMultipleChoice"),
("big_bird", "FlaxBigBirdForMultipleChoice"),
("distilbert", "FlaxDistilBertForMultipleChoice"),
("electra", "FlaxElectraForMultipleChoice"),
("roberta", "FlaxRobertaForMultipleChoice"),
("roformer", "FlaxRoFormerForMultipleChoice"),
("xlm-roberta", "FlaxXLMRobertaForMultipleChoice"),
]
)
......
......@@ -41,17 +41,17 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("flava", "FLAVAProcessor"),
("layoutlmv2", "LayoutLMv2Processor"),
("layoutxlm", "LayoutXLMProcessor"),
("sew", "Wav2Vec2Processor"),
("sew-d", "Wav2Vec2Processor"),
("speech_to_text", "Speech2TextProcessor"),
("speech_to_text_2", "Speech2Text2Processor"),
("trocr", "TrOCRProcessor"),
("wav2vec2", "Wav2Vec2Processor"),
("wav2vec2_with_lm", "Wav2Vec2ProcessorWithLM"),
("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"),
("unispeech", "Wav2Vec2Processor"),
("unispeech-sat", "Wav2Vec2Processor"),
("sew", "Wav2Vec2Processor"),
("sew-d", "Wav2Vec2Processor"),
("vilt", "ViltProcessor"),
("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"),
("wav2vec2", "Wav2Vec2Processor"),
("wav2vec2_with_lm", "Wav2Vec2ProcessorWithLM"),
("wavlm", "Wav2Vec2Processor"),
]
)
......@@ -65,7 +65,10 @@ def processor_class_from_name(class_name: str):
module_name = model_type_to_module_name(module_name)
module = importlib.import_module(f".{module_name}", "transformers.models")
return getattr(module, class_name)
try:
return getattr(module, class_name)
except AttributeError:
continue
for processor in PROCESSOR_MAPPING._extra_content.values():
if getattr(processor, "__name__", None) == class_name:
......
......@@ -46,34 +46,37 @@ if TYPE_CHECKING:
else:
TOKENIZER_MAPPING_NAMES = OrderedDict(
[
("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)),
("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)),
("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)),
(
"t5",
"albert",
(
"T5Tokenizer" if is_sentencepiece_available() else None,
"T5TokenizerFast" if is_tokenizers_available() else None,
"AlbertTokenizer" if is_sentencepiece_available() else None,
"AlbertTokenizerFast" if is_tokenizers_available() else None,
),
),
("bart", ("BartTokenizer", "BartTokenizerFast")),
(
"mt5",
"barthez",
(
"MT5Tokenizer" if is_sentencepiece_available() else None,
"MT5TokenizerFast" if is_tokenizers_available() else None,
"BarthezTokenizer" if is_sentencepiece_available() else None,
"BarthezTokenizerFast" if is_tokenizers_available() else None,
),
),
("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)),
("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)),
("bartpho", ("BartphoTokenizer", None)),
("bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("bert-generation", ("BertGenerationTokenizer" if is_sentencepiece_available() else None, None)),
("bert-japanese", ("BertJapaneseTokenizer", None)),
("bertweet", ("BertweetTokenizer", None)),
(
"albert",
"big_bird",
(
"AlbertTokenizer" if is_sentencepiece_available() else None,
"AlbertTokenizerFast" if is_tokenizers_available() else None,
"BigBirdTokenizer" if is_sentencepiece_available() else None,
"BigBirdTokenizerFast" if is_tokenizers_available() else None,
),
),
("bigbird_pegasus", ("PegasusTokenizer", "PegasusTokenizerFast" if is_tokenizers_available() else None)),
("blenderbot", ("BlenderbotTokenizer", "BlenderbotTokenizerFast")),
("blenderbot-small", ("BlenderbotSmallTokenizer", None)),
("byt5", ("ByT5Tokenizer", None)),
(
"camembert",
(
......@@ -81,136 +84,119 @@ else:
"CamembertTokenizerFast" if is_tokenizers_available() else None,
),
),
("canine", ("CanineTokenizer", None)),
(
"pegasus",
"clip",
(
"PegasusTokenizer" if is_sentencepiece_available() else None,
"PegasusTokenizerFast" if is_tokenizers_available() else None,
"CLIPTokenizer",
"CLIPTokenizerFast" if is_tokenizers_available() else None,
),
),
("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
(
"mbart",
"cpm",
(
"MBartTokenizer" if is_sentencepiece_available() else None,
"MBartTokenizerFast" if is_tokenizers_available() else None,
"CpmTokenizer" if is_sentencepiece_available() else None,
"CpmTokenizerFast" if is_tokenizers_available() else None,
),
),
("ctrl", ("CTRLTokenizer", None)),
("data2vec-text", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
("deberta", ("DebertaTokenizer", "DebertaTokenizerFast" if is_tokenizers_available() else None)),
(
"xlm-roberta",
"deberta-v2",
(
"XLMRobertaTokenizer" if is_sentencepiece_available() else None,
"XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
"DebertaV2Tokenizer" if is_sentencepiece_available() else None,
"DebertaV2TokenizerFast" if is_tokenizers_available() else None,
),
),
("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)),
("blenderbot-small", ("BlenderbotSmallTokenizer", None)),
("blenderbot", ("BlenderbotTokenizer", "BlenderbotTokenizerFast")),
("tapex", ("TapexTokenizer", None)),
("bart", ("BartTokenizer", "BartTokenizerFast")),
("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
("roberta", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)),
(
"reformer",
"dpr",
(
"ReformerTokenizer" if is_sentencepiece_available() else None,
"ReformerTokenizerFast" if is_tokenizers_available() else None,
"DPRQuestionEncoderTokenizer",
"DPRQuestionEncoderTokenizerFast" if is_tokenizers_available() else None,
),
),
("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)),
("flaubert", ("FlaubertTokenizer", None)),
("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
("fsmt", ("FSMTTokenizer", None)),
("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)),
("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)),
("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)),
("hubert", ("Wav2Vec2CTCTokenizer", None)),
("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)),
("layoutlmv2", ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast" if is_tokenizers_available() else None)),
("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)),
("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
("luke", ("LukeTokenizer", None)),
("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)),
("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)),
("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)),
(
"dpr",
"mbart",
(
"DPRQuestionEncoderTokenizer",
"DPRQuestionEncoderTokenizerFast" if is_tokenizers_available() else None,
"MBartTokenizer" if is_sentencepiece_available() else None,
"MBartTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"squeezebert",
("SqueezeBertTokenizer", "SqueezeBertTokenizerFast" if is_tokenizers_available() else None),
),
("bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None)),
("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("opt", ("GPT2Tokenizer", None)),
("transfo-xl", ("TransfoXLTokenizer", None)),
(
"xlnet",
"mbart50",
(
"XLNetTokenizer" if is_sentencepiece_available() else None,
"XLNetTokenizerFast" if is_tokenizers_available() else None,
"MBart50Tokenizer" if is_sentencepiece_available() else None,
"MBart50TokenizerFast" if is_tokenizers_available() else None,
),
),
("flaubert", ("FlaubertTokenizer", None)),
("xlm", ("XLMTokenizer", None)),
("ctrl", ("CTRLTokenizer", None)),
("fsmt", ("FSMTTokenizer", None)),
("bert-generation", ("BertGenerationTokenizer" if is_sentencepiece_available() else None, None)),
("deberta", ("DebertaTokenizer", "DebertaTokenizerFast" if is_tokenizers_available() else None)),
("megatron-bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)),
("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)),
("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)),
(
"deberta-v2",
"mt5",
(
"DebertaV2Tokenizer" if is_sentencepiece_available() else None,
"DebertaV2TokenizerFast" if is_tokenizers_available() else None,
"MT5Tokenizer" if is_sentencepiece_available() else None,
"MT5TokenizerFast" if is_tokenizers_available() else None,
),
),
("rag", ("RagTokenizer", None)),
("xlm-prophetnet", ("XLMProphetNetTokenizer" if is_sentencepiece_available() else None, None)),
("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
("speech_to_text_2", ("Speech2Text2Tokenizer", None)),
("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)),
("prophetnet", ("ProphetNetTokenizer", None)),
("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)),
("tapas", ("TapasTokenizer", None)),
("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
(
"big_bird",
"nystromformer",
(
"BigBirdTokenizer" if is_sentencepiece_available() else None,
"BigBirdTokenizerFast" if is_tokenizers_available() else None,
"AlbertTokenizer" if is_sentencepiece_available() else None,
"AlbertTokenizerFast" if is_tokenizers_available() else None,
),
),
("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("wav2vec2", ("Wav2Vec2CTCTokenizer", None)),
("hubert", ("Wav2Vec2CTCTokenizer", None)),
("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("luke", ("LukeTokenizer", None)),
("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)),
("bigbird_pegasus", ("PegasusTokenizer", "PegasusTokenizerFast" if is_tokenizers_available() else None)),
("canine", ("CanineTokenizer", None)),
("bertweet", ("BertweetTokenizer", None)),
("bert-japanese", ("BertJapaneseTokenizer", None)),
("splinter", ("SplinterTokenizer", "SplinterTokenizerFast")),
("byt5", ("ByT5Tokenizer", None)),
("openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None)),
("opt", ("GPT2Tokenizer", None)),
(
"cpm",
"pegasus",
(
"CpmTokenizer" if is_sentencepiece_available() else None,
"CpmTokenizerFast" if is_tokenizers_available() else None,
"PegasusTokenizer" if is_sentencepiece_available() else None,
"PegasusTokenizerFast" if is_tokenizers_available() else None,
),
),
("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)),
("phobert", ("PhobertTokenizer", None)),
("bartpho", ("BartphoTokenizer", None)),
(
"barthez",
"perceiver",
(
"BarthezTokenizer" if is_sentencepiece_available() else None,
"BarthezTokenizerFast" if is_tokenizers_available() else None,
"PerceiverTokenizer",
None,
),
),
("phobert", ("PhobertTokenizer", None)),
("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
("prophetnet", ("ProphetNetTokenizer", None)),
("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("rag", ("RagTokenizer", None)),
("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)),
(
"mbart50",
"reformer",
(
"MBart50Tokenizer" if is_sentencepiece_available() else None,
"MBart50TokenizerFast" if is_tokenizers_available() else None,
"ReformerTokenizer" if is_sentencepiece_available() else None,
"ReformerTokenizerFast" if is_tokenizers_available() else None,
),
),
(
......@@ -220,21 +206,29 @@ else:
"RemBertTokenizerFast" if is_tokenizers_available() else None,
),
),
("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)),
("roberta", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)),
("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
("speech_to_text_2", ("Speech2Text2Tokenizer", None)),
("splinter", ("SplinterTokenizer", "SplinterTokenizerFast")),
(
"clip",
(
"CLIPTokenizer",
"CLIPTokenizerFast" if is_tokenizers_available() else None,
),
"squeezebert",
("SqueezeBertTokenizer", "SqueezeBertTokenizerFast" if is_tokenizers_available() else None),
),
("wav2vec2_phoneme", ("Wav2Vec2PhonemeCTCTokenizer", None)),
(
"perceiver",
"t5",
(
"PerceiverTokenizer",
None,
"T5Tokenizer" if is_sentencepiece_available() else None,
"T5TokenizerFast" if is_tokenizers_available() else None,
),
),
("tapas", ("TapasTokenizer", None)),
("tapex", ("TapexTokenizer", None)),
("transfo-xl", ("TransfoXLTokenizer", None)),
("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("wav2vec2", ("Wav2Vec2CTCTokenizer", None)),
("wav2vec2_phoneme", ("Wav2Vec2PhonemeCTCTokenizer", None)),
(
"xglm",
(
......@@ -242,16 +236,23 @@ else:
"XGLMTokenizerFast" if is_tokenizers_available() else None,
),
),
("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("megatron-bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("xlm", ("XLMTokenizer", None)),
("xlm-prophetnet", ("XLMProphetNetTokenizer" if is_sentencepiece_available() else None, None)),
(
"nystromformer",
"xlm-roberta",
(
"AlbertTokenizer" if is_sentencepiece_available() else None,
"AlbertTokenizerFast" if is_tokenizers_available() else None,
"XLMRobertaTokenizer" if is_sentencepiece_available() else None,
"XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
),
),
("xlm-roberta-xl", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
(
"xlnet",
(
"XLNetTokenizer" if is_sentencepiece_available() else None,
"XLNetTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"yoso",
(
......@@ -259,7 +260,6 @@ else:
"AlbertTokenizerFast" if is_tokenizers_available() else None,
),
),
("data2vec-text", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
]
)
......@@ -277,7 +277,10 @@ def tokenizer_class_from_name(class_name: str):
module_name = model_type_to_module_name(module_name)
module = importlib.import_module(f".{module_name}", "transformers.models")
return getattr(module, class_name)
try:
return getattr(module, class_name)
except AttributeError:
continue
for config, tokenizers in TOKENIZER_MAPPING._extra_content.items():
for tokenizer in tokenizers:
......
......@@ -14,6 +14,8 @@
# limitations under the License.
import importlib
import json
import os
import sys
import tempfile
import unittest
......@@ -56,14 +58,14 @@ class AutoConfigTest(unittest.TestCase):
self.assertIsInstance(config, RobertaConfig)
def test_pattern_matching_fallback(self):
"""
In cases where config.json doesn't include a model_type,
perform a few safety checks on the config mapping's order.
"""
# no key string should be included in a later key string (typical failure case)
keys = list(CONFIG_MAPPING.keys())
for i, key in enumerate(keys):
self.assertFalse(any(key in later_key for later_key in keys[i + 1 :]))
with tempfile.TemporaryDirectory() as tmp_dir:
# This model name contains bert and roberta, but roberta ends up being picked.
folder = os.path.join(tmp_dir, "fake-roberta")
os.makedirs(folder, exist_ok=True)
with open(os.path.join(folder, "config.json"), "w") as f:
f.write(json.dumps({}))
config = AutoConfig.from_pretrained(folder)
self.assertEqual(type(config), RobertaConfig)
def test_new_config_registration(self):
try:
......
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import re
PATH_TO_AUTO_MODULE = "src/transformers/models/auto"
# re pattern that matches mapping introductions:
# SUPER_MODEL_MAPPING_NAMES = OrderedDict or SUPER_MODEL_MAPPING = OrderedDict
_re_intro_mapping = re.compile("[A-Z_]+_MAPPING(\s+|_[A-Z_]+\s+)=\s+OrderedDict")
# re pattern that matches identifiers in mappings
_re_identifier = re.compile(r'\s*\(\s*"(\S[^"]+)"')
def sort_auto_mapping(fname, overwrite: bool = False):
with open(fname, "r", encoding="utf-8") as f:
content = f.read()
lines = content.split("\n")
new_lines = []
line_idx = 0
while line_idx < len(lines):
if _re_intro_mapping.search(lines[line_idx]) is not None:
indent = len(re.search(r"^(\s*)\S", lines[line_idx]).groups()[0]) + 8
# Start of a new mapping!
while not lines[line_idx].startswith(" " * indent + "("):
new_lines.append(lines[line_idx])
line_idx += 1
blocks = []
while lines[line_idx].strip() != "]":
# Blocks either fit in one line or not
if lines[line_idx].strip() == "(":
start_idx = line_idx
while not lines[line_idx].startswith(" " * indent + ")"):
line_idx += 1
blocks.append("\n".join(lines[start_idx : line_idx + 1]))
else:
blocks.append(lines[line_idx])
line_idx += 1
# Sort blocks by their identifiers
blocks = sorted(blocks, key=lambda x: _re_identifier.search(x).groups()[0])
new_lines += blocks
else:
new_lines.append(lines[line_idx])
line_idx += 1
if overwrite:
with open(fname, "w", encoding="utf-8") as f:
f.write("\n".join(new_lines))
elif "\n".join(new_lines) != content:
return True
def sort_all_auto_mappings(overwrite: bool = False):
fnames = [os.path.join(PATH_TO_AUTO_MODULE, f) for f in os.listdir(PATH_TO_AUTO_MODULE) if f.endswith(".py")]
diffs = [sort_auto_mapping(fname, overwrite=overwrite) for fname in fnames]
if not overwrite and any(diffs):
failures = [f for f, d in zip(fnames, diffs) if d]
raise ValueError(
f"The following files have auto mappings that need sorting: {', '.join(failures)}. Run `make style` to fix"
" this."
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--check_only", action="store_true", help="Whether to only check or fix style.")
args = parser.parse_args()
sort_all_auto_mappings(not args.check_only)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment