style_doc.py 19.5 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Style utils for the .rst and the docstrings."""

import argparse
import os
import re
import warnings
Sylvain Gugger's avatar
Sylvain Gugger committed
21

Sylvain Gugger's avatar
Sylvain Gugger committed
22
23
24
25
26
27
28
29
30
31
import black


BLACK_AVOID_PATTERNS = {
    "===PT-TF-SPLIT===": "### PT-TF-SPLIT",
    "{processor_class}": "FakeProcessorClass",
    "{model_class}": "FakeModelClass",
    "{object_class}": "FakeObjectClass",
}

Sylvain Gugger's avatar
Sylvain Gugger committed
32
33

# Regexes
Sylvain Gugger's avatar
Sylvain Gugger committed
34
# Re pattern that catches list introduction (with potential indent)
35
_re_list = re.compile(r"^(\s*-\s+|\s*\*\s+|\s*\d+\.\s+)")
Sylvain Gugger's avatar
Sylvain Gugger committed
36
37
38
39
40
41
# Re pattern that catches code block introduction (with potentinal indent)
_re_code = re.compile(r"^(\s*)```(.*)$")
# Re pattern that catches rst args blocks of the form `Parameters:`.
_re_args = re.compile("^\s*(Args?|Arguments?|Params?|Parameters?):\s*$")
# Re pattern that catches return blocks of the form `Return:`.
_re_returns = re.compile("^\s*Returns?:\s*$")
Sylvain Gugger's avatar
Sylvain Gugger committed
42
43
# Matches the special tag to ignore some paragraphs.
_re_doc_ignore = re.compile(r"(\.\.|#)\s*docstyle-ignore")
Sylvain Gugger's avatar
Sylvain Gugger committed
44
45
# Re pattern that matches <Tip>, </Tip> and <Tip warning={true}> blocks.
_re_tip = re.compile("^\s*</?Tip(>|\s+warning={true}>)\s*$")
Sylvain Gugger's avatar
Sylvain Gugger committed
46

Sylvain Gugger's avatar
Sylvain Gugger committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
DOCTEST_PROMPTS = [">>>", "..."]


def is_empty_line(line):
    return len(line) == 0 or line.isspace()


def find_indent(line):
    """
    Returns the number of spaces that start a line indent.
    """
    search = re.search("^(\s*)(?:\S|$)", line)
    if search is None:
        return 0
    return len(search.groups()[0])
Sylvain Gugger's avatar
Sylvain Gugger committed
62
63


Sylvain Gugger's avatar
Sylvain Gugger committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def parse_code_example(code_lines):
    """
    Parses a code example

    Args:
        code_lines (`List[str]`): The code lines to parse.
        max_len (`int`): The maximum lengh per line.

    Returns:
        (List[`str`], List[`str`]): The list of code samples and the list of outputs.
    """
    has_doctest = code_lines[0][:3] in DOCTEST_PROMPTS

    code_samples = []
    outputs = []
    in_code = True
    current_bit = []

    for line in code_lines:
        if in_code and has_doctest and not is_empty_line(line) and line[:3] not in DOCTEST_PROMPTS:
            code_sample = "\n".join(current_bit)
            code_samples.append(code_sample.strip())
            in_code = False
            current_bit = []
        elif not in_code and line[:3] in DOCTEST_PROMPTS:
            output = "\n".join(current_bit)
            outputs.append(output.strip())
            in_code = True
            current_bit = []

        # Add the line without doctest prompt
        if line[:3] in DOCTEST_PROMPTS:
            line = line[4:]
        current_bit.append(line)

    # Add last sample
    if in_code:
        code_sample = "\n".join(current_bit)
        code_samples.append(code_sample.strip())
    else:
        output = "\n".join(current_bit)
        outputs.append(output.strip())

    return code_samples, outputs


def format_code_example(code: str, max_len: int, in_docstring: bool = False):
    """
    Format a code example using black. Will take into account the doctest syntax as well as any initial indentation in
    the code provided.

    Args:
        code (`str`): The code example to format.
        max_len (`int`): The maximum lengh per line.
        in_docstring (`bool`, *optional*, defaults to `False`): Whether or not the code example is inside a docstring.

    Returns:
        `str`: The formatted code.
    """
    code_lines = code.split("\n")

    # Find initial indent
    idx = 0
    while idx < len(code_lines) and is_empty_line(code_lines[idx]):
        idx += 1
    if idx >= len(code_lines):
        return "", ""
    indent = find_indent(code_lines[idx])

    # Remove the initial indent for now, we will had it back after styling.
    # Note that l[indent:] works for empty lines
    code_lines = [l[indent:] for l in code_lines[idx:]]
    has_doctest = code_lines[0][:3] in DOCTEST_PROMPTS

    code_samples, outputs = parse_code_example(code_lines)

    # Let's blackify the code! We put everything in one big text to go faster.
    delimiter = "\n\n### New code sample ###\n"
    full_code = delimiter.join(code_samples)
    line_length = max_len - indent
    if has_doctest:
        line_length -= 4

    for k, v in BLACK_AVOID_PATTERNS.items():
        full_code = full_code.replace(k, v)
    try:
        formatted_code = black.format_str(
            full_code, mode=black.FileMode([black.TargetVersion.PY37], line_length=line_length)
        )
        error = ""
    except Exception as e:
        formatted_code = full_code
        error = f"Code sample:\n{full_code}\n\nError message:\n{e}"

    # Let's get back the formatted code samples
    for k, v in BLACK_AVOID_PATTERNS.items():
        formatted_code = formatted_code.replace(v, k)
    # Triple quotes will mess docstrings.
    if in_docstring:
        formatted_code = formatted_code.replace('"""', "'''")

    code_samples = formatted_code.split(delimiter)
    # We can have one output less than code samples
    if len(outputs) == len(code_samples) - 1:
        outputs.append("")

    formatted_lines = []
    for code_sample, output in zip(code_samples, outputs):
        # black may have added some new lines, we remove them
        code_sample = code_sample.strip()
        in_triple_quotes = False
        for line in code_sample.strip().split("\n"):
            if has_doctest and not is_empty_line(line):
                prefix = "... " if line.startswith(" ") or line in [")", "]", "}"] or in_triple_quotes else ">>> "
            else:
                prefix = ""
            indent_str = "" if is_empty_line(line) else (" " * indent)
            formatted_lines.append(indent_str + prefix + line)

            if '"""' in line:
                in_triple_quotes = not in_triple_quotes

        formatted_lines.extend([" " * indent + line for line in output.split("\n")])
        if not output.endswith("===PT-TF-SPLIT==="):
            formatted_lines.append("")

    result = "\n".join(formatted_lines)
    return result.rstrip(), error


Sylvain Gugger's avatar
Sylvain Gugger committed
194
def format_text(text, max_len, prefix="", min_indent=None):
Sylvain Gugger's avatar
Sylvain Gugger committed
195
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
196
197
198
199
200
201
202
203
204
205
206
207
    Format a text in the biggest lines possible with the constraint of a maximum length and an indentation.

    Args:
        text (`str`): The text to format
        max_len (`int`): The maximum length per line to use
        prefix (`str`, *optional*, defaults to `""`): A prefix that will be added to the text.
            The prefix doesn't count toward the indent (like a - introducing a list).
        min_indent (`int`, *optional*): The minimum indent of the text.
            If not set, will default to the length of the `prefix`.

    Returns:
        `str`: The formatted text.
Sylvain Gugger's avatar
Sylvain Gugger committed
208
209
210
    """
    text = re.sub(r"\s+", " ", text)
    if min_indent is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
211
212
213
214
        if len(prefix) < min_indent:
            prefix = " " * (min_indent - len(prefix)) + prefix

    indent = " " * len(prefix)
Sylvain Gugger's avatar
Sylvain Gugger committed
215
216
217
218
219
220
221
222
223
224
225
226
227
228
    new_lines = []
    words = text.split(" ")
    current_line = f"{prefix}{words[0]}"
    for word in words[1:]:
        try_line = f"{current_line} {word}"
        if len(try_line) > max_len:
            new_lines.append(current_line)
            current_line = f"{indent}{word}"
        else:
            current_line = try_line
    new_lines.append(current_line)
    return "\n".join(new_lines)


Sylvain Gugger's avatar
Sylvain Gugger committed
229
230
231
def split_line_on_first_colon(line):
    splits = line.split(":")
    return splits[0], ":".join(splits[1:])
232
233


Sylvain Gugger's avatar
Sylvain Gugger committed
234
235
236
def style_docstring(docstring, max_len):
    """
    Style a docstring by making sure there is no useless whitespace and the maximum horizontal space is used.
237

Sylvain Gugger's avatar
Sylvain Gugger committed
238
239
240
    Args:
        docstring (`str`): The docstring to style.
        max_len (`int`): The maximum length of each line.
241

Sylvain Gugger's avatar
Sylvain Gugger committed
242
243
244
245
    Returns:
        `str`: The styled docstring
    """
    lines = docstring.split("\n")
246
247
    new_lines = []

Sylvain Gugger's avatar
Sylvain Gugger committed
248
249
250
251
252
253
    # Initialization
    current_paragraph = None
    current_indent = -1
    in_code = False
    param_indent = -1
    prefix = ""
Sylvain Gugger's avatar
Sylvain Gugger committed
254
    black_errors = []
255

Sylvain Gugger's avatar
Sylvain Gugger committed
256
257
258
259
260
261
262
263
264
265
    # Special case for docstrings that begin with continuation of Args with no Args block.
    idx = 0
    while idx < len(lines) and is_empty_line(lines[idx]):
        idx += 1
    if (
        len(lines[idx]) > 1
        and lines[idx].rstrip().endswith(":")
        and find_indent(lines[idx + 1]) > find_indent(lines[idx])
    ):
        param_indent = find_indent(lines[idx])
Sylvain Gugger's avatar
Sylvain Gugger committed
266

Sylvain Gugger's avatar
Sylvain Gugger committed
267
268
269
270
271
272
273
274
275
276
277
278
    for idx, line in enumerate(lines):
        # Doing all re searches once for the one we need to repeat.
        list_search = _re_list.search(line)
        code_search = _re_code.search(line)

        # Are we starting a new paragraph?
        # New indentation or new line:
        new_paragraph = find_indent(line) != current_indent or is_empty_line(line)
        # List item
        new_paragraph = new_paragraph or list_search is not None
        # Code block beginning
        new_paragraph = new_paragraph or code_search is not None
Sylvain Gugger's avatar
Sylvain Gugger committed
279
280
        # Beginning/end of tip
        new_paragraph = new_paragraph or _re_tip.search(line)
Sylvain Gugger's avatar
Sylvain Gugger committed
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299

        # In this case, we treat the current paragraph
        if not in_code and new_paragraph and current_paragraph is not None and len(current_paragraph) > 0:
            paragraph = " ".join(current_paragraph)
            new_lines.append(format_text(paragraph, max_len, prefix=prefix, min_indent=current_indent))
            current_paragraph = None

        if code_search is not None:
            if not in_code:
                current_paragraph = []
                current_indent = len(code_search.groups()[0])
                current_code = code_search.groups()[1]
                prefix = ""
                if current_indent < param_indent:
                    param_indent = -1
            else:
                current_indent = -1
                code = "\n".join(current_paragraph)
                if current_code in ["py", "python"]:
Sylvain Gugger's avatar
Sylvain Gugger committed
300
301
302
303
                    formatted_code, error = format_code_example(code, max_len, in_docstring=True)
                    new_lines.append(formatted_code)
                    if len(error) > 0:
                        black_errors.append(error)
Sylvain Gugger's avatar
Sylvain Gugger committed
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
                else:
                    new_lines.append(code)
                current_paragraph = None
            new_lines.append(line)
            in_code = not in_code

        elif in_code:
            current_paragraph.append(line)
        elif is_empty_line(line):
            current_paragraph = None
            current_indent = -1
            prefix = ""
            new_lines.append(line)
        elif list_search is not None:
            prefix = list_search.groups()[0]
            current_indent = len(prefix)
            current_paragraph = [line[current_indent:]]
        elif _re_args.search(line):
            new_lines.append(line)
            param_indent = find_indent(lines[idx + 1])
Sylvain Gugger's avatar
Sylvain Gugger committed
324
325
326
327
328
329
330
331
        elif _re_tip.search(line):
            # Add a new line before if not present
            if not is_empty_line(new_lines[-1]):
                new_lines.append("")
            new_lines.append(line)
            # Add a new line after if not present
            if idx < len(lines) - 1 and not is_empty_line(lines[idx + 1]):
                new_lines.append("")
Sylvain Gugger's avatar
Sylvain Gugger committed
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
        elif current_paragraph is None or find_indent(line) != current_indent:
            indent = find_indent(line)
            # Special behavior for parameters intros.
            if indent == param_indent:
                # Special rules for some docstring where the Returns blocks has the same indent as the parameters.
                if _re_returns.search(line) is not None:
                    param_indent = -1
                    new_lines.append(line)
                elif len(line) < max_len:
                    new_lines.append(line)
                else:
                    intro, description = split_line_on_first_colon(line)
                    new_lines.append(intro + ":")
                    if len(description) != 0:
                        if find_indent(lines[idx + 1]) > indent:
                            current_indent = find_indent(lines[idx + 1])
                        else:
                            current_indent = indent + 4
                        current_paragraph = [description.strip()]
                        prefix = ""
            else:
                # Check if we have exited the parameter block
                if indent < param_indent:
                    param_indent = -1

                current_paragraph = [line.strip()]
                current_indent = find_indent(line)
                prefix = ""
        elif current_paragraph is not None:
            current_paragraph.append(line.lstrip())

    if current_paragraph is not None and len(current_paragraph) > 0:
        paragraph = " ".join(current_paragraph)
        new_lines.append(format_text(paragraph, max_len, prefix=prefix, min_indent=current_indent))
Sylvain Gugger's avatar
Sylvain Gugger committed
366

Sylvain Gugger's avatar
Sylvain Gugger committed
367
    return "\n".join(new_lines), "\n\n".join(black_errors)
Sylvain Gugger's avatar
Sylvain Gugger committed
368
369


Sylvain Gugger's avatar
Sylvain Gugger committed
370
371
372
def style_file_docstrings(code_file, max_len=119, check_only=False):
    """
    Style all docstrings in  a given file.
Sylvain Gugger's avatar
Sylvain Gugger committed
373

Sylvain Gugger's avatar
Sylvain Gugger committed
374
375
376
377
378
    Args:
        code_file (`str` or `os.PathLike`): The file in which we want to style the docstring.
        max_len (`int`): The maximum number of characters per line.
        check_only (`bool`, *optional*, defaults to `False`):
            Whether to restyle file or just check if they should be restyled.
Sylvain Gugger's avatar
Sylvain Gugger committed
379

Sylvain Gugger's avatar
Sylvain Gugger committed
380
381
382
    Returns:
        `bool`: Whether or not the file was or should be restyled.
    """
383
    with open(code_file, "r", encoding="utf-8", newline="\n") as f:
Sylvain Gugger's avatar
Sylvain Gugger committed
384
        code = f.read()
385
386
    # fmt: off
    splits = code.split('\"\"\"')
Sylvain Gugger's avatar
Sylvain Gugger committed
387
388
389
390
    splits = [
        (s if i % 2 == 0 or _re_doc_ignore.search(splits[i - 1]) is not None else style_docstring(s, max_len=max_len))
        for i, s in enumerate(splits)
    ]
Sylvain Gugger's avatar
Sylvain Gugger committed
391
392
    black_errors = "\n\n".join([s[1] for s in splits if isinstance(s, tuple) and len(s[1]) > 0])
    splits = [s[0] if isinstance(s, tuple) else s for s in splits]
393
394
    clean_code = '\"\"\"'.join(splits)
    # fmt: on
Sylvain Gugger's avatar
Sylvain Gugger committed
395
396
397
398

    diff = clean_code != code
    if not check_only and diff:
        print(f"Overwriting content of {code_file}.")
399
        with open(code_file, "w", encoding="utf-8", newline="\n") as f:
Sylvain Gugger's avatar
Sylvain Gugger committed
400
401
            f.write(clean_code)

Sylvain Gugger's avatar
Sylvain Gugger committed
402
    return diff, black_errors
Sylvain Gugger's avatar
Sylvain Gugger committed
403
404


Sylvain Gugger's avatar
Sylvain Gugger committed
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
def style_mdx_file(mdx_file, max_len=119, check_only=False):
    """
    Style a MDX file by formatting all Python code samples.

    Args:
        mdx_file (`str` or `os.PathLike`): The file in which we want to style the examples.
        max_len (`int`): The maximum number of characters per line.
        check_only (`bool`, *optional*, defaults to `False`):
            Whether to restyle file or just check if they should be restyled.

    Returns:
        `bool`: Whether or not the file was or should be restyled.
    """
    with open(mdx_file, "r", encoding="utf-8", newline="\n") as f:
        content = f.read()

    lines = content.split("\n")
    current_code = []
    current_language = ""
    in_code = False
    new_lines = []
Sylvain Gugger's avatar
Sylvain Gugger committed
426
427
    black_errors = []

Sylvain Gugger's avatar
Sylvain Gugger committed
428
429
430
431
432
433
434
435
436
    for line in lines:
        if _re_code.search(line) is not None:
            in_code = not in_code
            if in_code:
                current_language = _re_code.search(line).groups()[1]
                current_code = []
            else:
                code = "\n".join(current_code)
                if current_language in ["py", "python"]:
Sylvain Gugger's avatar
Sylvain Gugger committed
437
438
439
                    code, error = format_code_example(code, max_len)
                    if len(error) > 0:
                        black_errors.append(error)
Sylvain Gugger's avatar
Sylvain Gugger committed
440
441
442
443
444
445
446
447
                new_lines.append(code)

            new_lines.append(line)
        elif in_code:
            current_code.append(line)
        else:
            new_lines.append(line)

448
449
450
    if in_code:
        raise ValueError(f"There was a problem when styling {mdx_file}. A code block is opened without being closed.")

Sylvain Gugger's avatar
Sylvain Gugger committed
451
452
453
454
455
456
457
    clean_content = "\n".join(new_lines)
    diff = clean_content != content
    if not check_only and diff:
        print(f"Overwriting content of {mdx_file}.")
        with open(mdx_file, "w", encoding="utf-8", newline="\n") as f:
            f.write(clean_content)

Sylvain Gugger's avatar
Sylvain Gugger committed
458
    return diff, "\n\n".join(black_errors)
Sylvain Gugger's avatar
Sylvain Gugger committed
459
460


Sylvain Gugger's avatar
Sylvain Gugger committed
461
462
def style_doc_files(*files, max_len=119, check_only=False):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
463
464
465
466
467
468
469
470
471
472
    Applies doc styling or checks everything is correct in a list of files.

    Args:
        files (several `str` or `os.PathLike`): The files to treat.
        max_len (`int`): The maximum number of characters per line.
        check_only (`bool`, *optional*, defaults to `False`):
            Whether to restyle file or just check if they should be restyled.

    Returns:
        List[`str`]: The list of files changed or that should be restyled.
Sylvain Gugger's avatar
Sylvain Gugger committed
473
474
    """
    changed = []
Sylvain Gugger's avatar
Sylvain Gugger committed
475
    black_errors = []
Sylvain Gugger's avatar
Sylvain Gugger committed
476
477
478
479
    for file in files:
        # Treat folders
        if os.path.isdir(file):
            files = [os.path.join(file, f) for f in os.listdir(file)]
Sylvain Gugger's avatar
Sylvain Gugger committed
480
            files = [f for f in files if os.path.isdir(f) or f.endswith(".mdx") or f.endswith(".py")]
Sylvain Gugger's avatar
Sylvain Gugger committed
481
            changed += style_doc_files(*files, max_len=max_len, check_only=check_only)
Sylvain Gugger's avatar
Sylvain Gugger committed
482
483
        # Treat mdx
        elif file.endswith(".mdx"):
Sylvain Gugger's avatar
Sylvain Gugger committed
484
485
486
487
488
489
490
491
492
493
494
            try:
                diff, black_error = style_mdx_file(file, max_len=max_len, check_only=check_only)
                if diff:
                    changed.append(file)
                if len(black_error) > 0:
                    black_errors.append(
                        f"There was a problem while formatting an example in {file} with black:\m{black_error}"
                    )
            except Exception:
                print(f"There is a problem in {file}.")
                raise
Sylvain Gugger's avatar
Sylvain Gugger committed
495
496
        # Treat python files
        elif file.endswith(".py"):
Sylvain Gugger's avatar
Sylvain Gugger committed
497
            try:
Sylvain Gugger's avatar
Sylvain Gugger committed
498
499
                diff, black_error = style_file_docstrings(file, max_len=max_len, check_only=check_only)
                if diff:
Sylvain Gugger's avatar
Sylvain Gugger committed
500
                    changed.append(file)
Sylvain Gugger's avatar
Sylvain Gugger committed
501
502
503
504
                if len(black_error) > 0:
                    black_errors.append(
                        f"There was a problem while formatting an example in {file} with black:\m{black_error}"
                    )
Sylvain Gugger's avatar
Sylvain Gugger committed
505
506
507
            except Exception:
                print(f"There is a problem in {file}.")
                raise
Sylvain Gugger's avatar
Sylvain Gugger committed
508
        else:
Sylvain Gugger's avatar
Sylvain Gugger committed
509
            warnings.warn(f"Ignoring {file} because it's not a py or an mdx file or a folder.")
Sylvain Gugger's avatar
Sylvain Gugger committed
510
511
512
513
514
515
516
517
    if len(black_errors) > 0:
        black_message = "\n\n".join(black_errors)
        raise ValueError(
            "Some code examples can't be interpreted by black, which means they aren't regular python:\n\n"
            + black_message
            + "\n\nMake sure to fix the corresponding docstring or doc file, or remove the py/python after ``` if it "
            + "was not supposed to be a Python code sample."
        )
Sylvain Gugger's avatar
Sylvain Gugger committed
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
    return changed


def main(*files, max_len=119, check_only=False):
    changed = style_doc_files(*files, max_len=max_len, check_only=check_only)
    if check_only and len(changed) > 0:
        raise ValueError(f"{len(changed)} files should be restyled!")
    elif len(changed) > 0:
        print(f"Cleaned {len(changed)} files!")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("files", nargs="+", help="The file(s) or folder(s) to restyle.")
    parser.add_argument("--max_len", type=int, help="The maximum length of lines.")
    parser.add_argument("--check_only", action="store_true", help="Whether to only check and not fix styling issues.")
    args = parser.parse_args()

    main(*args.files, max_len=args.max_len, check_only=args.check_only)