Unverified Commit 32dbb2d9 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

make style (#11442)

parent 04ab2ca6
......@@ -33,7 +33,7 @@ def _should_continue(line, indent):
def find_code_in_transformers(object_name):
""" Find and return the code source code of `object_name`."""
"""Find and return the code source code of `object_name`."""
parts = object_name.split(".")
i = 0
......@@ -193,7 +193,7 @@ def check_copies(overwrite: bool = False):
def get_model_list():
""" Extracts the model list from the README. """
"""Extracts the model list from the README."""
# If the introduction or the conclusion of the list change, the prompts may need to be updated.
_start_prompt = "🤗 Transformers currently provides the following architectures"
_end_prompt = "1. Want to contribute a new model?"
......@@ -224,7 +224,7 @@ def get_model_list():
def split_long_line_with_indent(line, max_per_line, indent):
""" Split the `line` so that it doesn't go over `max_per_line` and adds `indent` to new lines. """
"""Split the `line` so that it doesn't go over `max_per_line` and adds `indent` to new lines."""
words = line.split(" ")
lines = []
current_line = words[0]
......@@ -239,7 +239,7 @@ def split_long_line_with_indent(line, max_per_line, indent):
def convert_to_rst(model_list, max_per_line=None):
""" Convert `model_list` to rst format. """
"""Convert `model_list` to rst format."""
# Convert **[description](link)** to `description <link>`__
def _rep_link(match):
title, link = match.groups()
......@@ -298,7 +298,7 @@ def _find_text_in_file(filename, start_prompt, end_prompt):
def check_model_list_copy(overwrite=False, max_per_line=119):
""" Check the model lists in the README and index.rst are consistent and maybe `overwrite`. """
"""Check the model lists in the README and index.rst are consistent and maybe `overwrite`."""
rst_list, start_index, end_index, lines = _find_text_in_file(
filename=os.path.join(PATH_TO_DOCS, "index.rst"),
start_prompt=" This list is updated automatically from the README",
......
......@@ -65,7 +65,7 @@ def find_backend(line):
def read_init():
""" Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects. """
"""Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects."""
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
......@@ -101,7 +101,7 @@ def read_init():
def create_dummy_object(name, backend_name):
""" Create the code for the dummy object corresponding to `name`."""
"""Create the code for the dummy object corresponding to `name`."""
_pretrained = [
"Config" "ForCausalLM",
"ForConditionalGeneration",
......@@ -130,7 +130,7 @@ def create_dummy_object(name, backend_name):
def create_dummy_files():
""" Create the content of the dummy files. """
"""Create the content of the dummy files."""
backend_specific_objects = read_init()
# For special correspondence backend to module name as used in the function requires_modulename
dummy_files = {}
......@@ -146,7 +146,7 @@ def create_dummy_files():
def check_dummies(overwrite=False):
""" Check if the dummy files are up to date and maybe `overwrite` with the right content. """
"""Check if the dummy files are up to date and maybe `overwrite` with the right content."""
dummy_files = create_dummy_files()
# For special correspondence backend to shortcut as used in utils/dummy_xxx_objects.py
short_names = {"torch": "pt"}
......
......@@ -119,7 +119,7 @@ transformers = spec.loader.load_module()
# If some modeling modules should be ignored for all checks, they should be added in the nested list
# _ignore_modules of this function.
def get_model_modules():
""" Get the model modules inside the transformers library. """
"""Get the model modules inside the transformers library."""
_ignore_modules = [
"modeling_auto",
"modeling_encoder_decoder",
......@@ -151,7 +151,7 @@ def get_model_modules():
def get_models(module):
""" Get the objects in module that are models."""
"""Get the objects in module that are models."""
models = []
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel)
for attr_name in dir(module):
......@@ -166,7 +166,7 @@ def get_models(module):
# If some test_modeling files should be ignored when checking models are all tested, they should be added in the
# nested list _ignore_files of this function.
def get_model_test_files():
""" Get the model test files."""
"""Get the model test files."""
_ignore_files = [
"test_modeling_common",
"test_modeling_encoder_decoder",
......@@ -187,7 +187,7 @@ def get_model_test_files():
# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the tester class
# for the all_model_classes variable.
def find_tested_models(test_file):
""" Parse the content of test_file to detect what's in all_model_classes"""
"""Parse the content of test_file to detect what's in all_model_classes"""
# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the class
with open(os.path.join(PATH_TO_TESTS, test_file), "r", encoding="utf-8", newline="\n") as f:
content = f.read()
......@@ -205,7 +205,7 @@ def find_tested_models(test_file):
def check_models_are_tested(module, test_file):
""" Check models defined in module are tested in test_file."""
"""Check models defined in module are tested in test_file."""
defined_models = get_models(module)
tested_models = find_tested_models(test_file)
if tested_models is None:
......@@ -229,7 +229,7 @@ def check_models_are_tested(module, test_file):
def check_all_models_are_tested():
""" Check all models are properly tested."""
"""Check all models are properly tested."""
modules = get_model_modules()
test_files = get_model_test_files()
failures = []
......@@ -245,7 +245,7 @@ def check_all_models_are_tested():
def get_all_auto_configured_models():
""" Return the list of all models in at least one auto class."""
"""Return the list of all models in at least one auto class."""
result = set() # To avoid duplicates we concatenate all model classes in a set.
for attr_name in dir(transformers.models.auto.modeling_auto):
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING"):
......@@ -271,7 +271,7 @@ def ignore_unautoclassed(model_name):
def check_models_are_auto_configured(module, all_auto_models):
""" Check models defined in module are each in an auto class."""
"""Check models defined in module are each in an auto class."""
defined_models = get_models(module)
failures = []
for model_name, _ in defined_models:
......@@ -285,7 +285,7 @@ def check_models_are_auto_configured(module, all_auto_models):
def check_all_models_are_auto_configured():
""" Check all models are each in an auto class."""
"""Check all models are each in an auto class."""
modules = get_model_modules()
all_auto_models = get_all_auto_configured_models()
failures = []
......@@ -301,7 +301,7 @@ _re_decorator = re.compile(r"^\s*@(\S+)\s+$")
def check_decorator_order(filename):
""" Check that in the test file `filename` the slow decorator is always last."""
"""Check that in the test file `filename` the slow decorator is always last."""
with open(filename, "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
decorator_before = None
......@@ -319,7 +319,7 @@ def check_decorator_order(filename):
def check_all_decorator_order():
""" Check that in all test files, the slow decorator is always last."""
"""Check that in all test files, the slow decorator is always last."""
errors = []
for fname in os.listdir(PATH_TO_TESTS):
if fname.endswith(".py"):
......@@ -334,7 +334,7 @@ def check_all_decorator_order():
def find_all_documented_objects():
""" Parse the content of all doc files to detect which classes and functions it documents"""
"""Parse the content of all doc files to detect which classes and functions it documents"""
documented_obj = []
for doc_file in Path(PATH_TO_DOC).glob("**/*.rst"):
with open(doc_file, "r", encoding="utf-8", newline="\n") as f:
......@@ -454,7 +454,7 @@ def ignore_undocumented(name):
def check_all_objects_are_documented():
""" Check all models are properly documented."""
"""Check all models are properly documented."""
documented_objs = find_all_documented_objects()
modules = transformers._modules
objects = [c for c in dir(transformers) if c not in modules and not c.startswith("_")]
......@@ -467,7 +467,7 @@ def check_all_objects_are_documented():
def check_repo_quality():
""" Check all models are properly tested and documented."""
"""Check all models are properly tested and documented."""
print("Checking all models are properly tested.")
check_all_decorator_order()
check_all_models_are_tested()
......
......@@ -159,7 +159,7 @@ def get_model_table_from_auto_modules():
def check_model_table(overwrite=False):
""" Check the model table in the index.rst is consistent with the state of the lib and maybe `overwrite`. """
"""Check the model table in the index.rst is consistent with the state of the lib and maybe `overwrite`."""
current_table, start_index, end_index, lines = _find_text_in_file(
filename=os.path.join(PATH_TO_DOCS, "index.rst"),
start_prompt=" This table is updated automatically from the auto module",
......
......@@ -431,7 +431,7 @@ def _add_new_lines_before_doc_special_words(text):
def style_rst_file(doc_file, max_len=119, check_only=False):
""" Style one rst file `doc_file` to `max_len`."""
"""Style one rst file `doc_file` to `max_len`."""
with open(doc_file, "r", encoding="utf-8", newline="\n") as f:
doc = f.read()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment