Unverified Commit 7f3b41a3 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix check repo utils (#8600)

parent f0435f5a
......@@ -49,6 +49,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
"test_modeling_mt5.py",
"test_modeling_pegasus.py",
"test_modeling_tf_camembert.py",
"test_modeling_tf_mt5.py",
"test_modeling_tf_xlm_roberta.py",
"test_modeling_xlm_prophetnet.py",
"test_modeling_xlm_roberta.py",
......@@ -62,7 +63,6 @@ IGNORE_NON_DOCUMENTED = [
"T5Stack", # Building part of bigger (tested) model.
"TFDPREncoder", # Building part of bigger (documented) model.
"TFDPRSpanPredictor", # Building part of bigger (documented) model.
"TFElectraMainLayer", # Building part of bigger (documented) model (should it be a TFPreTrainedModel ?)
]
# Update this dict with any special correspondance model name (used in modeling_xxx.py) to doc file.
......@@ -135,11 +135,15 @@ def get_model_modules():
"modeling_tf_transfo_xl_utilities",
]
modules = []
for attr_name in dir(transformers):
if attr_name.startswith("modeling") and attr_name not in _ignore_modules:
module = getattr(transformers, attr_name)
if inspect.ismodule(module):
modules.append(module)
for model in dir(transformers.models):
# There are some magic dunder attributes in the dir, we ignore them
if not model.startswith("__"):
model_module = getattr(transformers.models, model)
for submodule in dir(model_module):
if submodule.startswith("modeling") and submodule not in _ignore_modules:
modeling_module = getattr(model_module, submodule)
if inspect.ismodule(modeling_module):
modules.append(modeling_module)
return modules
......@@ -244,7 +248,7 @@ def check_all_models_are_tested():
test_files = get_model_test_files()
failures = []
for module in modules:
test_file = f"test_{module.__name__.split('.')[1]}.py"
test_file = f"test_{module.__name__.split('.')[-1]}.py"
if test_file not in test_files:
failures.append(f"{module.__name__} does not have its corresponding test file {test_file}.")
new_failures = check_models_are_tested(module, test_file)
......@@ -279,9 +283,9 @@ def check_models_are_documented(module, doc_file):
def _get_model_name(module):
""" Get the model name for the module defining it."""
splits = module.__name__.split("_")
module_name = module.__name__.split(".")[-1]
splits = module_name.split("_")
splits = splits[(2 if splits[1] in ["flax", "tf"] else 1) :]
return "_".join(splits)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment