Unverified Commit c817bc44 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Check all objects are equally in the main `__init__` file (#24573)



* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 8c4471d1
...@@ -142,6 +142,10 @@ The following auto classes are available for the following natural language proc ...@@ -142,6 +142,10 @@ The following auto classes are available for the following natural language proc
[[autodoc]] AutoModelForMaskGeneration [[autodoc]] AutoModelForMaskGeneration
### TFAutoModelForMaskGeneration
[[autodoc]] TFAutoModelForMaskGeneration
### AutoModelForSeq2SeqLM ### AutoModelForSeq2SeqLM
[[autodoc]] AutoModelForSeq2SeqLM [[autodoc]] AutoModelForSeq2SeqLM
...@@ -250,6 +254,10 @@ The following auto classes are available for the following computer vision tasks ...@@ -250,6 +254,10 @@ The following auto classes are available for the following computer vision tasks
[[autodoc]] AutoModelForMaskedImageModeling [[autodoc]] AutoModelForMaskedImageModeling
### TFAutoModelForMaskedImageModeling
[[autodoc]] TFAutoModelForMaskedImageModeling
### AutoModelForObjectDetection ### AutoModelForObjectDetection
[[autodoc]] AutoModelForObjectDetection [[autodoc]] AutoModelForObjectDetection
...@@ -296,6 +304,10 @@ The following auto classes are available for the following audio tasks. ...@@ -296,6 +304,10 @@ The following auto classes are available for the following audio tasks.
### AutoModelForAudioFrameClassification ### AutoModelForAudioFrameClassification
[[autodoc]] TFAutoModelForAudioClassification
### TFAutoModelForAudioFrameClassification
[[autodoc]] AutoModelForAudioFrameClassification [[autodoc]] AutoModelForAudioFrameClassification
### AutoModelForCTC ### AutoModelForCTC
......
...@@ -3025,10 +3025,13 @@ else: ...@@ -3025,10 +3025,13 @@ else:
"TF_MODEL_MAPPING", "TF_MODEL_MAPPING",
"TF_MODEL_WITH_LM_HEAD_MAPPING", "TF_MODEL_WITH_LM_HEAD_MAPPING",
"TFAutoModel", "TFAutoModel",
"TFAutoModelForAudioClassification",
"TFAutoModelForCausalLM", "TFAutoModelForCausalLM",
"TFAutoModelForDocumentQuestionAnswering", "TFAutoModelForDocumentQuestionAnswering",
"TFAutoModelForImageClassification", "TFAutoModelForImageClassification",
"TFAutoModelForMaskedImageModeling",
"TFAutoModelForMaskedLM", "TFAutoModelForMaskedLM",
"TFAutoModelForMaskGeneration",
"TFAutoModelForMultipleChoice", "TFAutoModelForMultipleChoice",
"TFAutoModelForNextSentencePrediction", "TFAutoModelForNextSentencePrediction",
"TFAutoModelForPreTraining", "TFAutoModelForPreTraining",
...@@ -6453,10 +6456,13 @@ if TYPE_CHECKING: ...@@ -6453,10 +6456,13 @@ if TYPE_CHECKING:
TF_MODEL_MAPPING, TF_MODEL_MAPPING,
TF_MODEL_WITH_LM_HEAD_MAPPING, TF_MODEL_WITH_LM_HEAD_MAPPING,
TFAutoModel, TFAutoModel,
TFAutoModelForAudioClassification,
TFAutoModelForCausalLM, TFAutoModelForCausalLM,
TFAutoModelForDocumentQuestionAnswering, TFAutoModelForDocumentQuestionAnswering,
TFAutoModelForImageClassification, TFAutoModelForImageClassification,
TFAutoModelForMaskedImageModeling,
TFAutoModelForMaskedLM, TFAutoModelForMaskedLM,
TFAutoModelForMaskGeneration,
TFAutoModelForMultipleChoice, TFAutoModelForMultipleChoice,
TFAutoModelForNextSentencePrediction, TFAutoModelForNextSentencePrediction,
TFAutoModelForPreTraining, TFAutoModelForPreTraining,
......
...@@ -140,9 +140,12 @@ else: ...@@ -140,9 +140,12 @@ else:
"TF_MODEL_MAPPING", "TF_MODEL_MAPPING",
"TF_MODEL_WITH_LM_HEAD_MAPPING", "TF_MODEL_WITH_LM_HEAD_MAPPING",
"TFAutoModel", "TFAutoModel",
"TFAutoModelForAudioClassification",
"TFAutoModelForCausalLM", "TFAutoModelForCausalLM",
"TFAutoModelForImageClassification", "TFAutoModelForImageClassification",
"TFAutoModelForMaskedImageModeling",
"TFAutoModelForMaskedLM", "TFAutoModelForMaskedLM",
"TFAutoModelForMaskGeneration",
"TFAutoModelForMultipleChoice", "TFAutoModelForMultipleChoice",
"TFAutoModelForNextSentencePrediction", "TFAutoModelForNextSentencePrediction",
"TFAutoModelForPreTraining", "TFAutoModelForPreTraining",
...@@ -313,10 +316,13 @@ if TYPE_CHECKING: ...@@ -313,10 +316,13 @@ if TYPE_CHECKING:
TF_MODEL_MAPPING, TF_MODEL_MAPPING,
TF_MODEL_WITH_LM_HEAD_MAPPING, TF_MODEL_WITH_LM_HEAD_MAPPING,
TFAutoModel, TFAutoModel,
TFAutoModelForAudioClassification,
TFAutoModelForCausalLM, TFAutoModelForCausalLM,
TFAutoModelForDocumentQuestionAnswering, TFAutoModelForDocumentQuestionAnswering,
TFAutoModelForImageClassification, TFAutoModelForImageClassification,
TFAutoModelForMaskedImageModeling,
TFAutoModelForMaskedLM, TFAutoModelForMaskedLM,
TFAutoModelForMaskGeneration,
TFAutoModelForMultipleChoice, TFAutoModelForMultipleChoice,
TFAutoModelForNextSentencePrediction, TFAutoModelForNextSentencePrediction,
TFAutoModelForPreTraining, TFAutoModelForPreTraining,
......
...@@ -593,7 +593,7 @@ class TFAutoModelForSemanticSegmentation(_BaseAutoModelClass): ...@@ -593,7 +593,7 @@ class TFAutoModelForSemanticSegmentation(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING _model_mapping = TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
TF_AutoModelForSemanticSegmentation = auto_class_update( TFAutoModelForSemanticSegmentation = auto_class_update(
TFAutoModelForSemanticSegmentation, head_doc="semantic segmentation" TFAutoModelForSemanticSegmentation, head_doc="semantic segmentation"
) )
......
...@@ -289,6 +289,13 @@ class TFAutoModel(metaclass=DummyObject): ...@@ -289,6 +289,13 @@ class TFAutoModel(metaclass=DummyObject):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
class TFAutoModelForAudioClassification(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFAutoModelForCausalLM(metaclass=DummyObject): class TFAutoModelForCausalLM(metaclass=DummyObject):
_backends = ["tf"] _backends = ["tf"]
...@@ -310,6 +317,13 @@ class TFAutoModelForImageClassification(metaclass=DummyObject): ...@@ -310,6 +317,13 @@ class TFAutoModelForImageClassification(metaclass=DummyObject):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
class TFAutoModelForMaskedImageModeling(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFAutoModelForMaskedLM(metaclass=DummyObject): class TFAutoModelForMaskedLM(metaclass=DummyObject):
_backends = ["tf"] _backends = ["tf"]
...@@ -317,6 +331,13 @@ class TFAutoModelForMaskedLM(metaclass=DummyObject): ...@@ -317,6 +331,13 @@ class TFAutoModelForMaskedLM(metaclass=DummyObject):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
class TFAutoModelForMaskGeneration(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFAutoModelForMultipleChoice(metaclass=DummyObject): class TFAutoModelForMultipleChoice(metaclass=DummyObject):
_backends = ["tf"] _backends = ["tf"]
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import inspect import inspect
import os import os
import re import re
import sys
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from difflib import get_close_matches from difflib import get_close_matches
...@@ -336,6 +337,21 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ ...@@ -336,6 +337,21 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"MusicgenForConditionalGeneration", "MusicgenForConditionalGeneration",
] ]
# DO NOT edit this list!
# (The corresponding pytorch objects should never be in the main `__init__`, but it's too late to remove)
OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK = [
"FlaxBertLayer",
"FlaxBigBirdLayer",
"FlaxRoFormerLayer",
"TFBertLayer",
"TFLxmertEncoder",
"TFLxmertXLayer",
"TFMPNetLayer",
"TFMobileBertLayer",
"TFSegformerLayer",
"TFViTMAELayer",
]
# Update this list for models that have multiple model types for the same # Update this list for models that have multiple model types for the same
# model doc # model doc
MODEL_TYPE_TO_DOC_MAPPING = OrderedDict( MODEL_TYPE_TO_DOC_MAPPING = OrderedDict(
...@@ -735,7 +751,48 @@ def check_all_auto_mappings_importable(): ...@@ -735,7 +751,48 @@ def check_all_auto_mappings_importable():
for name, _ in mappings_to_check.items(): for name, _ in mappings_to_check.items():
name = name.replace("_MAPPING_NAMES", "_MAPPING") name = name.replace("_MAPPING_NAMES", "_MAPPING")
if not hasattr(transformers, name): if not hasattr(transformers, name):
failures.append(f"`{name}` should be defined in the main `__init__` file.") failures.append(f"`{name}`")
if len(failures) > 0:
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
def check_objects_being_equally_in_main_init():
"""Check if an object is in the main __init__ if its counterpart in PyTorch is."""
attrs = dir(transformers)
failures = []
for attr in attrs:
obj = getattr(transformers, attr)
if hasattr(obj, "__module__"):
module_path = obj.__module__
module_name = module_path.split(".")[-1]
module_dir = ".".join(module_path.split(".")[:-1])
if (
module_name.startswith("modeling_")
and not module_name.startswith("modeling_tf_")
and not module_name.startswith("modeling_flax_")
):
parent_module = sys.modules[module_dir]
frameworks = []
if is_tf_available():
frameworks.append("TF")
if is_flax_available():
frameworks.append("Flax")
for framework in frameworks:
other_module_path = module_path.replace("modeling_", f"modeling_{framework.lower()}_")
if os.path.isfile("src/" + other_module_path.replace(".", "/") + ".py"):
other_module_name = module_name.replace("modeling_", f"modeling_{framework.lower()}_")
other_module = getattr(parent_module, other_module_name)
if hasattr(other_module, f"{framework}{attr}"):
if not hasattr(transformers, f"{framework}{attr}"):
if f"{framework}{attr}" not in OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK:
failures.append(f"{framework}{attr}")
if hasattr(other_module, f"{framework}_{attr}"):
if not hasattr(transformers, f"{framework}_{attr}"):
if f"{framework}_{attr}" not in OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK:
failures.append(f"{framework}_{attr}")
if len(failures) > 0: if len(failures) > 0:
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures)) raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
...@@ -1024,6 +1081,8 @@ def check_repo_quality(): ...@@ -1024,6 +1081,8 @@ def check_repo_quality():
check_all_auto_mapping_names_in_config_mapping_names() check_all_auto_mapping_names_in_config_mapping_names()
print("Checking all auto mappings could be imported.") print("Checking all auto mappings could be imported.")
check_all_auto_mappings_importable() check_all_auto_mappings_importable()
print("Checking all objects are equally (across frameworks) in the main __init__.")
check_objects_being_equally_in_main_init()
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment