Unverified Commit ddb1a47e authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Automatically sort auto mappings (#17250)

* Automatically sort auto mappings

* Better class extraction

* Some auto class magic

* Adapt test and underlying behavior

* Remove re-used config

* Quality
parent 2f611f85
...@@ -857,6 +857,7 @@ jobs: ...@@ -857,6 +857,7 @@ jobs:
- run: black --check --preview examples tests src utils - run: black --check --preview examples tests src utils
- run: isort --check-only examples tests src utils - run: isort --check-only examples tests src utils
- run: python utils/custom_init_isort.py --check_only - run: python utils/custom_init_isort.py --check_only
- run: python utils/sort_auto_mappings.py --check_only
- run: flake8 examples tests src utils - run: flake8 examples tests src utils
- run: doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source - run: doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source
......
...@@ -48,6 +48,7 @@ quality: ...@@ -48,6 +48,7 @@ quality:
black --check --preview $(check_dirs) black --check --preview $(check_dirs)
isort --check-only $(check_dirs) isort --check-only $(check_dirs)
python utils/custom_init_isort.py --check_only python utils/custom_init_isort.py --check_only
python utils/sort_auto_mappings.py --check_only
flake8 $(check_dirs) flake8 $(check_dirs)
doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source
...@@ -55,6 +56,7 @@ quality: ...@@ -55,6 +56,7 @@ quality:
extra_style_checks: extra_style_checks:
python utils/custom_init_isort.py python utils/custom_init_isort.py
python utils/sort_auto_mappings.py
doc-builder style src/transformers docs/source --max_len 119 --path_to_docs docs/source doc-builder style src/transformers docs/source --max_len 119 --path_to_docs docs/source
# this target runs checks on all files and potentially modifies some of them # this target runs checks on all files and potentially modifies some of them
......
...@@ -259,7 +259,6 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -259,7 +259,6 @@ Flax), PyTorch, and/or TensorFlow.
| Swin | ❌ | ❌ | ✅ | ❌ | ❌ | | Swin | ❌ | ❌ | ✅ | ❌ | ❌ |
| T5 | ✅ | ✅ | ✅ | ✅ | ✅ | | T5 | ✅ | ✅ | ✅ | ✅ | ✅ |
| TAPAS | ✅ | ❌ | ✅ | ✅ | ❌ | | TAPAS | ✅ | ❌ | ✅ | ✅ | ❌ |
| TAPEX | ✅ | ✅ | ✅ | ✅ | ✅ |
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ | | Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ |
| TrOCR | ❌ | ❌ | ✅ | ❌ | ❌ | | TrOCR | ❌ | ❌ | ✅ | ❌ | ❌ |
| UniSpeech | ❌ | ❌ | ✅ | ❌ | ❌ | | UniSpeech | ❌ | ❌ | ✅ | ❌ | ❌ |
......
...@@ -74,7 +74,6 @@ Ready-made configurations include the following architectures: ...@@ -74,7 +74,6 @@ Ready-made configurations include the following architectures:
- RoBERTa - RoBERTa
- RoFormer - RoFormer
- T5 - T5
- TAPEX
- ViT - ViT
- XLM-RoBERTa - XLM-RoBERTa
- XLM-RoBERTa-XL - XLM-RoBERTa-XL
......
...@@ -560,11 +560,18 @@ class _LazyAutoMapping(OrderedDict): ...@@ -560,11 +560,18 @@ class _LazyAutoMapping(OrderedDict):
if key in self._extra_content: if key in self._extra_content:
return self._extra_content[key] return self._extra_content[key]
model_type = self._reverse_config_mapping[key.__name__] model_type = self._reverse_config_mapping[key.__name__]
if model_type not in self._model_mapping: if model_type in self._model_mapping:
raise KeyError(key)
model_name = self._model_mapping[model_type] model_name = self._model_mapping[model_type]
return self._load_attr_from_module(model_type, model_name) return self._load_attr_from_module(model_type, model_name)
# Maybe there was several model types associated with this config.
model_types = [k for k, v in self._config_mapping.items() if v == key.__name__]
for mtype in model_types:
if mtype in self._model_mapping:
model_name = self._model_mapping[mtype]
return self._load_attr_from_module(mtype, model_name)
raise KeyError(key)
def _load_attr_from_module(self, model_type, attr): def _load_attr_from_module(self, model_type, attr):
module_name = model_type_to_module_name(model_type) module_name = model_type_to_module_name(model_type)
if module_name not in self._modules: if module_name not in self._modules:
......
...@@ -38,30 +38,30 @@ logger = logging.get_logger(__name__) ...@@ -38,30 +38,30 @@ logger = logging.get_logger(__name__)
FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict( FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
[ [
("beit", "BeitFeatureExtractor"), ("beit", "BeitFeatureExtractor"),
("detr", "DetrFeatureExtractor"),
("deit", "DeiTFeatureExtractor"),
("hubert", "Wav2Vec2FeatureExtractor"),
("speech_to_text", "Speech2TextFeatureExtractor"),
("vit", "ViTFeatureExtractor"),
("wav2vec2", "Wav2Vec2FeatureExtractor"),
("detr", "DetrFeatureExtractor"),
("layoutlmv2", "LayoutLMv2FeatureExtractor"),
("clip", "CLIPFeatureExtractor"), ("clip", "CLIPFeatureExtractor"),
("flava", "FlavaFeatureExtractor"),
("perceiver", "PerceiverFeatureExtractor"),
("swin", "ViTFeatureExtractor"),
("vit_mae", "ViTFeatureExtractor"),
("segformer", "SegformerFeatureExtractor"),
("convnext", "ConvNextFeatureExtractor"), ("convnext", "ConvNextFeatureExtractor"),
("van", "ConvNextFeatureExtractor"),
("resnet", "ConvNextFeatureExtractor"),
("regnet", "ConvNextFeatureExtractor"),
("poolformer", "PoolFormerFeatureExtractor"),
("maskformer", "MaskFormerFeatureExtractor"),
("data2vec-audio", "Wav2Vec2FeatureExtractor"), ("data2vec-audio", "Wav2Vec2FeatureExtractor"),
("data2vec-vision", "BeitFeatureExtractor"), ("data2vec-vision", "BeitFeatureExtractor"),
("deit", "DeiTFeatureExtractor"),
("detr", "DetrFeatureExtractor"),
("detr", "DetrFeatureExtractor"),
("dpt", "DPTFeatureExtractor"), ("dpt", "DPTFeatureExtractor"),
("flava", "FlavaFeatureExtractor"),
("glpn", "GLPNFeatureExtractor"), ("glpn", "GLPNFeatureExtractor"),
("hubert", "Wav2Vec2FeatureExtractor"),
("layoutlmv2", "LayoutLMv2FeatureExtractor"),
("maskformer", "MaskFormerFeatureExtractor"),
("perceiver", "PerceiverFeatureExtractor"),
("poolformer", "PoolFormerFeatureExtractor"),
("regnet", "ConvNextFeatureExtractor"),
("resnet", "ConvNextFeatureExtractor"),
("segformer", "SegformerFeatureExtractor"),
("speech_to_text", "Speech2TextFeatureExtractor"),
("swin", "ViTFeatureExtractor"),
("van", "ConvNextFeatureExtractor"),
("vit", "ViTFeatureExtractor"),
("vit_mae", "ViTFeatureExtractor"),
("wav2vec2", "Wav2Vec2FeatureExtractor"),
("yolos", "YolosFeatureExtractor"), ("yolos", "YolosFeatureExtractor"),
] ]
) )
...@@ -75,8 +75,10 @@ def feature_extractor_class_from_name(class_name: str): ...@@ -75,8 +75,10 @@ def feature_extractor_class_from_name(class_name: str):
module_name = model_type_to_module_name(module_name) module_name = model_type_to_module_name(module_name)
module = importlib.import_module(f".{module_name}", "transformers.models") module = importlib.import_module(f".{module_name}", "transformers.models")
try:
return getattr(module, class_name) return getattr(module, class_name)
break except AttributeError:
continue
for config, extractor in FEATURE_EXTRACTOR_MAPPING._extra_content.items(): for config, extractor in FEATURE_EXTRACTOR_MAPPING._extra_content.items():
if getattr(extractor, "__name__", None) == class_name: if getattr(extractor, "__name__", None) == class_name:
......
...@@ -28,31 +28,31 @@ logger = logging.get_logger(__name__) ...@@ -28,31 +28,31 @@ logger = logging.get_logger(__name__)
FLAX_MODEL_MAPPING_NAMES = OrderedDict( FLAX_MODEL_MAPPING_NAMES = OrderedDict(
[ [
# Base model mapping # Base model mapping
("xglm", "FlaxXGLMModel"),
("blenderbot-small", "FlaxBlenderbotSmallModel"),
("pegasus", "FlaxPegasusModel"),
("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"),
("distilbert", "FlaxDistilBertModel"),
("albert", "FlaxAlbertModel"), ("albert", "FlaxAlbertModel"),
("roberta", "FlaxRobertaModel"), ("bart", "FlaxBartModel"),
("xlm-roberta", "FlaxXLMRobertaModel"),
("bert", "FlaxBertModel"),
("beit", "FlaxBeitModel"), ("beit", "FlaxBeitModel"),
("bert", "FlaxBertModel"),
("big_bird", "FlaxBigBirdModel"), ("big_bird", "FlaxBigBirdModel"),
("bart", "FlaxBartModel"), ("blenderbot", "FlaxBlenderbotModel"),
("blenderbot-small", "FlaxBlenderbotSmallModel"),
("clip", "FlaxCLIPModel"),
("distilbert", "FlaxDistilBertModel"),
("electra", "FlaxElectraModel"),
("gpt2", "FlaxGPT2Model"), ("gpt2", "FlaxGPT2Model"),
("gpt_neo", "FlaxGPTNeoModel"), ("gpt_neo", "FlaxGPTNeoModel"),
("gptj", "FlaxGPTJModel"), ("gptj", "FlaxGPTJModel"),
("electra", "FlaxElectraModel"), ("marian", "FlaxMarianModel"),
("clip", "FlaxCLIPModel"),
("vit", "FlaxViTModel"),
("mbart", "FlaxMBartModel"), ("mbart", "FlaxMBartModel"),
("t5", "FlaxT5Model"),
("mt5", "FlaxMT5Model"), ("mt5", "FlaxMT5Model"),
("wav2vec2", "FlaxWav2Vec2Model"), ("pegasus", "FlaxPegasusModel"),
("marian", "FlaxMarianModel"), ("roberta", "FlaxRobertaModel"),
("blenderbot", "FlaxBlenderbotModel"),
("roformer", "FlaxRoFormerModel"), ("roformer", "FlaxRoFormerModel"),
("t5", "FlaxT5Model"),
("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"),
("vit", "FlaxViTModel"),
("wav2vec2", "FlaxWav2Vec2Model"),
("xglm", "FlaxXGLMModel"),
("xlm-roberta", "FlaxXLMRobertaModel"),
] ]
) )
...@@ -60,56 +60,56 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( ...@@ -60,56 +60,56 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
[ [
# Model for pre-training mapping # Model for pre-training mapping
("albert", "FlaxAlbertForPreTraining"), ("albert", "FlaxAlbertForPreTraining"),
("roberta", "FlaxRobertaForMaskedLM"), ("bart", "FlaxBartForConditionalGeneration"),
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
("bert", "FlaxBertForPreTraining"), ("bert", "FlaxBertForPreTraining"),
("big_bird", "FlaxBigBirdForPreTraining"), ("big_bird", "FlaxBigBirdForPreTraining"),
("bart", "FlaxBartForConditionalGeneration"),
("electra", "FlaxElectraForPreTraining"), ("electra", "FlaxElectraForPreTraining"),
("mbart", "FlaxMBartForConditionalGeneration"), ("mbart", "FlaxMBartForConditionalGeneration"),
("t5", "FlaxT5ForConditionalGeneration"),
("mt5", "FlaxMT5ForConditionalGeneration"), ("mt5", "FlaxMT5ForConditionalGeneration"),
("wav2vec2", "FlaxWav2Vec2ForPreTraining"), ("roberta", "FlaxRobertaForMaskedLM"),
("roformer", "FlaxRoFormerForMaskedLM"), ("roformer", "FlaxRoFormerForMaskedLM"),
("t5", "FlaxT5ForConditionalGeneration"),
("wav2vec2", "FlaxWav2Vec2ForPreTraining"),
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
] ]
) )
FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[ [
# Model for Masked LM mapping # Model for Masked LM mapping
("distilbert", "FlaxDistilBertForMaskedLM"),
("albert", "FlaxAlbertForMaskedLM"), ("albert", "FlaxAlbertForMaskedLM"),
("roberta", "FlaxRobertaForMaskedLM"), ("bart", "FlaxBartForConditionalGeneration"),
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
("bert", "FlaxBertForMaskedLM"), ("bert", "FlaxBertForMaskedLM"),
("big_bird", "FlaxBigBirdForMaskedLM"), ("big_bird", "FlaxBigBirdForMaskedLM"),
("bart", "FlaxBartForConditionalGeneration"), ("distilbert", "FlaxDistilBertForMaskedLM"),
("electra", "FlaxElectraForMaskedLM"), ("electra", "FlaxElectraForMaskedLM"),
("mbart", "FlaxMBartForConditionalGeneration"), ("mbart", "FlaxMBartForConditionalGeneration"),
("roberta", "FlaxRobertaForMaskedLM"),
("roformer", "FlaxRoFormerForMaskedLM"), ("roformer", "FlaxRoFormerForMaskedLM"),
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
] ]
) )
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[ [
# Model for Seq2Seq Causal LM mapping # Model for Seq2Seq Causal LM mapping
("blenderbot-small", "FlaxBlenderbotSmallForConditionalGeneration"),
("pegasus", "FlaxPegasusForConditionalGeneration"),
("bart", "FlaxBartForConditionalGeneration"), ("bart", "FlaxBartForConditionalGeneration"),
("blenderbot", "FlaxBlenderbotForConditionalGeneration"),
("blenderbot-small", "FlaxBlenderbotSmallForConditionalGeneration"),
("encoder-decoder", "FlaxEncoderDecoderModel"),
("marian", "FlaxMarianMTModel"),
("mbart", "FlaxMBartForConditionalGeneration"), ("mbart", "FlaxMBartForConditionalGeneration"),
("t5", "FlaxT5ForConditionalGeneration"),
("mt5", "FlaxMT5ForConditionalGeneration"), ("mt5", "FlaxMT5ForConditionalGeneration"),
("marian", "FlaxMarianMTModel"), ("pegasus", "FlaxPegasusForConditionalGeneration"),
("encoder-decoder", "FlaxEncoderDecoderModel"), ("t5", "FlaxT5ForConditionalGeneration"),
("blenderbot", "FlaxBlenderbotForConditionalGeneration"),
] ]
) )
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[ [
# Model for Image-classsification # Model for Image-classsification
("vit", "FlaxViTForImageClassification"),
("beit", "FlaxBeitForImageClassification"), ("beit", "FlaxBeitForImageClassification"),
("vit", "FlaxViTForImageClassification"),
] ]
) )
...@@ -122,75 +122,75 @@ FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict( ...@@ -122,75 +122,75 @@ FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[ [
# Model for Causal LM mapping # Model for Causal LM mapping
("gpt2", "FlaxGPT2LMHeadModel"),
("gpt_neo", "FlaxGPTNeoForCausalLM"),
("gptj", "FlaxGPTJForCausalLM"),
("xglm", "FlaxXGLMForCausalLM"),
("bart", "FlaxBartForCausalLM"), ("bart", "FlaxBartForCausalLM"),
("bert", "FlaxBertForCausalLM"), ("bert", "FlaxBertForCausalLM"),
("roberta", "FlaxRobertaForCausalLM"),
("big_bird", "FlaxBigBirdForCausalLM"), ("big_bird", "FlaxBigBirdForCausalLM"),
("electra", "FlaxElectraForCausalLM"), ("electra", "FlaxElectraForCausalLM"),
("gpt2", "FlaxGPT2LMHeadModel"),
("gpt_neo", "FlaxGPTNeoForCausalLM"),
("gptj", "FlaxGPTJForCausalLM"),
("roberta", "FlaxRobertaForCausalLM"),
("xglm", "FlaxXGLMForCausalLM"),
] ]
) )
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[ [
# Model for Sequence Classification mapping # Model for Sequence Classification mapping
("distilbert", "FlaxDistilBertForSequenceClassification"),
("albert", "FlaxAlbertForSequenceClassification"), ("albert", "FlaxAlbertForSequenceClassification"),
("roberta", "FlaxRobertaForSequenceClassification"), ("bart", "FlaxBartForSequenceClassification"),
("xlm-roberta", "FlaxXLMRobertaForSequenceClassification"),
("bert", "FlaxBertForSequenceClassification"), ("bert", "FlaxBertForSequenceClassification"),
("big_bird", "FlaxBigBirdForSequenceClassification"), ("big_bird", "FlaxBigBirdForSequenceClassification"),
("bart", "FlaxBartForSequenceClassification"), ("distilbert", "FlaxDistilBertForSequenceClassification"),
("electra", "FlaxElectraForSequenceClassification"), ("electra", "FlaxElectraForSequenceClassification"),
("mbart", "FlaxMBartForSequenceClassification"), ("mbart", "FlaxMBartForSequenceClassification"),
("roberta", "FlaxRobertaForSequenceClassification"),
("roformer", "FlaxRoFormerForSequenceClassification"), ("roformer", "FlaxRoFormerForSequenceClassification"),
("xlm-roberta", "FlaxXLMRobertaForSequenceClassification"),
] ]
) )
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[ [
# Model for Question Answering mapping # Model for Question Answering mapping
("distilbert", "FlaxDistilBertForQuestionAnswering"),
("albert", "FlaxAlbertForQuestionAnswering"), ("albert", "FlaxAlbertForQuestionAnswering"),
("roberta", "FlaxRobertaForQuestionAnswering"), ("bart", "FlaxBartForQuestionAnswering"),
("xlm-roberta", "FlaxXLMRobertaForQuestionAnswering"),
("bert", "FlaxBertForQuestionAnswering"), ("bert", "FlaxBertForQuestionAnswering"),
("big_bird", "FlaxBigBirdForQuestionAnswering"), ("big_bird", "FlaxBigBirdForQuestionAnswering"),
("bart", "FlaxBartForQuestionAnswering"), ("distilbert", "FlaxDistilBertForQuestionAnswering"),
("electra", "FlaxElectraForQuestionAnswering"), ("electra", "FlaxElectraForQuestionAnswering"),
("mbart", "FlaxMBartForQuestionAnswering"), ("mbart", "FlaxMBartForQuestionAnswering"),
("roberta", "FlaxRobertaForQuestionAnswering"),
("roformer", "FlaxRoFormerForQuestionAnswering"), ("roformer", "FlaxRoFormerForQuestionAnswering"),
("xlm-roberta", "FlaxXLMRobertaForQuestionAnswering"),
] ]
) )
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[ [
# Model for Token Classification mapping # Model for Token Classification mapping
("distilbert", "FlaxDistilBertForTokenClassification"),
("albert", "FlaxAlbertForTokenClassification"), ("albert", "FlaxAlbertForTokenClassification"),
("roberta", "FlaxRobertaForTokenClassification"),
("xlm-roberta", "FlaxXLMRobertaForTokenClassification"),
("bert", "FlaxBertForTokenClassification"), ("bert", "FlaxBertForTokenClassification"),
("big_bird", "FlaxBigBirdForTokenClassification"), ("big_bird", "FlaxBigBirdForTokenClassification"),
("distilbert", "FlaxDistilBertForTokenClassification"),
("electra", "FlaxElectraForTokenClassification"), ("electra", "FlaxElectraForTokenClassification"),
("roberta", "FlaxRobertaForTokenClassification"),
("roformer", "FlaxRoFormerForTokenClassification"), ("roformer", "FlaxRoFormerForTokenClassification"),
("xlm-roberta", "FlaxXLMRobertaForTokenClassification"),
] ]
) )
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
[ [
# Model for Multiple Choice mapping # Model for Multiple Choice mapping
("distilbert", "FlaxDistilBertForMultipleChoice"),
("albert", "FlaxAlbertForMultipleChoice"), ("albert", "FlaxAlbertForMultipleChoice"),
("roberta", "FlaxRobertaForMultipleChoice"),
("xlm-roberta", "FlaxXLMRobertaForMultipleChoice"),
("bert", "FlaxBertForMultipleChoice"), ("bert", "FlaxBertForMultipleChoice"),
("big_bird", "FlaxBigBirdForMultipleChoice"), ("big_bird", "FlaxBigBirdForMultipleChoice"),
("distilbert", "FlaxDistilBertForMultipleChoice"),
("electra", "FlaxElectraForMultipleChoice"), ("electra", "FlaxElectraForMultipleChoice"),
("roberta", "FlaxRobertaForMultipleChoice"),
("roformer", "FlaxRoFormerForMultipleChoice"), ("roformer", "FlaxRoFormerForMultipleChoice"),
("xlm-roberta", "FlaxXLMRobertaForMultipleChoice"),
] ]
) )
......
...@@ -41,17 +41,17 @@ PROCESSOR_MAPPING_NAMES = OrderedDict( ...@@ -41,17 +41,17 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("flava", "FLAVAProcessor"), ("flava", "FLAVAProcessor"),
("layoutlmv2", "LayoutLMv2Processor"), ("layoutlmv2", "LayoutLMv2Processor"),
("layoutxlm", "LayoutXLMProcessor"), ("layoutxlm", "LayoutXLMProcessor"),
("sew", "Wav2Vec2Processor"),
("sew-d", "Wav2Vec2Processor"),
("speech_to_text", "Speech2TextProcessor"), ("speech_to_text", "Speech2TextProcessor"),
("speech_to_text_2", "Speech2Text2Processor"), ("speech_to_text_2", "Speech2Text2Processor"),
("trocr", "TrOCRProcessor"), ("trocr", "TrOCRProcessor"),
("wav2vec2", "Wav2Vec2Processor"),
("wav2vec2_with_lm", "Wav2Vec2ProcessorWithLM"),
("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"),
("unispeech", "Wav2Vec2Processor"), ("unispeech", "Wav2Vec2Processor"),
("unispeech-sat", "Wav2Vec2Processor"), ("unispeech-sat", "Wav2Vec2Processor"),
("sew", "Wav2Vec2Processor"),
("sew-d", "Wav2Vec2Processor"),
("vilt", "ViltProcessor"), ("vilt", "ViltProcessor"),
("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"),
("wav2vec2", "Wav2Vec2Processor"),
("wav2vec2_with_lm", "Wav2Vec2ProcessorWithLM"),
("wavlm", "Wav2Vec2Processor"), ("wavlm", "Wav2Vec2Processor"),
] ]
) )
...@@ -65,7 +65,10 @@ def processor_class_from_name(class_name: str): ...@@ -65,7 +65,10 @@ def processor_class_from_name(class_name: str):
module_name = model_type_to_module_name(module_name) module_name = model_type_to_module_name(module_name)
module = importlib.import_module(f".{module_name}", "transformers.models") module = importlib.import_module(f".{module_name}", "transformers.models")
try:
return getattr(module, class_name) return getattr(module, class_name)
except AttributeError:
continue
for processor in PROCESSOR_MAPPING._extra_content.values(): for processor in PROCESSOR_MAPPING._extra_content.values():
if getattr(processor, "__name__", None) == class_name: if getattr(processor, "__name__", None) == class_name:
......
...@@ -46,34 +46,37 @@ if TYPE_CHECKING: ...@@ -46,34 +46,37 @@ if TYPE_CHECKING:
else: else:
TOKENIZER_MAPPING_NAMES = OrderedDict( TOKENIZER_MAPPING_NAMES = OrderedDict(
[ [
("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)),
("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)),
("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)),
( (
"t5", "albert",
( (
"T5Tokenizer" if is_sentencepiece_available() else None, "AlbertTokenizer" if is_sentencepiece_available() else None,
"T5TokenizerFast" if is_tokenizers_available() else None, "AlbertTokenizerFast" if is_tokenizers_available() else None,
), ),
), ),
("bart", ("BartTokenizer", "BartTokenizerFast")),
( (
"mt5", "barthez",
( (
"MT5Tokenizer" if is_sentencepiece_available() else None, "BarthezTokenizer" if is_sentencepiece_available() else None,
"MT5TokenizerFast" if is_tokenizers_available() else None, "BarthezTokenizerFast" if is_tokenizers_available() else None,
), ),
), ),
("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)), ("bartpho", ("BartphoTokenizer", None)),
("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)), ("bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("bert-generation", ("BertGenerationTokenizer" if is_sentencepiece_available() else None, None)),
("bert-japanese", ("BertJapaneseTokenizer", None)),
("bertweet", ("BertweetTokenizer", None)),
( (
"albert", "big_bird",
( (
"AlbertTokenizer" if is_sentencepiece_available() else None, "BigBirdTokenizer" if is_sentencepiece_available() else None,
"AlbertTokenizerFast" if is_tokenizers_available() else None, "BigBirdTokenizerFast" if is_tokenizers_available() else None,
), ),
), ),
("bigbird_pegasus", ("PegasusTokenizer", "PegasusTokenizerFast" if is_tokenizers_available() else None)),
("blenderbot", ("BlenderbotTokenizer", "BlenderbotTokenizerFast")),
("blenderbot-small", ("BlenderbotSmallTokenizer", None)),
("byt5", ("ByT5Tokenizer", None)),
( (
"camembert", "camembert",
( (
...@@ -81,136 +84,119 @@ else: ...@@ -81,136 +84,119 @@ else:
"CamembertTokenizerFast" if is_tokenizers_available() else None, "CamembertTokenizerFast" if is_tokenizers_available() else None,
), ),
), ),
("canine", ("CanineTokenizer", None)),
( (
"pegasus", "clip",
( (
"PegasusTokenizer" if is_sentencepiece_available() else None, "CLIPTokenizer",
"PegasusTokenizerFast" if is_tokenizers_available() else None, "CLIPTokenizerFast" if is_tokenizers_available() else None,
), ),
), ),
("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
( (
"mbart", "cpm",
( (
"MBartTokenizer" if is_sentencepiece_available() else None, "CpmTokenizer" if is_sentencepiece_available() else None,
"MBartTokenizerFast" if is_tokenizers_available() else None, "CpmTokenizerFast" if is_tokenizers_available() else None,
), ),
), ),
("ctrl", ("CTRLTokenizer", None)),
("data2vec-text", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
("deberta", ("DebertaTokenizer", "DebertaTokenizerFast" if is_tokenizers_available() else None)),
( (
"xlm-roberta", "deberta-v2",
( (
"XLMRobertaTokenizer" if is_sentencepiece_available() else None, "DebertaV2Tokenizer" if is_sentencepiece_available() else None,
"XLMRobertaTokenizerFast" if is_tokenizers_available() else None, "DebertaV2TokenizerFast" if is_tokenizers_available() else None,
), ),
), ),
("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)), ("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)),
("blenderbot-small", ("BlenderbotSmallTokenizer", None)),
("blenderbot", ("BlenderbotTokenizer", "BlenderbotTokenizerFast")),
("tapex", ("TapexTokenizer", None)),
("bart", ("BartTokenizer", "BartTokenizerFast")),
("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
("roberta", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
( (
"reformer", "dpr",
( (
"ReformerTokenizer" if is_sentencepiece_available() else None, "DPRQuestionEncoderTokenizer",
"ReformerTokenizerFast" if is_tokenizers_available() else None, "DPRQuestionEncoderTokenizerFast" if is_tokenizers_available() else None,
), ),
), ),
("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)), ("electra", ("ElectraTokenizer", "ElectraTokenizerFast" if is_tokenizers_available() else None)),
("flaubert", ("FlaubertTokenizer", None)),
("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
("fsmt", ("FSMTTokenizer", None)),
("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)), ("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)),
("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)), ("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)),
("hubert", ("Wav2Vec2CTCTokenizer", None)),
("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)), ("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)),
("layoutlmv2", ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast" if is_tokenizers_available() else None)), ("layoutlmv2", ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast" if is_tokenizers_available() else None)),
("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)), ("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)),
("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
("longformer", ("LongformerTokenizer", "LongformerTokenizerFast" if is_tokenizers_available() else None)),
("luke", ("LukeTokenizer", None)),
("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)),
("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)),
("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)),
( (
"dpr", "mbart",
( (
"DPRQuestionEncoderTokenizer", "MBartTokenizer" if is_sentencepiece_available() else None,
"DPRQuestionEncoderTokenizerFast" if is_tokenizers_available() else None, "MBartTokenizerFast" if is_tokenizers_available() else None,
),
), ),
(
"squeezebert",
("SqueezeBertTokenizer", "SqueezeBertTokenizerFast" if is_tokenizers_available() else None),
), ),
("bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None)),
("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("opt", ("GPT2Tokenizer", None)),
("transfo-xl", ("TransfoXLTokenizer", None)),
( (
"xlnet", "mbart50",
( (
"XLNetTokenizer" if is_sentencepiece_available() else None, "MBart50Tokenizer" if is_sentencepiece_available() else None,
"XLNetTokenizerFast" if is_tokenizers_available() else None, "MBart50TokenizerFast" if is_tokenizers_available() else None,
), ),
), ),
("flaubert", ("FlaubertTokenizer", None)), ("megatron-bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("xlm", ("XLMTokenizer", None)), ("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)),
("ctrl", ("CTRLTokenizer", None)), ("mobilebert", ("MobileBertTokenizer", "MobileBertTokenizerFast" if is_tokenizers_available() else None)),
("fsmt", ("FSMTTokenizer", None)), ("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)),
("bert-generation", ("BertGenerationTokenizer" if is_sentencepiece_available() else None, None)),
("deberta", ("DebertaTokenizer", "DebertaTokenizerFast" if is_tokenizers_available() else None)),
( (
"deberta-v2", "mt5",
( (
"DebertaV2Tokenizer" if is_sentencepiece_available() else None, "MT5Tokenizer" if is_sentencepiece_available() else None,
"DebertaV2TokenizerFast" if is_tokenizers_available() else None, "MT5TokenizerFast" if is_tokenizers_available() else None,
), ),
), ),
("rag", ("RagTokenizer", None)),
("xlm-prophetnet", ("XLMProphetNetTokenizer" if is_sentencepiece_available() else None, None)),
("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
("speech_to_text_2", ("Speech2Text2Tokenizer", None)),
("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)),
("prophetnet", ("ProphetNetTokenizer", None)),
("mpnet", ("MPNetTokenizer", "MPNetTokenizerFast" if is_tokenizers_available() else None)),
("tapas", ("TapasTokenizer", None)),
("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
( (
"big_bird", "nystromformer",
( (
"BigBirdTokenizer" if is_sentencepiece_available() else None, "AlbertTokenizer" if is_sentencepiece_available() else None,
"BigBirdTokenizerFast" if is_tokenizers_available() else None, "AlbertTokenizerFast" if is_tokenizers_available() else None,
), ),
), ),
("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), ("openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None)),
("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("opt", ("GPT2Tokenizer", None)),
("wav2vec2", ("Wav2Vec2CTCTokenizer", None)),
("hubert", ("Wav2Vec2CTCTokenizer", None)),
("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("luke", ("LukeTokenizer", None)),
("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)),
("bigbird_pegasus", ("PegasusTokenizer", "PegasusTokenizerFast" if is_tokenizers_available() else None)),
("canine", ("CanineTokenizer", None)),
("bertweet", ("BertweetTokenizer", None)),
("bert-japanese", ("BertJapaneseTokenizer", None)),
("splinter", ("SplinterTokenizer", "SplinterTokenizerFast")),
("byt5", ("ByT5Tokenizer", None)),
( (
"cpm", "pegasus",
( (
"CpmTokenizer" if is_sentencepiece_available() else None, "PegasusTokenizer" if is_sentencepiece_available() else None,
"CpmTokenizerFast" if is_tokenizers_available() else None, "PegasusTokenizerFast" if is_tokenizers_available() else None,
), ),
), ),
("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)),
("phobert", ("PhobertTokenizer", None)),
("bartpho", ("BartphoTokenizer", None)),
( (
"barthez", "perceiver",
( (
"BarthezTokenizer" if is_sentencepiece_available() else None, "PerceiverTokenizer",
"BarthezTokenizerFast" if is_tokenizers_available() else None, None,
), ),
), ),
("phobert", ("PhobertTokenizer", None)),
("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
("prophetnet", ("ProphetNetTokenizer", None)),
("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("rag", ("RagTokenizer", None)),
("realm", ("RealmTokenizer", "RealmTokenizerFast" if is_tokenizers_available() else None)),
( (
"mbart50", "reformer",
( (
"MBart50Tokenizer" if is_sentencepiece_available() else None, "ReformerTokenizer" if is_sentencepiece_available() else None,
"MBart50TokenizerFast" if is_tokenizers_available() else None, "ReformerTokenizerFast" if is_tokenizers_available() else None,
), ),
), ),
( (
...@@ -220,21 +206,29 @@ else: ...@@ -220,21 +206,29 @@ else:
"RemBertTokenizerFast" if is_tokenizers_available() else None, "RemBertTokenizerFast" if is_tokenizers_available() else None,
), ),
), ),
("retribert", ("RetriBertTokenizer", "RetriBertTokenizerFast" if is_tokenizers_available() else None)),
("roberta", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)),
("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
("speech_to_text_2", ("Speech2Text2Tokenizer", None)),
("splinter", ("SplinterTokenizer", "SplinterTokenizerFast")),
( (
"clip", "squeezebert",
( ("SqueezeBertTokenizer", "SqueezeBertTokenizerFast" if is_tokenizers_available() else None),
"CLIPTokenizer",
"CLIPTokenizerFast" if is_tokenizers_available() else None,
),
), ),
("wav2vec2_phoneme", ("Wav2Vec2PhonemeCTCTokenizer", None)),
( (
"perceiver", "t5",
( (
"PerceiverTokenizer", "T5Tokenizer" if is_sentencepiece_available() else None,
None, "T5TokenizerFast" if is_tokenizers_available() else None,
), ),
), ),
("tapas", ("TapasTokenizer", None)),
("tapex", ("TapexTokenizer", None)),
("transfo-xl", ("TransfoXLTokenizer", None)),
("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("wav2vec2", ("Wav2Vec2CTCTokenizer", None)),
("wav2vec2_phoneme", ("Wav2Vec2PhonemeCTCTokenizer", None)),
( (
"xglm", "xglm",
( (
...@@ -242,16 +236,23 @@ else: ...@@ -242,16 +236,23 @@ else:
"XGLMTokenizerFast" if is_tokenizers_available() else None, "XGLMTokenizerFast" if is_tokenizers_available() else None,
), ),
), ),
("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("xlm", ("XLMTokenizer", None)),
("megatron-bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("xlm-prophetnet", ("XLMProphetNetTokenizer" if is_sentencepiece_available() else None, None)),
( (
"nystromformer", "xlm-roberta",
( (
"AlbertTokenizer" if is_sentencepiece_available() else None, "XLMRobertaTokenizer" if is_sentencepiece_available() else None,
"AlbertTokenizerFast" if is_tokenizers_available() else None, "XLMRobertaTokenizerFast" if is_tokenizers_available() else None,
), ),
), ),
("xlm-roberta-xl", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)), ("xlm-roberta-xl", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
(
"xlnet",
(
"XLNetTokenizer" if is_sentencepiece_available() else None,
"XLNetTokenizerFast" if is_tokenizers_available() else None,
),
),
( (
"yoso", "yoso",
( (
...@@ -259,7 +260,6 @@ else: ...@@ -259,7 +260,6 @@ else:
"AlbertTokenizerFast" if is_tokenizers_available() else None, "AlbertTokenizerFast" if is_tokenizers_available() else None,
), ),
), ),
("data2vec-text", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
] ]
) )
...@@ -277,7 +277,10 @@ def tokenizer_class_from_name(class_name: str): ...@@ -277,7 +277,10 @@ def tokenizer_class_from_name(class_name: str):
module_name = model_type_to_module_name(module_name) module_name = model_type_to_module_name(module_name)
module = importlib.import_module(f".{module_name}", "transformers.models") module = importlib.import_module(f".{module_name}", "transformers.models")
try:
return getattr(module, class_name) return getattr(module, class_name)
except AttributeError:
continue
for config, tokenizers in TOKENIZER_MAPPING._extra_content.items(): for config, tokenizers in TOKENIZER_MAPPING._extra_content.items():
for tokenizer in tokenizers: for tokenizer in tokenizers:
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
import importlib import importlib
import json
import os
import sys import sys
import tempfile import tempfile
import unittest import unittest
...@@ -56,14 +58,14 @@ class AutoConfigTest(unittest.TestCase): ...@@ -56,14 +58,14 @@ class AutoConfigTest(unittest.TestCase):
self.assertIsInstance(config, RobertaConfig) self.assertIsInstance(config, RobertaConfig)
def test_pattern_matching_fallback(self): def test_pattern_matching_fallback(self):
""" with tempfile.TemporaryDirectory() as tmp_dir:
In cases where config.json doesn't include a model_type, # This model name contains bert and roberta, but roberta ends up being picked.
perform a few safety checks on the config mapping's order. folder = os.path.join(tmp_dir, "fake-roberta")
""" os.makedirs(folder, exist_ok=True)
# no key string should be included in a later key string (typical failure case) with open(os.path.join(folder, "config.json"), "w") as f:
keys = list(CONFIG_MAPPING.keys()) f.write(json.dumps({}))
for i, key in enumerate(keys): config = AutoConfig.from_pretrained(folder)
self.assertFalse(any(key in later_key for later_key in keys[i + 1 :])) self.assertEqual(type(config), RobertaConfig)
def test_new_config_registration(self): def test_new_config_registration(self):
try: try:
......
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import re
PATH_TO_AUTO_MODULE = "src/transformers/models/auto"
# re pattern that matches mapping introductions:
# SUPER_MODEL_MAPPING_NAMES = OrderedDict or SUPER_MODEL_MAPPING = OrderedDict
_re_intro_mapping = re.compile("[A-Z_]+_MAPPING(\s+|_[A-Z_]+\s+)=\s+OrderedDict")
# re pattern that matches identifiers in mappings
_re_identifier = re.compile(r'\s*\(\s*"(\S[^"]+)"')
def sort_auto_mapping(fname, overwrite: bool = False):
with open(fname, "r", encoding="utf-8") as f:
content = f.read()
lines = content.split("\n")
new_lines = []
line_idx = 0
while line_idx < len(lines):
if _re_intro_mapping.search(lines[line_idx]) is not None:
indent = len(re.search(r"^(\s*)\S", lines[line_idx]).groups()[0]) + 8
# Start of a new mapping!
while not lines[line_idx].startswith(" " * indent + "("):
new_lines.append(lines[line_idx])
line_idx += 1
blocks = []
while lines[line_idx].strip() != "]":
# Blocks either fit in one line or not
if lines[line_idx].strip() == "(":
start_idx = line_idx
while not lines[line_idx].startswith(" " * indent + ")"):
line_idx += 1
blocks.append("\n".join(lines[start_idx : line_idx + 1]))
else:
blocks.append(lines[line_idx])
line_idx += 1
# Sort blocks by their identifiers
blocks = sorted(blocks, key=lambda x: _re_identifier.search(x).groups()[0])
new_lines += blocks
else:
new_lines.append(lines[line_idx])
line_idx += 1
if overwrite:
with open(fname, "w", encoding="utf-8") as f:
f.write("\n".join(new_lines))
elif "\n".join(new_lines) != content:
return True
def sort_all_auto_mappings(overwrite: bool = False):
fnames = [os.path.join(PATH_TO_AUTO_MODULE, f) for f in os.listdir(PATH_TO_AUTO_MODULE) if f.endswith(".py")]
diffs = [sort_auto_mapping(fname, overwrite=overwrite) for fname in fnames]
if not overwrite and any(diffs):
failures = [f for f, d in zip(fnames, diffs) if d]
raise ValueError(
f"The following files have auto mappings that need sorting: {', '.join(failures)}. Run `make style` to fix"
" this."
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--check_only", action="store_true", help="Whether to only check or fix style.")
args = parser.parse_args()
sort_all_auto_mappings(not args.check_only)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment