style_doc.py 12.6 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Style utils for the .rst and the docstrings."""

import argparse
import os
import re
import warnings
Sylvain Gugger's avatar
Sylvain Gugger committed
21

Sylvain Gugger's avatar
Sylvain Gugger committed
22
23

# Regexes
Sylvain Gugger's avatar
Sylvain Gugger committed
24
# Re pattern that catches list introduction (with potential indent)
25
_re_list = re.compile(r"^(\s*-\s+|\s*\*\s+|\s*\d+\.\s+)")
Sylvain Gugger's avatar
Sylvain Gugger committed
26
27
28
29
30
31
# Re pattern that catches code block introduction (with potentinal indent)
_re_code = re.compile(r"^(\s*)```(.*)$")
# Re pattern that catches rst args blocks of the form `Parameters:`.
_re_args = re.compile("^\s*(Args?|Arguments?|Params?|Parameters?):\s*$")
# Re pattern that catches return blocks of the form `Return:`.
_re_returns = re.compile("^\s*Returns?:\s*$")
Sylvain Gugger's avatar
Sylvain Gugger committed
32
33
34
35
# Matches the special tag to ignore some paragraphs.
_re_doc_ignore = re.compile(r"(\.\.|#)\s*docstyle-ignore")


Sylvain Gugger's avatar
Sylvain Gugger committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
DOCTEST_PROMPTS = [">>>", "..."]


def is_empty_line(line):
    return len(line) == 0 or line.isspace()


def find_indent(line):
    """
    Returns the number of spaces that start a line indent.
    """
    search = re.search("^(\s*)(?:\S|$)", line)
    if search is None:
        return 0
    return len(search.groups()[0])
Sylvain Gugger's avatar
Sylvain Gugger committed
51
52


Sylvain Gugger's avatar
Sylvain Gugger committed
53
def format_text(text, max_len, prefix="", min_indent=None):
Sylvain Gugger's avatar
Sylvain Gugger committed
54
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
55
56
57
58
59
60
61
62
63
64
65
66
    Format a text in the biggest lines possible with the constraint of a maximum length and an indentation.

    Args:
        text (`str`): The text to format
        max_len (`int`): The maximum length per line to use
        prefix (`str`, *optional*, defaults to `""`): A prefix that will be added to the text.
            The prefix doesn't count toward the indent (like a - introducing a list).
        min_indent (`int`, *optional*): The minimum indent of the text.
            If not set, will default to the length of the `prefix`.

    Returns:
        `str`: The formatted text.
Sylvain Gugger's avatar
Sylvain Gugger committed
67
68
69
    """
    text = re.sub(r"\s+", " ", text)
    if min_indent is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
70
71
72
73
        if len(prefix) < min_indent:
            prefix = " " * (min_indent - len(prefix)) + prefix

    indent = " " * len(prefix)
Sylvain Gugger's avatar
Sylvain Gugger committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    new_lines = []
    words = text.split(" ")
    current_line = f"{prefix}{words[0]}"
    for word in words[1:]:
        try_line = f"{current_line} {word}"
        if len(try_line) > max_len:
            new_lines.append(current_line)
            current_line = f"{indent}{word}"
        else:
            current_line = try_line
    new_lines.append(current_line)
    return "\n".join(new_lines)


Sylvain Gugger's avatar
Sylvain Gugger committed
88
89
90
def split_line_on_first_colon(line):
    splits = line.split(":")
    return splits[0], ":".join(splits[1:])
91
92


Sylvain Gugger's avatar
Sylvain Gugger committed
93
94
95
def style_docstring(docstring, max_len):
    """
    Style a docstring by making sure there is no useless whitespace and the maximum horizontal space is used.
96

Sylvain Gugger's avatar
Sylvain Gugger committed
97
98
99
    Args:
        docstring (`str`): The docstring to style.
        max_len (`int`): The maximum length of each line.
100

Sylvain Gugger's avatar
Sylvain Gugger committed
101
102
103
104
    Returns:
        `str`: The styled docstring
    """
    lines = docstring.split("\n")
105
106
    new_lines = []

Sylvain Gugger's avatar
Sylvain Gugger committed
107
108
109
110
111
112
    # Initialization
    current_paragraph = None
    current_indent = -1
    in_code = False
    param_indent = -1
    prefix = ""
113

Sylvain Gugger's avatar
Sylvain Gugger committed
114
115
116
117
118
119
120
121
122
123
    # Special case for docstrings that begin with continuation of Args with no Args block.
    idx = 0
    while idx < len(lines) and is_empty_line(lines[idx]):
        idx += 1
    if (
        len(lines[idx]) > 1
        and lines[idx].rstrip().endswith(":")
        and find_indent(lines[idx + 1]) > find_indent(lines[idx])
    ):
        param_indent = find_indent(lines[idx])
Sylvain Gugger's avatar
Sylvain Gugger committed
124

Sylvain Gugger's avatar
Sylvain Gugger committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    for idx, line in enumerate(lines):
        # Doing all re searches once for the one we need to repeat.
        list_search = _re_list.search(line)
        code_search = _re_code.search(line)

        # Are we starting a new paragraph?
        # New indentation or new line:
        new_paragraph = find_indent(line) != current_indent or is_empty_line(line)
        # List item
        new_paragraph = new_paragraph or list_search is not None
        # Code block beginning
        new_paragraph = new_paragraph or code_search is not None

        # In this case, we treat the current paragraph
        if not in_code and new_paragraph and current_paragraph is not None and len(current_paragraph) > 0:
            paragraph = " ".join(current_paragraph)
            new_lines.append(format_text(paragraph, max_len, prefix=prefix, min_indent=current_indent))
            current_paragraph = None

        if code_search is not None:
            if not in_code:
                current_paragraph = []
                current_indent = len(code_search.groups()[0])
                current_code = code_search.groups()[1]
                prefix = ""
                if current_indent < param_indent:
                    param_indent = -1
            else:
                current_indent = -1
                code = "\n".join(current_paragraph)
                if current_code in ["py", "python"]:
                    new_lines.append(code)
                    # new_lines.append(format_code_example(code, max_len))
                else:
                    new_lines.append(code)
                current_paragraph = None
            new_lines.append(line)
            in_code = not in_code

        elif in_code:
            current_paragraph.append(line)
        elif is_empty_line(line):
            current_paragraph = None
            current_indent = -1
            prefix = ""
            new_lines.append(line)
        elif list_search is not None:
            prefix = list_search.groups()[0]
            current_indent = len(prefix)
            current_paragraph = [line[current_indent:]]
        elif _re_args.search(line):
            new_lines.append(line)
            param_indent = find_indent(lines[idx + 1])
        elif current_paragraph is None or find_indent(line) != current_indent:
            indent = find_indent(line)
            # Special behavior for parameters intros.
            if indent == param_indent:
                # Special rules for some docstring where the Returns blocks has the same indent as the parameters.
                if _re_returns.search(line) is not None:
                    param_indent = -1
                    new_lines.append(line)
                elif len(line) < max_len:
                    new_lines.append(line)
                else:
                    intro, description = split_line_on_first_colon(line)
                    new_lines.append(intro + ":")
                    if len(description) != 0:
                        if find_indent(lines[idx + 1]) > indent:
                            current_indent = find_indent(lines[idx + 1])
                        else:
                            current_indent = indent + 4
                        current_paragraph = [description.strip()]
                        prefix = ""
            else:
                # Check if we have exited the parameter block
                if indent < param_indent:
                    param_indent = -1

                current_paragraph = [line.strip()]
                current_indent = find_indent(line)
                prefix = ""
        elif current_paragraph is not None:
            current_paragraph.append(line.lstrip())

    if current_paragraph is not None and len(current_paragraph) > 0:
        paragraph = " ".join(current_paragraph)
        new_lines.append(format_text(paragraph, max_len, prefix=prefix, min_indent=current_indent))
Sylvain Gugger's avatar
Sylvain Gugger committed
212

Sylvain Gugger's avatar
Sylvain Gugger committed
213
    return "\n".join(new_lines)
Sylvain Gugger's avatar
Sylvain Gugger committed
214
215


Sylvain Gugger's avatar
Sylvain Gugger committed
216
217
218
def style_file_docstrings(code_file, max_len=119, check_only=False):
    """
    Style all docstrings in  a given file.
Sylvain Gugger's avatar
Sylvain Gugger committed
219

Sylvain Gugger's avatar
Sylvain Gugger committed
220
221
222
223
224
    Args:
        code_file (`str` or `os.PathLike`): The file in which we want to style the docstring.
        max_len (`int`): The maximum number of characters per line.
        check_only (`bool`, *optional*, defaults to `False`):
            Whether to restyle file or just check if they should be restyled.
Sylvain Gugger's avatar
Sylvain Gugger committed
225

Sylvain Gugger's avatar
Sylvain Gugger committed
226
227
228
    Returns:
        `bool`: Whether or not the file was or should be restyled.
    """
229
    with open(code_file, "r", encoding="utf-8", newline="\n") as f:
Sylvain Gugger's avatar
Sylvain Gugger committed
230
        code = f.read()
231
232
    # fmt: off
    splits = code.split('\"\"\"')
Sylvain Gugger's avatar
Sylvain Gugger committed
233
234
235
236
    splits = [
        (s if i % 2 == 0 or _re_doc_ignore.search(splits[i - 1]) is not None else style_docstring(s, max_len=max_len))
        for i, s in enumerate(splits)
    ]
237
238
    clean_code = '\"\"\"'.join(splits)
    # fmt: on
Sylvain Gugger's avatar
Sylvain Gugger committed
239
240
241
242

    diff = clean_code != code
    if not check_only and diff:
        print(f"Overwriting content of {code_file}.")
243
        with open(code_file, "w", encoding="utf-8", newline="\n") as f:
Sylvain Gugger's avatar
Sylvain Gugger committed
244
245
246
247
248
            f.write(clean_code)

    return diff


Sylvain Gugger's avatar
Sylvain Gugger committed
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
def style_mdx_file(mdx_file, max_len=119, check_only=False):
    """
    Style a MDX file by formatting all Python code samples.

    Args:
        mdx_file (`str` or `os.PathLike`): The file in which we want to style the examples.
        max_len (`int`): The maximum number of characters per line.
        check_only (`bool`, *optional*, defaults to `False`):
            Whether to restyle file or just check if they should be restyled.

    Returns:
        `bool`: Whether or not the file was or should be restyled.
    """
    with open(mdx_file, "r", encoding="utf-8", newline="\n") as f:
        content = f.read()

    lines = content.split("\n")
    current_code = []
    current_language = ""
    in_code = False
    new_lines = []
    for line in lines:
        if _re_code.search(line) is not None:
            in_code = not in_code
            if in_code:
                current_language = _re_code.search(line).groups()[1]
                current_code = []
            else:
                code = "\n".join(current_code)
                if current_language in ["py", "python"]:
                    pass
                    # code = format_code_example(code, max_len)
                new_lines.append(code)

            new_lines.append(line)
        elif in_code:
            current_code.append(line)
        else:
            new_lines.append(line)

    clean_content = "\n".join(new_lines)
    diff = clean_content != content
    if not check_only and diff:
        print(f"Overwriting content of {mdx_file}.")
        with open(mdx_file, "w", encoding="utf-8", newline="\n") as f:
            f.write(clean_content)

    return diff


Sylvain Gugger's avatar
Sylvain Gugger committed
299
300
def style_doc_files(*files, max_len=119, check_only=False):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
301
302
303
304
305
306
307
308
309
310
    Applies doc styling or checks everything is correct in a list of files.

    Args:
        files (several `str` or `os.PathLike`): The files to treat.
        max_len (`int`): The maximum number of characters per line.
        check_only (`bool`, *optional*, defaults to `False`):
            Whether to restyle file or just check if they should be restyled.

    Returns:
        List[`str`]: The list of files changed or that should be restyled.
Sylvain Gugger's avatar
Sylvain Gugger committed
311
312
313
314
315
316
317
318
    """
    changed = []
    for file in files:
        # Treat folders
        if os.path.isdir(file):
            files = [os.path.join(file, f) for f in os.listdir(file)]
            files = [f for f in files if os.path.isdir(f) or f.endswith(".rst") or f.endswith(".py")]
            changed += style_doc_files(*files, max_len=max_len, check_only=check_only)
Sylvain Gugger's avatar
Sylvain Gugger committed
319
320
321
        # Treat mdx
        elif file.endswith(".mdx"):
            if style_mdx_file(file, max_len=max_len, check_only=check_only):
Sylvain Gugger's avatar
Sylvain Gugger committed
322
323
324
                changed.append(file)
        # Treat python files
        elif file.endswith(".py"):
Sylvain Gugger's avatar
Sylvain Gugger committed
325
326
327
328
329
330
            try:
                if style_file_docstrings(file, max_len=max_len, check_only=check_only):
                    changed.append(file)
            except Exception:
                print(f"There is a problem in {file}.")
                raise
Sylvain Gugger's avatar
Sylvain Gugger committed
331
        else:
Sylvain Gugger's avatar
Sylvain Gugger committed
332
            warnings.warn(f"Ignoring {file} because it's not a py or an mdx file or a folder.")
Sylvain Gugger's avatar
Sylvain Gugger committed
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
    return changed


def main(*files, max_len=119, check_only=False):
    changed = style_doc_files(*files, max_len=max_len, check_only=check_only)
    if check_only and len(changed) > 0:
        raise ValueError(f"{len(changed)} files should be restyled!")
    elif len(changed) > 0:
        print(f"Cleaned {len(changed)} files!")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("files", nargs="+", help="The file(s) or folder(s) to restyle.")
    parser.add_argument("--max_len", type=int, help="The maximum length of lines.")
    parser.add_argument("--check_only", action="store_true", help="Whether to only check and not fix styling issues.")
    args = parser.parse_args()

    main(*args.files, max_len=args.max_len, check_only=args.check_only)