Unverified Commit 9870093f authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

[WIP] Disentangle auto modules from other modeling files (#13023)

* Initial work

* All auto models

* All tf auto models

* All flax auto models

* Tokenizers

* Add feature extractors

* Fix typos

* Fix other typo

* Use the right config

* Remove old mapping names and update logic in AutoTokenizer

* Update check_table

* Fix copies and check_repo script

* Fix last test

* Add back name

* clean up

* Update template

* Update template

* Forgot a )

* Use alternative to fixup

* Fix TF model template

* Address review comments

* Address review comments

* Style
parent 2e408236
......@@ -17,7 +17,7 @@
{% if cookiecutter.is_encoder_decoder_model == "False" %}
import math
from typing import Any, Dict, Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union
import numpy as np
import tensorflow as tf
......@@ -1484,7 +1484,7 @@ from ...file_utils import (
)
from ...modeling_tf_outputs import (
TFBaseModelOutput,
TFBaseModelOutputWithPast,
TFBaseModelOutputWithPastAndCrossAttentions,
TFSeq2SeqLMOutput,
TFSeq2SeqModelOutput,
)
......@@ -2162,7 +2162,7 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
)
# encoder layers
for encoder_layer in self.layers:
for idx, encoder_layer in enumerate(self.layers):
if inputs["output_hidden_states"]:
encoder_states = encoder_states + (hidden_states,)
......
......@@ -172,17 +172,12 @@
# To replace in: "src/transformers/models/auto/configuration_auto.py"
# Below: "# Add configs here"
# Replace with:
("{{cookiecutter.lowercase_modelname}}", {{cookiecutter.camelcase_modelname}}Config),
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}Config"),
# End.
# Below: "# Add archive maps here"
# Replace with:
{{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP,
# End.
# Below: "from ..albert.configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig",
# Replace with:
from ..{{cookiecutter.lowercase_modelname}}.configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP"),
# End.
# Below: "# Add full (and cased) model names here"
......@@ -193,75 +188,47 @@ from ..{{cookiecutter.lowercase_modelname}}.configuration_{{cookiecutter.lowerca
# To replace in: "src/transformers/models/auto/modeling_auto.py" if generating PyTorch
# Below: "from .configuration_auto import ("
# Replace with:
{{cookiecutter.camelcase_modelname}}Config,
# End.
# Below: "# Add modeling imports here"
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
{{cookiecutter.camelcase_modelname}}ForCausalLM,
{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}ForTokenClassification,
{{cookiecutter.camelcase_modelname}}Model,
)
{% else -%}
from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
{{cookiecutter.camelcase_modelname}}ForCausalLM,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}Model,
)
{% endif -%}
# End.
# Below: "# Base model mapping"
# Replace with:
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}Model),
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}Model"),
# End.
# Below: "# Model with LM heads mapping"
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForMaskedLM),
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForMaskedLM"),
{% else %}
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForConditionalGeneration),
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForConditionalGeneration"),
{% endif -%}
# End.
# Below: "# Model for Causal LM mapping"
# Replace with:
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForCausalLM),
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForCausalLM"),
# End.
# Below: "# Model for Masked LM mapping"
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForMaskedLM),
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForMaskedLM"),
{% else -%}
{% endif -%}
# End.
# Below: "# Model for Sequence Classification mapping"
# Replace with:
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForSequenceClassification),
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForSequenceClassification"),
# End.
# Below: "# Model for Question Answering mapping"
# Replace with:
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForQuestionAnswering),
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForQuestionAnswering"),
# End.
# Below: "# Model for Token Classification mapping"
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForTokenClassification),
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForTokenClassification"),
{% else -%}
{% endif -%}
# End.
......@@ -269,7 +236,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_mo
# Below: "# Model for Multiple Choice mapping"
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForMultipleChoice),
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForMultipleChoice"),
{% else -%}
{% endif -%}
# End.
......@@ -278,54 +245,29 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_{{cookiecutter.lowercase_mo
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
{% else %}
({{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}ForConditionalGeneration),
("{{cookiecutter.lowercase_modelname}}", "{{cookiecutter.camelcase_modelname}}ForConditionalGeneration"),
{% endif -%}
# End.
# To replace in: "src/transformers/models/auto/modeling_tf_auto.py" if generating TensorFlow
# Below: "from .configuration_auto import ("
# Replace with:
{{cookiecutter.camelcase_modelname}}Config,
# End.
# Below: "# Add modeling imports here"
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase_modelname}} import (
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
TF{{cookiecutter.camelcase_modelname}}ForCausalLM,
TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
TF{{cookiecutter.camelcase_modelname}}ForTokenClassification,
TF{{cookiecutter.camelcase_modelname}}Model,
)
{% else -%}
from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase_modelname}} import (
TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
TF{{cookiecutter.camelcase_modelname}}Model,
)
{% endif -%}
# End.
# Below: "# Base model mapping"
# Replace with:
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}Model),
("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}Model"),
# End.
# Below: "# Model with LM heads mapping"
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForMaskedLM),
("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForMaskedLM"),
{% else %}
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration),
("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration"),
{% endif -%}
# End.
# Below: "# Model for Causal LM mapping"
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForCausalLM),
("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForCausalLM"),
{% else -%}
{% endif -%}
# End.
......@@ -333,7 +275,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
# Below: "# Model for Masked LM mapping"
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForMaskedLM),
("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForMaskedLM"),
{% else -%}
{% endif -%}
# End.
......@@ -341,7 +283,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
# Below: "# Model for Sequence Classification mapping"
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification),
("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification"),
{% else -%}
{% endif -%}
# End.
......@@ -349,7 +291,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
# Below: "# Model for Question Answering mapping"
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering),
("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering"),
{% else -%}
{% endif -%}
# End.
......@@ -357,7 +299,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
# Below: "# Model for Token Classification mapping"
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForTokenClassification),
("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForTokenClassification"),
{% else -%}
{% endif -%}
# End.
......@@ -365,7 +307,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
# Below: "# Model for Multiple Choice mapping"
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice),
("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice"),
{% else -%}
{% endif -%}
# End.
......@@ -374,7 +316,7 @@ from ..{{cookiecutter.lowercase_modelname}}.modeling_tf_{{cookiecutter.lowercase
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" -%}
{% else %}
({{cookiecutter.camelcase_modelname}}Config, TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration),
("{{cookiecutter.lowercase_modelname}}", "TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration"),
{% endif -%}
# End.
......
......@@ -23,7 +23,8 @@ from .test_pipelines_common import MonoInputPipelineCommonMixin
if is_torch_available():
from transformers.models.mbart import MBart50TokenizerFast, MBartForConditionalGeneration
from transformers.models.mbart import MBartForConditionalGeneration
from transformers.models.mbart50 import MBart50TokenizerFast
class TranslationEnToDePipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
......
......@@ -306,17 +306,17 @@ def get_all_auto_configured_models():
result = set() # To avoid duplicates we concatenate all model classes in a set.
if is_torch_available():
for attr_name in dir(transformers.models.auto.modeling_auto):
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING"):
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING_NAMES"):
result = result | set(get_values(getattr(transformers.models.auto.modeling_auto, attr_name)))
if is_tf_available():
for attr_name in dir(transformers.models.auto.modeling_tf_auto):
if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING"):
if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING_NAMES"):
result = result | set(get_values(getattr(transformers.models.auto.modeling_tf_auto, attr_name)))
if is_flax_available():
for attr_name in dir(transformers.models.auto.modeling_flax_auto):
if attr_name.startswith("FLAX_MODEL_") and attr_name.endswith("MAPPING"):
if attr_name.startswith("FLAX_MODEL_") and attr_name.endswith("MAPPING_NAMES"):
result = result | set(get_values(getattr(transformers.models.auto.modeling_flax_auto, attr_name)))
return [cls.__name__ for cls in result]
return [cls for cls in result]
def ignore_unautoclassed(model_name):
......
......@@ -87,12 +87,13 @@ def get_model_table_from_auto_modules():
transformers = spec.loader.load_module()
# Dictionary model names to config.
config_maping_names = transformers.models.auto.configuration_auto.CONFIG_MAPPING_NAMES
model_name_to_config = {
name: transformers.CONFIG_MAPPING[code] for code, name in transformers.MODEL_NAMES_MAPPING.items()
}
model_name_to_prefix = {
name: config.__name__.replace("Config", "") for name, config in model_name_to_config.items()
name: config_maping_names[code]
for code, name in transformers.MODEL_NAMES_MAPPING.items()
if code in config_maping_names
}
model_name_to_prefix = {name: config.replace("Config", "") for name, config in model_name_to_config.items()}
# Dictionaries flagging if each model prefix has a slow/fast tokenizer, backend in PT/TF/Flax.
slow_tokenizers = collections.defaultdict(bool)
......
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# this script remaps classes to class strings so that it's quick to load such maps and not require
# loading all possible modeling files
#
# it can be extended to auto-generate other dicts that are needed at runtime
import os
import sys
from os.path import abspath, dirname, join
git_repo_path = abspath(join(dirname(dirname(__file__)), "src"))
sys.path.insert(1, git_repo_path)
src = "src/transformers/models/auto/modeling_auto.py"
dst = "src/transformers/utils/modeling_auto_mapping.py"
if os.path.exists(dst) and os.path.getmtime(src) < os.path.getmtime(dst):
# speed things up by only running this script if the src is newer than dst
sys.exit(0)
# only load if needed
from transformers.models.auto.modeling_auto import ( # noqa
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
MODEL_FOR_OBJECT_DETECTION_MAPPING,
MODEL_FOR_PRETRAINING_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
)
# Those constants don't have a name attribute, so we need to define it manually
mappings = {
"MODEL_FOR_QUESTION_ANSWERING_MAPPING": MODEL_FOR_QUESTION_ANSWERING_MAPPING,
"MODEL_FOR_CAUSAL_LM_MAPPING": MODEL_FOR_CAUSAL_LM_MAPPING,
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
"MODEL_FOR_MASKED_LM_MAPPING": MODEL_FOR_MASKED_LM_MAPPING,
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING": MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
"MODEL_FOR_OBJECT_DETECTION_MAPPING": MODEL_FOR_OBJECT_DETECTION_MAPPING,
"MODEL_FOR_OBJECT_DETECTION_MAPPING": MODEL_FOR_OBJECT_DETECTION_MAPPING,
"MODEL_FOR_QUESTION_ANSWERING_MAPPING": MODEL_FOR_QUESTION_ANSWERING_MAPPING,
"MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
"MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING": MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
"MODEL_MAPPING": MODEL_MAPPING,
"MODEL_WITH_LM_HEAD_MAPPING": MODEL_WITH_LM_HEAD_MAPPING,
}
def get_name(value):
if isinstance(value, tuple):
return tuple(get_name(o) for o in value)
return value.__name__
content = [
"# THIS FILE HAS BEEN AUTOGENERATED. To update:",
"# 1. modify: models/auto/modeling_auto.py",
"# 2. run: python utils/class_mapping_update.py",
"from collections import OrderedDict",
"",
]
for name, mapping in mappings.items():
entries = "\n".join([f' ("{k.__name__}", "{get_name(v)}"),' for k, v in mapping.items()])
content += [
"",
f"{name}_NAMES = OrderedDict(",
" [",
entries,
" ]",
")",
"",
]
print(f"Updating {dst}")
with open(dst, "w", encoding="utf-8", newline="\n") as f:
f.write("\n".join(content))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment