style_doc.py 20.2 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Style utils for the .rst and the docstrings."""

import argparse
import os
import re
import warnings
Sylvain Gugger's avatar
Sylvain Gugger committed
21

Sylvain Gugger's avatar
Sylvain Gugger committed
22
23
24
25
26
27
28
29
30
31
import black


BLACK_AVOID_PATTERNS = {
    "===PT-TF-SPLIT===": "### PT-TF-SPLIT",
    "{processor_class}": "FakeProcessorClass",
    "{model_class}": "FakeModelClass",
    "{object_class}": "FakeObjectClass",
}

Sylvain Gugger's avatar
Sylvain Gugger committed
32
33

# Regexes
Sylvain Gugger's avatar
Sylvain Gugger committed
34
# Re pattern that catches list introduction (with potential indent)
35
_re_list = re.compile(r"^(\s*-\s+|\s*\*\s+|\s*\d+\.\s+)")
Sylvain Gugger's avatar
Sylvain Gugger committed
36
37
38
39
40
41
# Re pattern that catches code block introduction (with potentinal indent)
_re_code = re.compile(r"^(\s*)```(.*)$")
# Re pattern that catches rst args blocks of the form `Parameters:`.
_re_args = re.compile("^\s*(Args?|Arguments?|Params?|Parameters?):\s*$")
# Re pattern that catches return blocks of the form `Return:`.
_re_returns = re.compile("^\s*Returns?:\s*$")
Sylvain Gugger's avatar
Sylvain Gugger committed
42
43
# Matches the special tag to ignore some paragraphs.
_re_doc_ignore = re.compile(r"(\.\.|#)\s*docstyle-ignore")
Sylvain Gugger's avatar
Sylvain Gugger committed
44
45
# Re pattern that matches <Tip>, </Tip> and <Tip warning={true}> blocks.
_re_tip = re.compile("^\s*</?Tip(>|\s+warning={true}>)\s*$")
Sylvain Gugger's avatar
Sylvain Gugger committed
46

Sylvain Gugger's avatar
Sylvain Gugger committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
DOCTEST_PROMPTS = [">>>", "..."]


def is_empty_line(line):
    return len(line) == 0 or line.isspace()


def find_indent(line):
    """
    Returns the number of spaces that start a line indent.
    """
    search = re.search("^(\s*)(?:\S|$)", line)
    if search is None:
        return 0
    return len(search.groups()[0])
Sylvain Gugger's avatar
Sylvain Gugger committed
62
63


Sylvain Gugger's avatar
Sylvain Gugger committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def parse_code_example(code_lines):
    """
    Parses a code example

    Args:
        code_lines (`List[str]`): The code lines to parse.
        max_len (`int`): The maximum lengh per line.

    Returns:
        (List[`str`], List[`str`]): The list of code samples and the list of outputs.
    """
    has_doctest = code_lines[0][:3] in DOCTEST_PROMPTS

    code_samples = []
    outputs = []
    in_code = True
    current_bit = []

    for line in code_lines:
        if in_code and has_doctest and not is_empty_line(line) and line[:3] not in DOCTEST_PROMPTS:
            code_sample = "\n".join(current_bit)
            code_samples.append(code_sample.strip())
            in_code = False
            current_bit = []
        elif not in_code and line[:3] in DOCTEST_PROMPTS:
            output = "\n".join(current_bit)
            outputs.append(output.strip())
            in_code = True
            current_bit = []

        # Add the line without doctest prompt
        if line[:3] in DOCTEST_PROMPTS:
            line = line[4:]
        current_bit.append(line)

    # Add last sample
    if in_code:
        code_sample = "\n".join(current_bit)
        code_samples.append(code_sample.strip())
    else:
        output = "\n".join(current_bit)
        outputs.append(output.strip())

    return code_samples, outputs


def format_code_example(code: str, max_len: int, in_docstring: bool = False):
    """
    Format a code example using black. Will take into account the doctest syntax as well as any initial indentation in
    the code provided.

    Args:
        code (`str`): The code example to format.
        max_len (`int`): The maximum lengh per line.
        in_docstring (`bool`, *optional*, defaults to `False`): Whether or not the code example is inside a docstring.

    Returns:
        `str`: The formatted code.
    """
    code_lines = code.split("\n")

    # Find initial indent
    idx = 0
    while idx < len(code_lines) and is_empty_line(code_lines[idx]):
        idx += 1
    if idx >= len(code_lines):
        return "", ""
    indent = find_indent(code_lines[idx])

    # Remove the initial indent for now, we will had it back after styling.
    # Note that l[indent:] works for empty lines
    code_lines = [l[indent:] for l in code_lines[idx:]]
    has_doctest = code_lines[0][:3] in DOCTEST_PROMPTS

    code_samples, outputs = parse_code_example(code_lines)

    # Let's blackify the code! We put everything in one big text to go faster.
    delimiter = "\n\n### New code sample ###\n"
    full_code = delimiter.join(code_samples)
    line_length = max_len - indent
    if has_doctest:
        line_length -= 4

    for k, v in BLACK_AVOID_PATTERNS.items():
        full_code = full_code.replace(k, v)
    try:
150
151
        mode = black.Mode(target_versions={black.TargetVersion.PY37}, line_length=line_length)
        formatted_code = black.format_str(full_code, mode=mode)
Sylvain Gugger's avatar
Sylvain Gugger committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
        error = ""
    except Exception as e:
        formatted_code = full_code
        error = f"Code sample:\n{full_code}\n\nError message:\n{e}"

    # Let's get back the formatted code samples
    for k, v in BLACK_AVOID_PATTERNS.items():
        formatted_code = formatted_code.replace(v, k)
    # Triple quotes will mess docstrings.
    if in_docstring:
        formatted_code = formatted_code.replace('"""', "'''")

    code_samples = formatted_code.split(delimiter)
    # We can have one output less than code samples
    if len(outputs) == len(code_samples) - 1:
        outputs.append("")

    formatted_lines = []
    for code_sample, output in zip(code_samples, outputs):
        # black may have added some new lines, we remove them
        code_sample = code_sample.strip()
        in_triple_quotes = False
174
        in_decorator = False
Sylvain Gugger's avatar
Sylvain Gugger committed
175
176
        for line in code_sample.strip().split("\n"):
            if has_doctest and not is_empty_line(line):
177
178
179
180
181
                prefix = (
                    "... "
                    if line.startswith(" ") or line in [")", "]", "}"] or in_triple_quotes or in_decorator
                    else ">>> "
                )
Sylvain Gugger's avatar
Sylvain Gugger committed
182
183
184
185
186
187
188
            else:
                prefix = ""
            indent_str = "" if is_empty_line(line) else (" " * indent)
            formatted_lines.append(indent_str + prefix + line)

            if '"""' in line:
                in_triple_quotes = not in_triple_quotes
189
190
191
192
            if line.startswith(" "):
                in_decorator = False
            if line.startswith("@"):
                in_decorator = True
Sylvain Gugger's avatar
Sylvain Gugger committed
193
194
195
196
197
198
199
200
201

        formatted_lines.extend([" " * indent + line for line in output.split("\n")])
        if not output.endswith("===PT-TF-SPLIT==="):
            formatted_lines.append("")

    result = "\n".join(formatted_lines)
    return result.rstrip(), error


Sylvain Gugger's avatar
Sylvain Gugger committed
202
def format_text(text, max_len, prefix="", min_indent=None):
Sylvain Gugger's avatar
Sylvain Gugger committed
203
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
204
205
206
207
208
209
210
211
212
213
214
215
    Format a text in the biggest lines possible with the constraint of a maximum length and an indentation.

    Args:
        text (`str`): The text to format
        max_len (`int`): The maximum length per line to use
        prefix (`str`, *optional*, defaults to `""`): A prefix that will be added to the text.
            The prefix doesn't count toward the indent (like a - introducing a list).
        min_indent (`int`, *optional*): The minimum indent of the text.
            If not set, will default to the length of the `prefix`.

    Returns:
        `str`: The formatted text.
Sylvain Gugger's avatar
Sylvain Gugger committed
216
217
218
    """
    text = re.sub(r"\s+", " ", text)
    if min_indent is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
219
220
221
222
        if len(prefix) < min_indent:
            prefix = " " * (min_indent - len(prefix)) + prefix

    indent = " " * len(prefix)
Sylvain Gugger's avatar
Sylvain Gugger committed
223
224
225
226
227
228
229
230
231
232
233
234
235
236
    new_lines = []
    words = text.split(" ")
    current_line = f"{prefix}{words[0]}"
    for word in words[1:]:
        try_line = f"{current_line} {word}"
        if len(try_line) > max_len:
            new_lines.append(current_line)
            current_line = f"{indent}{word}"
        else:
            current_line = try_line
    new_lines.append(current_line)
    return "\n".join(new_lines)


Sylvain Gugger's avatar
Sylvain Gugger committed
237
238
239
def split_line_on_first_colon(line):
    splits = line.split(":")
    return splits[0], ":".join(splits[1:])
240
241


Sylvain Gugger's avatar
Sylvain Gugger committed
242
243
244
def style_docstring(docstring, max_len):
    """
    Style a docstring by making sure there is no useless whitespace and the maximum horizontal space is used.
245

Sylvain Gugger's avatar
Sylvain Gugger committed
246
247
248
    Args:
        docstring (`str`): The docstring to style.
        max_len (`int`): The maximum length of each line.
249

Sylvain Gugger's avatar
Sylvain Gugger committed
250
251
252
253
    Returns:
        `str`: The styled docstring
    """
    lines = docstring.split("\n")
254
255
    new_lines = []

Sylvain Gugger's avatar
Sylvain Gugger committed
256
257
258
259
260
261
    # Initialization
    current_paragraph = None
    current_indent = -1
    in_code = False
    param_indent = -1
    prefix = ""
Sylvain Gugger's avatar
Sylvain Gugger committed
262
    black_errors = []
263

Sylvain Gugger's avatar
Sylvain Gugger committed
264
265
266
267
268
269
270
271
272
273
    # Special case for docstrings that begin with continuation of Args with no Args block.
    idx = 0
    while idx < len(lines) and is_empty_line(lines[idx]):
        idx += 1
    if (
        len(lines[idx]) > 1
        and lines[idx].rstrip().endswith(":")
        and find_indent(lines[idx + 1]) > find_indent(lines[idx])
    ):
        param_indent = find_indent(lines[idx])
Sylvain Gugger's avatar
Sylvain Gugger committed
274

Sylvain Gugger's avatar
Sylvain Gugger committed
275
276
277
278
279
280
281
282
283
284
285
286
    for idx, line in enumerate(lines):
        # Doing all re searches once for the one we need to repeat.
        list_search = _re_list.search(line)
        code_search = _re_code.search(line)

        # Are we starting a new paragraph?
        # New indentation or new line:
        new_paragraph = find_indent(line) != current_indent or is_empty_line(line)
        # List item
        new_paragraph = new_paragraph or list_search is not None
        # Code block beginning
        new_paragraph = new_paragraph or code_search is not None
Sylvain Gugger's avatar
Sylvain Gugger committed
287
288
        # Beginning/end of tip
        new_paragraph = new_paragraph or _re_tip.search(line)
Sylvain Gugger's avatar
Sylvain Gugger committed
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307

        # In this case, we treat the current paragraph
        if not in_code and new_paragraph and current_paragraph is not None and len(current_paragraph) > 0:
            paragraph = " ".join(current_paragraph)
            new_lines.append(format_text(paragraph, max_len, prefix=prefix, min_indent=current_indent))
            current_paragraph = None

        if code_search is not None:
            if not in_code:
                current_paragraph = []
                current_indent = len(code_search.groups()[0])
                current_code = code_search.groups()[1]
                prefix = ""
                if current_indent < param_indent:
                    param_indent = -1
            else:
                current_indent = -1
                code = "\n".join(current_paragraph)
                if current_code in ["py", "python"]:
Sylvain Gugger's avatar
Sylvain Gugger committed
308
309
310
311
                    formatted_code, error = format_code_example(code, max_len, in_docstring=True)
                    new_lines.append(formatted_code)
                    if len(error) > 0:
                        black_errors.append(error)
Sylvain Gugger's avatar
Sylvain Gugger committed
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
                else:
                    new_lines.append(code)
                current_paragraph = None
            new_lines.append(line)
            in_code = not in_code

        elif in_code:
            current_paragraph.append(line)
        elif is_empty_line(line):
            current_paragraph = None
            current_indent = -1
            prefix = ""
            new_lines.append(line)
        elif list_search is not None:
            prefix = list_search.groups()[0]
            current_indent = len(prefix)
            current_paragraph = [line[current_indent:]]
        elif _re_args.search(line):
            new_lines.append(line)
            param_indent = find_indent(lines[idx + 1])
Sylvain Gugger's avatar
Sylvain Gugger committed
332
333
334
335
336
337
338
339
        elif _re_tip.search(line):
            # Add a new line before if not present
            if not is_empty_line(new_lines[-1]):
                new_lines.append("")
            new_lines.append(line)
            # Add a new line after if not present
            if idx < len(lines) - 1 and not is_empty_line(lines[idx + 1]):
                new_lines.append("")
Sylvain Gugger's avatar
Sylvain Gugger committed
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
        elif current_paragraph is None or find_indent(line) != current_indent:
            indent = find_indent(line)
            # Special behavior for parameters intros.
            if indent == param_indent:
                # Special rules for some docstring where the Returns blocks has the same indent as the parameters.
                if _re_returns.search(line) is not None:
                    param_indent = -1
                    new_lines.append(line)
                elif len(line) < max_len:
                    new_lines.append(line)
                else:
                    intro, description = split_line_on_first_colon(line)
                    new_lines.append(intro + ":")
                    if len(description) != 0:
                        if find_indent(lines[idx + 1]) > indent:
                            current_indent = find_indent(lines[idx + 1])
                        else:
                            current_indent = indent + 4
                        current_paragraph = [description.strip()]
                        prefix = ""
            else:
                # Check if we have exited the parameter block
                if indent < param_indent:
                    param_indent = -1

                current_paragraph = [line.strip()]
                current_indent = find_indent(line)
                prefix = ""
        elif current_paragraph is not None:
            current_paragraph.append(line.lstrip())

    if current_paragraph is not None and len(current_paragraph) > 0:
        paragraph = " ".join(current_paragraph)
        new_lines.append(format_text(paragraph, max_len, prefix=prefix, min_indent=current_indent))
Sylvain Gugger's avatar
Sylvain Gugger committed
374

Sylvain Gugger's avatar
Sylvain Gugger committed
375
    return "\n".join(new_lines), "\n\n".join(black_errors)
Sylvain Gugger's avatar
Sylvain Gugger committed
376
377


378
def style_docstrings_in_code(code, max_len=119):
Sylvain Gugger's avatar
Sylvain Gugger committed
379
    """
380
    Style all docstrings in some code.
Sylvain Gugger's avatar
Sylvain Gugger committed
381

Sylvain Gugger's avatar
Sylvain Gugger committed
382
    Args:
383
        code (`str`): The code in which we want to style the docstrings.
Sylvain Gugger's avatar
Sylvain Gugger committed
384
        max_len (`int`): The maximum number of characters per line.
Sylvain Gugger's avatar
Sylvain Gugger committed
385

Sylvain Gugger's avatar
Sylvain Gugger committed
386
    Returns:
387
        `Tuple[str, str]`: A tuple with the clean code and the black errors (if any)
Sylvain Gugger's avatar
Sylvain Gugger committed
388
    """
389
390
    # fmt: off
    splits = code.split('\"\"\"')
Sylvain Gugger's avatar
Sylvain Gugger committed
391
392
393
394
    splits = [
        (s if i % 2 == 0 or _re_doc_ignore.search(splits[i - 1]) is not None else style_docstring(s, max_len=max_len))
        for i, s in enumerate(splits)
    ]
Sylvain Gugger's avatar
Sylvain Gugger committed
395
396
    black_errors = "\n\n".join([s[1] for s in splits if isinstance(s, tuple) and len(s[1]) > 0])
    splits = [s[0] if isinstance(s, tuple) else s for s in splits]
397
398
    clean_code = '\"\"\"'.join(splits)
    # fmt: on
Sylvain Gugger's avatar
Sylvain Gugger committed
399

400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
    return clean_code, black_errors


def style_file_docstrings(code_file, max_len=119, check_only=False):
    """
    Style all docstrings in a given file.

    Args:
        code_file (`str` or `os.PathLike`): The file in which we want to style the docstring.
        max_len (`int`): The maximum number of characters per line.
        check_only (`bool`, *optional*, defaults to `False`):
            Whether to restyle file or just check if they should be restyled.

    Returns:
        `bool`: Whether or not the file was or should be restyled.
    """
    with open(code_file, "r", encoding="utf-8", newline="\n") as f:
        code = f.read()

    clean_code, black_errors = style_docstrings_in_code(code, max_len=max_len)

Sylvain Gugger's avatar
Sylvain Gugger committed
421
422
423
    diff = clean_code != code
    if not check_only and diff:
        print(f"Overwriting content of {code_file}.")
424
        with open(code_file, "w", encoding="utf-8", newline="\n") as f:
Sylvain Gugger's avatar
Sylvain Gugger committed
425
426
            f.write(clean_code)

Sylvain Gugger's avatar
Sylvain Gugger committed
427
    return diff, black_errors
Sylvain Gugger's avatar
Sylvain Gugger committed
428
429


Sylvain Gugger's avatar
Sylvain Gugger committed
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
def style_mdx_file(mdx_file, max_len=119, check_only=False):
    """
    Style a MDX file by formatting all Python code samples.

    Args:
        mdx_file (`str` or `os.PathLike`): The file in which we want to style the examples.
        max_len (`int`): The maximum number of characters per line.
        check_only (`bool`, *optional*, defaults to `False`):
            Whether to restyle file or just check if they should be restyled.

    Returns:
        `bool`: Whether or not the file was or should be restyled.
    """
    with open(mdx_file, "r", encoding="utf-8", newline="\n") as f:
        content = f.read()

    lines = content.split("\n")
    current_code = []
    current_language = ""
    in_code = False
    new_lines = []
Sylvain Gugger's avatar
Sylvain Gugger committed
451
452
    black_errors = []

Sylvain Gugger's avatar
Sylvain Gugger committed
453
454
455
456
457
458
459
460
461
    for line in lines:
        if _re_code.search(line) is not None:
            in_code = not in_code
            if in_code:
                current_language = _re_code.search(line).groups()[1]
                current_code = []
            else:
                code = "\n".join(current_code)
                if current_language in ["py", "python"]:
Sylvain Gugger's avatar
Sylvain Gugger committed
462
463
464
                    code, error = format_code_example(code, max_len)
                    if len(error) > 0:
                        black_errors.append(error)
Sylvain Gugger's avatar
Sylvain Gugger committed
465
466
467
468
469
470
471
472
                new_lines.append(code)

            new_lines.append(line)
        elif in_code:
            current_code.append(line)
        else:
            new_lines.append(line)

473
474
475
    if in_code:
        raise ValueError(f"There was a problem when styling {mdx_file}. A code block is opened without being closed.")

Sylvain Gugger's avatar
Sylvain Gugger committed
476
477
478
479
480
481
482
    clean_content = "\n".join(new_lines)
    diff = clean_content != content
    if not check_only and diff:
        print(f"Overwriting content of {mdx_file}.")
        with open(mdx_file, "w", encoding="utf-8", newline="\n") as f:
            f.write(clean_content)

Sylvain Gugger's avatar
Sylvain Gugger committed
483
    return diff, "\n\n".join(black_errors)
Sylvain Gugger's avatar
Sylvain Gugger committed
484
485


Sylvain Gugger's avatar
Sylvain Gugger committed
486
487
def style_doc_files(*files, max_len=119, check_only=False):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
488
489
490
491
492
493
494
495
496
497
    Applies doc styling or checks everything is correct in a list of files.

    Args:
        files (several `str` or `os.PathLike`): The files to treat.
        max_len (`int`): The maximum number of characters per line.
        check_only (`bool`, *optional*, defaults to `False`):
            Whether to restyle file or just check if they should be restyled.

    Returns:
        List[`str`]: The list of files changed or that should be restyled.
Sylvain Gugger's avatar
Sylvain Gugger committed
498
499
    """
    changed = []
Sylvain Gugger's avatar
Sylvain Gugger committed
500
    black_errors = []
Sylvain Gugger's avatar
Sylvain Gugger committed
501
502
503
504
    for file in files:
        # Treat folders
        if os.path.isdir(file):
            files = [os.path.join(file, f) for f in os.listdir(file)]
Sylvain Gugger's avatar
Sylvain Gugger committed
505
            files = [f for f in files if os.path.isdir(f) or f.endswith(".mdx") or f.endswith(".py")]
Sylvain Gugger's avatar
Sylvain Gugger committed
506
            changed += style_doc_files(*files, max_len=max_len, check_only=check_only)
Sylvain Gugger's avatar
Sylvain Gugger committed
507
508
        # Treat mdx
        elif file.endswith(".mdx"):
Sylvain Gugger's avatar
Sylvain Gugger committed
509
510
511
512
513
514
515
516
517
518
519
            try:
                diff, black_error = style_mdx_file(file, max_len=max_len, check_only=check_only)
                if diff:
                    changed.append(file)
                if len(black_error) > 0:
                    black_errors.append(
                        f"There was a problem while formatting an example in {file} with black:\m{black_error}"
                    )
            except Exception:
                print(f"There is a problem in {file}.")
                raise
Sylvain Gugger's avatar
Sylvain Gugger committed
520
521
        # Treat python files
        elif file.endswith(".py"):
Sylvain Gugger's avatar
Sylvain Gugger committed
522
            try:
Sylvain Gugger's avatar
Sylvain Gugger committed
523
524
                diff, black_error = style_file_docstrings(file, max_len=max_len, check_only=check_only)
                if diff:
Sylvain Gugger's avatar
Sylvain Gugger committed
525
                    changed.append(file)
Sylvain Gugger's avatar
Sylvain Gugger committed
526
527
528
529
                if len(black_error) > 0:
                    black_errors.append(
                        f"There was a problem while formatting an example in {file} with black:\m{black_error}"
                    )
Sylvain Gugger's avatar
Sylvain Gugger committed
530
531
532
            except Exception:
                print(f"There is a problem in {file}.")
                raise
Sylvain Gugger's avatar
Sylvain Gugger committed
533
        else:
Sylvain Gugger's avatar
Sylvain Gugger committed
534
            warnings.warn(f"Ignoring {file} because it's not a py or an mdx file or a folder.")
Sylvain Gugger's avatar
Sylvain Gugger committed
535
536
537
538
539
540
541
542
    if len(black_errors) > 0:
        black_message = "\n\n".join(black_errors)
        raise ValueError(
            "Some code examples can't be interpreted by black, which means they aren't regular python:\n\n"
            + black_message
            + "\n\nMake sure to fix the corresponding docstring or doc file, or remove the py/python after ``` if it "
            + "was not supposed to be a Python code sample."
        )
Sylvain Gugger's avatar
Sylvain Gugger committed
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
    return changed


def main(*files, max_len=119, check_only=False):
    changed = style_doc_files(*files, max_len=max_len, check_only=check_only)
    if check_only and len(changed) > 0:
        raise ValueError(f"{len(changed)} files should be restyled!")
    elif len(changed) > 0:
        print(f"Cleaned {len(changed)} files!")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("files", nargs="+", help="The file(s) or folder(s) to restyle.")
    parser.add_argument("--max_len", type=int, help="The maximum length of lines.")
    parser.add_argument("--check_only", action="store_true", help="Whether to only check and not fix styling issues.")
    args = parser.parse_args()

    main(*args.files, max_len=args.max_len, check_only=args.check_only)