".github/vscode:/vscode.git/clone" did not exist on "3512164a978bbb5ff6256fbdeb70625072784c63"
check_copies.py 8.21 KB
Newer Older
1
# coding=utf-8
Patrick von Platen's avatar
Patrick von Platen committed
2
# Copyright 2023 The HuggingFace Inc. team.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import glob
import os
import re
20
import subprocess
21
22
23
24


# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_copies.py
25
DIFFUSERS_PATH = "src/diffusers"
26
27
REPO_PATH = "."

28

29
30
31
32
def _should_continue(line, indent):
    return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None


Patrick von Platen's avatar
Patrick von Platen committed
33
def find_code_in_diffusers(object_name):
34
35
36
37
38
39
    """Find and return the code source code of `object_name`."""
    parts = object_name.split(".")
    i = 0

    # First let's find the module where our object lives.
    module = parts[i]
40
    while i < len(parts) and not os.path.isfile(os.path.join(DIFFUSERS_PATH, f"{module}.py")):
41
42
43
44
        i += 1
        if i < len(parts):
            module = os.path.join(module, parts[i])
    if i >= len(parts):
Patrick von Platen's avatar
Patrick von Platen committed
45
        raise ValueError(f"`object_name` should begin with the name of a module of diffusers but got {object_name}.")
46

47
48
49
50
51
52
    with open(
        os.path.join(DIFFUSERS_PATH, f"{module}.py"),
        "r",
        encoding="utf-8",
        newline="\n",
    ) as f:
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        lines = f.readlines()

    # Now let's find the class / func in the code!
    indent = ""
    line_index = 0
    for name in parts[i + 1 :]:
        while (
            line_index < len(lines) and re.search(rf"^{indent}(class|def)\s+{name}(\(|\:)", lines[line_index]) is None
        ):
            line_index += 1
        indent += "    "
        line_index += 1

    if line_index >= len(lines):
        raise ValueError(f" {object_name} does not match any function or class in {module}.")

    # We found the beginning of the class / func, now let's find the end (when the indent diminishes).
    start_index = line_index
    while line_index < len(lines) and _should_continue(lines[line_index], indent):
        line_index += 1
    # Clean up empty lines at the end (if any).
    while len(lines[line_index - 1]) <= 1:
        line_index -= 1

    code_lines = lines[start_index:line_index]
    return "".join(code_lines)


Patrick von Platen's avatar
Patrick von Platen committed
81
_re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+diffusers\.(\S+\.\S+)\s*($|\S.*$)")
82
_re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)")
83
_re_fill_pattern = re.compile(r"<FILL\s+[^>]*>")
84
85
86
87
88
89
90
91
92
93
94
95


def get_indent(code):
    lines = code.split("\n")
    idx = 0
    while idx < len(lines) and len(lines[idx]) == 0:
        idx += 1
    if idx < len(lines):
        return re.search(r"^(\s*)\S", lines[idx]).groups()[0]
    return ""


96
97
98
99
100
101
102
103
def run_ruff(code):
    command = ["ruff", "format", "-", "--config", "pyproject.toml", "--silent"]
    process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE)
    stdout, _ = process.communicate(input=code.encode())
    return stdout.decode()


def stylify(code: str) -> str:
104
    """
105
106
107
108
109
110
111
112
    Applies the ruff part of our `make style` command to some code. This formats the code using `ruff format`.
    As `ruff` does not provide a python api this cannot be done on the fly.

    Args:
        code (`str`): The code to format.

    Returns:
        `str`: The formatted code.
113
114
115
116
    """
    has_indent = len(get_indent(code)) > 0
    if has_indent:
        code = f"class Bla:\n{code}"
117
118
    formatted_code = run_ruff(code)
    return formatted_code[len("class Bla:\n") :] if has_indent else formatted_code
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138


def is_copy_consistent(filename, overwrite=False):
    """
    Check if the code commented as a copy in `filename` matches the original.
    Return the differences or overwrites the content depending on `overwrite`.
    """
    with open(filename, "r", encoding="utf-8", newline="\n") as f:
        lines = f.readlines()
    diffs = []
    line_index = 0
    # Not a for loop cause `lines` is going to change (if `overwrite=True`).
    while line_index < len(lines):
        search = _re_copy_warning.search(lines[line_index])
        if search is None:
            line_index += 1
            continue

        # There is some copied code here, let's retrieve the original.
        indent, object_name, replace_pattern = search.groups()
Patrick von Platen's avatar
Patrick von Platen committed
139
        theoretical_code = find_code_in_diffusers(object_name)
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        theoretical_indent = get_indent(theoretical_code)

        start_index = line_index + 1 if indent == theoretical_indent else line_index + 2
        indent = theoretical_indent
        line_index = start_index

        # Loop to check the observed code, stop when indentation diminishes or if we see a End copy comment.
        should_continue = True
        while line_index < len(lines) and should_continue:
            line_index += 1
            if line_index >= len(lines):
                break
            line = lines[line_index]
            should_continue = _should_continue(line, indent) and re.search(f"^{indent}# End copy", line) is None
        # Clean up empty lines at the end (if any).
        while len(lines[line_index - 1]) <= 1:
            line_index -= 1

        observed_code_lines = lines[start_index:line_index]
        observed_code = "".join(observed_code_lines)

161
162
163
164
        # Remove any nested `Copied from` comments to avoid circular copies
        theoretical_code = [line for line in theoretical_code.split("\n") if _re_copy_warning.search(line) is None]
        theoretical_code = "\n".join(theoretical_code)

165
166
167
168
169
170
171
172
173
174
175
176
177
        # Before comparing, use the `replace_pattern` on the original code.
        if len(replace_pattern) > 0:
            patterns = replace_pattern.replace("with", "").split(",")
            patterns = [_re_replace_pattern.search(p) for p in patterns]
            for pattern in patterns:
                if pattern is None:
                    continue
                obj1, obj2, option = pattern.groups()
                theoretical_code = re.sub(obj1, obj2, theoretical_code)
                if option.strip() == "all-casing":
                    theoretical_code = re.sub(obj1.lower(), obj2.lower(), theoretical_code)
                    theoretical_code = re.sub(obj1.upper(), obj2.upper(), theoretical_code)

178
            # stylify after replacement. To be able to do that, we need the header (class or function definition)
179
            # from the previous line
180
            theoretical_code = stylify(lines[start_index - 1] + theoretical_code)
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
            theoretical_code = theoretical_code[len(lines[start_index - 1]) :]

        # Test for a diff and act accordingly.
        if observed_code != theoretical_code:
            diffs.append([object_name, start_index])
            if overwrite:
                lines = lines[:start_index] + [theoretical_code] + lines[line_index:]
                line_index = start_index + 1

    if overwrite and len(diffs) > 0:
        # Warn the user a file has been modified.
        print(f"Detected changes, rewriting {filename}.")
        with open(filename, "w", encoding="utf-8", newline="\n") as f:
            f.writelines(lines)
    return diffs


def check_copies(overwrite: bool = False):
199
    all_files = glob.glob(os.path.join(DIFFUSERS_PATH, "**/*.py"), recursive=True)
200
201
202
203
204
205
206
207
208
209
210
    diffs = []
    for filename in all_files:
        new_diffs = is_copy_consistent(filename, overwrite)
        diffs += [f"- {filename}: copy does not match {d[0]} at line {d[1]}" for d in new_diffs]
    if not overwrite and len(diffs) > 0:
        diff = "\n".join(diffs)
        raise Exception(
            "Found the following copy inconsistencies:\n"
            + diff
            + "\nRun `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them."
        )
Patrick von Platen's avatar
Patrick von Platen committed
211
212


213
214
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
215
216
217
218
219
    parser.add_argument(
        "--fix_and_overwrite",
        action="store_true",
        help="Whether to fix inconsistencies.",
    )
220
221
222
    args = parser.parse_args()

    check_copies(args.fix_and_overwrite)