Unverified Commit 9870093f authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

[WIP] Disentangle auto modules from other modeling files (#13023)

* Initial work

* All auto models

* All tf auto models

* All flax auto models

* Tokenizers

* Add feature extractors

* Fix typos

* Fix other typo

* Use the right config

* Remove old mapping names and update logic in AutoTokenizer

* Update check_table

* Fix copies and check_repo script

* Fix last test

* Add back name

* clean up

* Update template

* Update template

* Forgot a )

* Use alternative to fixup

* Fix TF model template

* Address review comments

* Address review comments

* Style
parent 2e408236
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
{% if cookiecutter.is_encoder_decoder_model == "False" %} {% if cookiecutter.is_encoder_decoder_model == "False" %}
import math import math
from typing import Any, Dict, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -1484,7 +1484,7 @@ from ...file_utils import ( ...@@ -1484,7 +1484,7 @@ from ...file_utils import (
) )
from ...modeling_tf_outputs import ( from ...modeling_tf_outputs import (
TFBaseModelOutput, TFBaseModelOutput,
TFBaseModelOutputWithPast, TFBaseModelOutputWithPastAndCrossAttentions,
TFSeq2SeqLMOutput, TFSeq2SeqLMOutput,
TFSeq2SeqModelOutput, TFSeq2SeqModelOutput,
) )
...@@ -2162,7 +2162,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer): ...@@ -2162,7 +2162,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
) )
# encoder layers # encoder layers
for encoder_layer in self.layers: for idx, encoder_layer in enumerate(self.layers):
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
......
...@@ -172,17 +172,12 @@ ...@@ -172,17 +172,12 @@
# To replace in: "src/transformers/models/auto/configuration_auto.py" # To replace in: "src/transformers/models/auto/configuration_auto.py"
# Below: "# Add configs here" # Below: "# Add configs here"
# Replace with: # Replace with:
("{{cookiecutter.lowercase_modelname}}", {{cookiecutter.camelcase_modelname}}Config), ("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}Config"),
# End. # End.
# Below: "# Add archive maps here" # Below: "# Add archive maps here"
# Replace with: # Replace with:
{{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, ("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP"),
# End.
# Below: "from ..albert.configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig",
# Replace with:
from ..{{cookiecutter.lowercase_modelname}}.configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
# End. # End.
# Below: "# Add full (and cased) model names here" # Below: "# Add full (and cased) model names here"
...@@ -193,75 +188,47 @@ from ..{{cookiecutter.lowercase_modelname}}.configuration_{{cookiecutter.lowerca ...@@ -193,75 +188,47 @@ from ..{{cookiecutter.lowercase_modelname}}.configuration_{{cookiecutter.lowerca
# To replace in: "src/transformers/models/auto/modeling_auto.py" if generating PyTorch # To replace in: "src/transformers/models/auto/modeling_auto.py" if generating PyTorch
# Below: "from .configuration_auto import ("
# Replace with:
{{cookiecutter.camelcase_modelname}}Config,
# End.
# Below: "# Add modeling imports here"
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
{{cookiecutter.camelcase_modelname}}ForCausalLM,
{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}ForTokenClassification,
{{cookiecutter.camelcase_modelname}}Model,
)
{% else -%}
from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
{{cookiecutter.camelcase_modelname}}ForCausalLM,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}Model,
)
{% endif -%}
# End.
# Below: "# Base model mapping" # Below: "# Base model mapping"
# Replace with: # Replace with:
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}Model), ("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}Model"),
# End. # End.
# Below: "# Model with LM heads mapping" # Below: "# Model with LM heads mapping"
# Replace with: # Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%} {% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForMaskedLM), ("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForMaskedLM"),
{% else %} {% else %}
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForConditionalGeneration), ("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForConditionalGeneration"),
{% endif -%} {% endif -%}
# End. # End.
# Below: "# Model for Causal LM mapping" # Below: "# Model for Causal LM mapping"
# Replace with: # Replace with:
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForCausalLM), ("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForCausalLM"),
# End. # End.
# Below: "# Model for Masked LM mapping" # Below: "# Model for Masked LM mapping"
# Replace with: # Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%} {% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForMaskedLM), ("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForMaskedLM"),
{% else -%} {% else -%}
{% endif -%} {% endif -%}
# End. # End.
# Below: "# Model for Sequence Classification mapping" # Below: "# Model for Sequence Classification mapping"
# Replace with: # Replace with:
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForSequenceClassification), ("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForSequenceClassification"),
# End. # End.
# Below: "# Model for Question Answering mapping" # Below: "# Model for Question Answering mapping"
# Replace with: # Replace with:
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForQuestionAnswering), ("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForQuestionAnswering"),
# End. # End.
# Below: "# Model for Token Classification mapping" # Below: "# Model for Token Classification mapping"
# Replace with: # Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%} {% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForTokenClassification), ("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForTokenClassification"),
{% else -%} {% else -%}
{% endif -%} {% endif -%}
# End. # End.
...@@ -269,7 +236,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_mo ...@@ -269,7 +236,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_mo
# Below: "# Model for Multiple Choice mapping" # Below: "# Model for Multiple Choice mapping"
# Replace with: # Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%} {% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForMultipleChoice), ("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForMultipleChoice"),
{% else -%} {% else -%}
{% endif -%} {% endif -%}
# End. # End.
...@@ -278,54 +245,29 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_mo ...@@ -278,54 +245,29 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_mo
# Replace with: # Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%} {% if cookiecutter.is_encoder_decoder_model == "False" -%}
{% else %} {% else %}
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForConditionalGeneration), ("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForConditionalGeneration"),
{% endif -%} {% endif -%}
# End. # End.
# To replace in: "src/transformers/models/auto/modeling_tf_auto.py" if generating TensorFlow # To replace in: "src/transformers/models/auto/modeling_tf_auto.py" if generating TensorFlow
# Below: "from .configuration_auto import ("
# Replace with:
{{cookiecutter.camelcase_modelname}}Config,
# End.
# Below: "# Add modeling imports here"
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase_modelname}} import (
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
TF{{cookiecutter.camelcase_modelname}}ForCausalLM,
TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
TF{{cookiecutter.camelcase_modelname}}ForTokenClassification,
TF{{cookiecutter.camelcase_modelname}}Model,
)
{% else -%}
from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase_modelname}} import (
TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
TF{{cookiecutter.camelcase_modelname}}Model,
)
{% endif -%}
# End.
# Below: "# Base model mapping" # Below: "# Base model mapping"
# Replace with: # Replace with:
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}Model), ("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}Model"),
# End. # End.
# Below: "# Model with LM heads mapping" # Below: "# Model with LM heads mapping"
# Replace with: # Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%} {% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForMaskedLM), ("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForMaskedLM"),
{% else %} {% else %}
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration), ("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration"),
{% endif -%} {% endif -%}
# End. # End.
# Below: "# Model for Causal LM mapping" # Below: "# Model for Causal LM mapping"
# Replace with: # Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%} {% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForCausalLM), ("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForCausalLM"),
{% else -%} {% else -%}
{% endif -%} {% endif -%}
# End. # End.
...@@ -333,7 +275,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase ...@@ -333,7 +275,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
# Below: "# Model for Masked LM mapping" # Below: "# Model for Masked LM mapping"
# Replace with: # Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%} {% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForMaskedLM), ("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForMaskedLM"),
{% else -%} {% else -%}
{% endif -%} {% endif -%}
# End. # End.
...@@ -341,7 +283,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase ...@@ -341,7 +283,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
# Below: "# Model for Sequence Classification mapping" # Below: "# Model for Sequence Classification mapping"
# Replace with: # Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%} {% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification), ("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification"),
{% else -%} {% else -%}
{% endif -%} {% endif -%}
# End. # End.
...@@ -349,7 +291,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase ...@@ -349,7 +291,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
# Below: "# Model for Question Answering mapping" # Below: "# Model for Question Answering mapping"
# Replace with: # Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%} {% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering), ("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering"),
{% else -%} {% else -%}
{% endif -%} {% endif -%}
# End. # End.
...@@ -357,7 +299,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase ...@@ -357,7 +299,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
# Below: "# Model for Token Classification mapping" # Below: "# Model for Token Classification mapping"
# Replace with: # Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%} {% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForTokenClassification), ("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForTokenClassification"),
{% else -%} {% else -%}
{% endif -%} {% endif -%}
# End. # End.
...@@ -365,7 +307,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase ...@@ -365,7 +307,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
# Below: "# Model for Multiple Choice mapping" # Below: "# Model for Multiple Choice mapping"
# Replace with: # Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%} {% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice), ("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice"),
{% else -%} {% else -%}
{% endif -%} {% endif -%}
# End. # End.
...@@ -374,7 +316,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase ...@@ -374,7 +316,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
# Replace with: # Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%} {% if cookiecutter.is_encoder_decoder_model == "False" -%}
{% else %} {% else %}
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration), ("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration"),
{% endif -%} {% endif -%}
# End. # End.
......
...@@ -23,7 +23,8 @@ from .test_pipelines_common import MonoInputPipelineCommonMixin ...@@ -23,7 +23,8 @@ from .test_pipelines_common import MonoInputPipelineCommonMixin
if is_torch_available(): if is_torch_available():
from transformers.models.mbart import MBart50TokenizerFast, MBartForConditionalGeneration from transformers.models.mbart import MBartForConditionalGeneration
from transformers.models.mbart50 import MBart50TokenizerFast
class TranslationEnToDePipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): class TranslationEnToDePipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
......
...@@ -306,17 +306,17 @@ def get_all_auto_configured_models(): ...@@ -306,17 +306,17 @@ def get_all_auto_configured_models():
result = set() # To avoid duplicates we concatenate all model classes in a set. result = set() # To avoid duplicates we concatenate all model classes in a set.
if is_torch_available(): if is_torch_available():
for attr_name in dir(transformers.models.auto.modeling_auto): for attr_name in dir(transformers.models.auto.modeling_auto):
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING"): if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING_NAMES"):
result = result | set(get_values(getattr(transformers.models.auto.modeling_auto, attr_name))) result = result | set(get_values(getattr(transformers.models.auto.modeling_auto, attr_name)))
if is_tf_available(): if is_tf_available():
for attr_name in dir(transformers.models.auto.modeling_tf_auto): for attr_name in dir(transformers.models.auto.modeling_tf_auto):
if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING"): if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING_NAMES"):
result = result | set(get_values(getattr(transformers.models.auto.modeling_tf_auto, attr_name))) result = result | set(get_values(getattr(transformers.models.auto.modeling_tf_auto, attr_name)))
if is_flax_available(): if is_flax_available():
for attr_name in dir(transformers.models.auto.modeling_flax_auto): for attr_name in dir(transformers.models.auto.modeling_flax_auto):
if attr_name.startswith("FLAX_MODEL_") and attr_name.endswith("MAPPING"): if attr_name.startswith("FLAX_MODEL_") and attr_name.endswith("MAPPING_NAMES"):
result = result | set(get_values(getattr(transformers.models.auto.modeling_flax_auto, attr_name))) result = result | set(get_values(getattr(transformers.models.auto.modeling_flax_auto, attr_name)))
return [cls.__name__ for cls in result] return [cls for cls in result]
def ignore_unautoclassed(model_name): def ignore_unautoclassed(model_name):
......
...@@ -87,12 +87,13 @@ def get_model_table_from_auto_modules(): ...@@ -87,12 +87,13 @@ def get_model_table_from_auto_modules():
transformers = spec.loader.load_module() transformers = spec.loader.load_module()
# Dictionary model names to config. # Dictionary model names to config.
config_maping_names = transformers.models.auto.configuration_auto.CONFIG_MAPPING_NAMES
model_name_to_config = { model_name_to_config = {
name: transformers.CONFIG_MAPPING[code] for code, name in transformers.MODEL_NAMES_MAPPING.items() name: config_maping_names[code]
} for code, name in transformers.MODEL_NAMES_MAPPING.items()
model_name_to_prefix = { if code in config_maping_names
name: config.__name__.replace("Config", "") for name, config in model_name_to_config.items()
} }
model_name_to_prefix = {name: config.replace("Config", "") for name, config in model_name_to_config.items()}
# Dictionaries flagging if each model prefix has a slow/fast tokenizer, backend in PT/TF/Flax. # Dictionaries flagging if each model prefix has a slow/fast tokenizer, backend in PT/TF/Flax.
slow_tokenizers = collections.defaultdict(bool) slow_tokenizers = collections.defaultdict(bool)
......
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# this script remaps classes to class strings so that it's quick to load such maps and not require
# loading all possible modeling files
#
# it can be extended to auto-generate other dicts that are needed at runtime
import os
import sys
from os.path import abspath, dirname, join
git_repo_path = abspath(join(dirname(dirname(__file__)), "src"))
sys.path.insert(1, git_repo_path)
src = "src/transformers/models/auto/modeling_auto.py"
dst = "src/transformers/utils/modeling_auto_mapping.py"
if os.path.exists(dst) and os.path.getmtime(src) < os.path.getmtime(dst):
# speed things up by only running this script if the src is newer than dst
sys.exit(0)
# only load if needed
from transformers.models.auto.modeling_auto import ( # noqa
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
MODEL_FOR_OBJECT_DETECTION_MAPPING,
MODEL_FOR_PRETRAINING_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
)
# Those constants don't have a name attribute, so we need to define it manually
mappings = {
"MODEL_FOR_QUESTION_ANSWERING_MAPPING": MODEL_FOR_QUESTION_ANSWERING_MAPPING,
"MODEL_FOR_CAUSAL_LM_MAPPING": MODEL_FOR_CAUSAL_LM_MAPPING,
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
"MODEL_FOR_MASKED_LM_MAPPING": MODEL_FOR_MASKED_LM_MAPPING,
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING": MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
"MODEL_FOR_OBJECT_DETECTION_MAPPING": MODEL_FOR_OBJECT_DETECTION_MAPPING,
"MODEL_FOR_OBJECT_DETECTION_MAPPING": MODEL_FOR_OBJECT_DETECTION_MAPPING,
"MODEL_FOR_QUESTION_ANSWERING_MAPPING": MODEL_FOR_QUESTION_ANSWERING_MAPPING,
"MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
"MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING": MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
"MODEL_MAPPING": MODEL_MAPPING,
"MODEL_WITH_LM_HEAD_MAPPING": MODEL_WITH_LM_HEAD_MAPPING,
}
def get_name(value):
if isinstance(value, tuple):
return tuple(get_name(o) for o in value)
return value.__name__
content = [
"# THIS FILE HAS BEEN AUTOGENERATED. To update:",
"# 1. modify: models/auto/modeling_auto.py",
"# 2. run: python utils/class_mapping_update.py",
"from collections import OrderedDict",
"",
]
for name, mapping in mappings.items():
entries = "\n".join([f' ("{k.__name__}", "{get_name(v)}"),' for k, v in mapping.items()])
content += [
"",
f"{name}_NAMES = OrderedDict(",
" [",
entries,
" ]",
")",
"",
]
print(f"Updating {dst}")
with open(dst, "w", encoding="utf-8", newline="\n") as f:
f.write("\n".join(content))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment