check_copies.py 13.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import glob
import os
import re
20
21

import black
22
23
24
25
26


# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_copies.py
TRANSFORMERS_PATH = "src/transformers"
27
PATH_TO_DOCS = "docs/source"
28
REPO_PATH = "."
29

30
31
32
# Mapping for files that are full copies of others (keys are copies, values the file to keep them up to data with)
FULL_COPIES = {"examples/tensorflow/question-answering/utils_qa.py": "examples/pytorch/question-answering/utils_qa.py"}

33

34
35
36
37
def _should_continue(line, indent):
    return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\):\s*$", line) is not None


38
def find_code_in_transformers(object_name):
Patrick von Platen's avatar
Patrick von Platen committed
39
    """Find and return the code source code of `object_name`."""
40
41
42
43
44
45
46
    parts = object_name.split(".")
    i = 0

    # First let's find the module where our object lives.
    module = parts[i]
    while i < len(parts) and not os.path.isfile(os.path.join(TRANSFORMERS_PATH, f"{module}.py")):
        i += 1
47
48
        if i < len(parts):
            module = os.path.join(module, parts[i])
49
50
51
52
53
    if i >= len(parts):
        raise ValueError(
            f"`object_name` should begin with the name of a module of transformers but got {object_name}."
        )

54
    with open(os.path.join(TRANSFORMERS_PATH, f"{module}.py"), "r", encoding="utf-8", newline="\n") as f:
55
56
57
58
59
60
        lines = f.readlines()

    # Now let's find the class / func in the code!
    indent = ""
    line_index = 0
    for name in parts[i + 1 :]:
61
62
63
        while (
            line_index < len(lines) and re.search(fr"^{indent}(class|def)\s+{name}(\(|\:)", lines[line_index]) is None
        ):
64
65
66
67
68
69
70
71
72
            line_index += 1
        indent += "    "
        line_index += 1

    if line_index >= len(lines):
        raise ValueError(f" {object_name} does not match any function or class in {module}.")

    # We found the beginning of the class / func, now let's find the end (when the indent diminishes).
    start_index = line_index
73
    while line_index < len(lines) and _should_continue(lines[line_index], indent):
74
75
76
77
78
79
80
81
82
83
        line_index += 1
    # Clean up empty lines at the end (if any).
    while len(lines[line_index - 1]) <= 1:
        line_index -= 1

    code_lines = lines[start_index:line_index]
    return "".join(code_lines)


_re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+transformers\.(\S+\.\S+)\s*($|\S.*$)")
84
_re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)")
85
86


87
88
89
90
91
92
93
def get_indent(code):
    lines = code.split("\n")
    idx = 0
    while idx < len(lines) and len(lines[idx]) == 0:
        idx += 1
    if idx < len(lines):
        return re.search(r"^(\s*)\S", lines[idx]).groups()[0]
94
95
96
97
98
99
100
101
102
103
104
105
    return ""


def blackify(code):
    """
    Applies the black part of our `make style` command to `code`.
    """
    has_indent = len(get_indent(code)) > 0
    if has_indent:
        code = f"class Bla:\n{code}"
    result = black.format_str(code, mode=black.FileMode([black.TargetVersion.PY35], line_length=119))
    return result[len("class Bla:\n") :] if has_indent else result
106
107


108
109
110
111
112
113
def is_copy_consistent(filename, overwrite=False):
    """
    Check if the code commented as a copy in `filename` matches the original.

    Return the differences or overwrites the content depending on `overwrite`.
    """
114
    with open(filename, "r", encoding="utf-8", newline="\n") as f:
115
        lines = f.readlines()
116
    diffs = []
117
    line_index = 0
118
    # Not a for loop cause `lines` is going to change (if `overwrite=True`).
119
120
121
122
123
124
125
126
127
    while line_index < len(lines):
        search = _re_copy_warning.search(lines[line_index])
        if search is None:
            line_index += 1
            continue

        # There is some copied code here, let's retrieve the original.
        indent, object_name, replace_pattern = search.groups()
        theoretical_code = find_code_in_transformers(object_name)
128
        theoretical_indent = get_indent(theoretical_code)
129
130
131
132
133
134
135
136
137
138
139
140

        start_index = line_index + 1 if indent == theoretical_indent else line_index + 2
        indent = theoretical_indent
        line_index = start_index

        # Loop to check the observed code, stop when indentation diminishes or if we see a End copy comment.
        should_continue = True
        while line_index < len(lines) and should_continue:
            line_index += 1
            if line_index >= len(lines):
                break
            line = lines[line_index]
141
            should_continue = _should_continue(line, indent) and re.search(f"^{indent}# End copy", line) is None
142
143
144
145
146
147
148
149
150
        # Clean up empty lines at the end (if any).
        while len(lines[line_index - 1]) <= 1:
            line_index -= 1

        observed_code_lines = lines[start_index:line_index]
        observed_code = "".join(observed_code_lines)

        # Before comparing, use the `replace_pattern` on the original code.
        if len(replace_pattern) > 0:
151
152
153
154
155
156
            patterns = replace_pattern.replace("with", "").split(",")
            patterns = [_re_replace_pattern.search(p) for p in patterns]
            for pattern in patterns:
                if pattern is None:
                    continue
                obj1, obj2, option = pattern.groups()
157
                theoretical_code = re.sub(obj1, obj2, theoretical_code)
158
159
160
                if option.strip() == "all-casing":
                    theoretical_code = re.sub(obj1.lower(), obj2.lower(), theoretical_code)
                    theoretical_code = re.sub(obj1.upper(), obj2.upper(), theoretical_code)
161

162
163
164
165
166
            # Blackify after replacement. To be able to do that, we need the header (class or function definition)
            # from the previous line
            theoretical_code = blackify(lines[start_index - 1] + theoretical_code)
            theoretical_code = theoretical_code[len(lines[start_index - 1]) :]

167
168
        # Test for a diff and act accordingly.
        if observed_code != theoretical_code:
169
            diffs.append([object_name, start_index])
170
171
172
173
            if overwrite:
                lines = lines[:start_index] + [theoretical_code] + lines[line_index:]
                line_index = start_index + 1

174
    if overwrite and len(diffs) > 0:
175
176
        # Warn the user a file has been modified.
        print(f"Detected changes, rewriting {filename}.")
177
        with open(filename, "w", encoding="utf-8", newline="\n") as f:
178
            f.writelines(lines)
179
    return diffs
180
181
182
183
184
185


def check_copies(overwrite: bool = False):
    all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True)
    diffs = []
    for filename in all_files:
186
187
        new_diffs = is_copy_consistent(filename, overwrite)
        diffs += [f"- {filename}: copy does not match {d[0]} at line {d[1]}" for d in new_diffs]
188
189
190
    if not overwrite and len(diffs) > 0:
        diff = "\n".join(diffs)
        raise Exception(
191
            "Found the following copy inconsistencies:\n"
192
            + diff
193
            + "\nRun `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them."
194
        )
195
196
197
    check_model_list_copy(overwrite=overwrite)


198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
def check_full_copies(overwrite: bool = False):
    diffs = []
    for target, source in FULL_COPIES.items():
        with open(source, "r", encoding="utf-8") as f:
            source_code = f.read()
        with open(target, "r", encoding="utf-8") as f:
            target_code = f.read()
        if source_code != target_code:
            if overwrite:
                with open(target, "w", encoding="utf-8") as f:
                    print(f"Replacing the content of {target} by the one of {source}.")
                    f.write(source_code)
            else:
                diffs.append(f"- {target}: copy does not match {source}.")

    if not overwrite and len(diffs) > 0:
        diff = "\n".join(diffs)
        raise Exception(
            "Found the following copy inconsistencies:\n"
            + diff
            + "\nRun `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them."
        )


222
def get_model_list():
Patrick von Platen's avatar
Patrick von Platen committed
223
    """Extracts the model list from the README."""
224
225
226
    # If the introduction or the conclusion of the list change, the prompts may need to be updated.
    _start_prompt = "馃 Transformers currently provides the following architectures"
    _end_prompt = "1. Want to contribute a new model?"
227
    with open(os.path.join(REPO_PATH, "README.md"), "r", encoding="utf-8", newline="\n") as f:
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
        lines = f.readlines()
    # Find the start of the list.
    start_index = 0
    while not lines[start_index].startswith(_start_prompt):
        start_index += 1
    start_index += 1

    result = []
    current_line = ""
    end_index = start_index

    while not lines[end_index].startswith(_end_prompt):
        if lines[end_index].startswith("1."):
            if len(current_line) > 1:
                result.append(current_line)
            current_line = lines[end_index]
        elif len(lines[end_index]) > 1:
            current_line = f"{current_line[:-1]} {lines[end_index].lstrip()}"
        end_index += 1
    if len(current_line) > 1:
        result.append(current_line)

    return "".join(result)


def split_long_line_with_indent(line, max_per_line, indent):
Patrick von Platen's avatar
Patrick von Platen committed
254
    """Split the `line` so that it doesn't go over `max_per_line` and adds `indent` to new lines."""
255
256
257
258
259
260
261
262
263
264
265
266
267
268
    words = line.split(" ")
    lines = []
    current_line = words[0]
    for word in words[1:]:
        if len(f"{current_line} {word}") > max_per_line:
            lines.append(current_line)
            current_line = " " * indent + word
        else:
            current_line = f"{current_line} {word}"
    lines.append(current_line)
    return "\n".join(lines)


def convert_to_rst(model_list, max_per_line=None):
Patrick von Platen's avatar
Patrick von Platen committed
269
    """Convert `model_list` to rst format."""
270
    # Convert **[description](link)** to `description <link>`__
271
272
273
274
275
276
277
278
279
280
281
    def _rep_link(match):
        title, link = match.groups()
        # Keep hard links for the models not released yet
        if "master" in link or not link.startswith("https://huggingface.co/transformers"):
            return f"`{title} <{link}>`__"
        # Convert links to relative links otherwise
        else:
            link = link[len("https://huggingface.co/transformers/") : -len(".html")]
            return f":doc:`{title} <{link}>`"

    model_list = re.sub(r"\*\*\[([^\]]*)\]\(([^\)]*)\)\*\*", _rep_link, model_list)
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300

    # Convert [description](link) to `description <link>`__
    model_list = re.sub(r"\[([^\]]*)\]\(([^\)]*)\)", r"`\1 <\2>`__", model_list)

    # Enumerate the lines properly
    lines = model_list.split("\n")
    result = []
    for i, line in enumerate(lines):
        line = re.sub(r"^\s*(\d+)\.", f"{i+1}.", line)
        # Split the lines that are too long
        if max_per_line is not None and len(line) > max_per_line:
            prompt = re.search(r"^(\s*\d+\.\s+)\S", line)
            indent = len(prompt.groups()[0]) if prompt is not None else 0
            line = split_long_line_with_indent(line, max_per_line, indent)

        result.append(line)
    return "\n".join(result)


Sylvain Gugger's avatar
Sylvain Gugger committed
301
302
303
304
305
306
def _find_text_in_file(filename, start_prompt, end_prompt):
    """
    Find the text in `filename` between a line beginning with `start_prompt` and before `end_prompt`, removing empty
    lines.
    """
    with open(filename, "r", encoding="utf-8", newline="\n") as f:
307
        lines = f.readlines()
Sylvain Gugger's avatar
Sylvain Gugger committed
308
    # Find the start prompt.
309
    start_index = 0
Sylvain Gugger's avatar
Sylvain Gugger committed
310
    while not lines[start_index].startswith(start_prompt):
311
312
313
314
        start_index += 1
    start_index += 1

    end_index = start_index
Sylvain Gugger's avatar
Sylvain Gugger committed
315
    while not lines[end_index].startswith(end_prompt):
316
317
318
319
320
321
322
323
        end_index += 1
    end_index -= 1

    while len(lines[start_index]) <= 1:
        start_index += 1
    while len(lines[end_index]) <= 1:
        end_index -= 1
    end_index += 1
Sylvain Gugger's avatar
Sylvain Gugger committed
324
325
    return "".join(lines[start_index:end_index]), start_index, end_index, lines

326

Sylvain Gugger's avatar
Sylvain Gugger committed
327
def check_model_list_copy(overwrite=False, max_per_line=119):
Patrick von Platen's avatar
Patrick von Platen committed
328
    """Check the model lists in the README and index.rst are consistent and maybe `overwrite`."""
Sylvain Gugger's avatar
Sylvain Gugger committed
329
330
331
    rst_list, start_index, end_index, lines = _find_text_in_file(
        filename=os.path.join(PATH_TO_DOCS, "index.rst"),
        start_prompt="    This list is updated automatically from the README",
332
        end_prompt="Supported frameworks",
Sylvain Gugger's avatar
Sylvain Gugger committed
333
    )
334
335
336
337
338
    md_list = get_model_list()
    converted_list = convert_to_rst(md_list, max_per_line=max_per_line)

    if converted_list != rst_list:
        if overwrite:
339
            with open(os.path.join(PATH_TO_DOCS, "index.rst"), "w", encoding="utf-8", newline="\n") as f:
340
341
342
                f.writelines(lines[:start_index] + [converted_list] + lines[end_index:])
        else:
            raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
343
344
345
346
347
                "The model list in the README changed and the list in `index.rst` has not been updated. Run "
                "`make fix-copies` to fix this."
            )


348
349
350
351
352
353
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
    args = parser.parse_args()

    check_copies(args.fix_and_overwrite)
354
    check_full_copies(args.fix_and_overwrite)