Unverified Commit b5e2b183 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Doc styler examples (#14953)

* Fix bad examples

* Add black formatting to style_doc

* Use first nonempty line

* Put it at the right place

* Don't add spaces to empty lines

* Better templates

* Deal with triple quotes in docstrings

* Result of style_doc

* Enable mdx treatment and fix code examples in MDXs

* Result of doc styler on doc source files

* Last fixes

* Break copy from
parent e13f72fb
...@@ -433,6 +433,7 @@ class Adafactor(Optimizer): ...@@ -433,6 +433,7 @@ class Adafactor(Optimizer):
```python ```python
from transformers.optimization import Adafactor, AdafactorSchedule from transformers.optimization import Adafactor, AdafactorSchedule
optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
lr_scheduler = AdafactorSchedule(optimizer) lr_scheduler = AdafactorSchedule(optimizer)
trainer = Trainer(..., optimizers=(optimizer, lr_scheduler)) trainer = Trainer(..., optimizers=(optimizer, lr_scheduler))
...@@ -452,7 +453,7 @@ class Adafactor(Optimizer): ...@@ -452,7 +453,7 @@ class Adafactor(Optimizer):
weight_decay=0.0, weight_decay=0.0,
relative_step=False, relative_step=False,
scale_parameter=False, scale_parameter=False,
warmup_init=False warmup_init=False,
) )
```""" ```"""
......
...@@ -469,15 +469,15 @@ def pipeline( ...@@ -469,15 +469,15 @@ def pipeline(
>>> from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer >>> from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer
>>> # Sentiment analysis pipeline >>> # Sentiment analysis pipeline
>>> pipeline('sentiment-analysis') >>> pipeline("sentiment-analysis")
>>> # Question answering pipeline, specifying the checkpoint identifier >>> # Question answering pipeline, specifying the checkpoint identifier
>>> pipeline('question-answering', model='distilbert-base-cased-distilled-squad', tokenizer='bert-base-cased') >>> pipeline("question-answering", model="distilbert-base-cased-distilled-squad", tokenizer="bert-base-cased")
>>> # Named entity recognition pipeline, passing in a specific model and tokenizer >>> # Named entity recognition pipeline, passing in a specific model and tokenizer
>>> model = AutoModelForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english") >>> model = AutoModelForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") >>> tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
>>> pipeline('ner', model=model, tokenizer=tokenizer) >>> pipeline("ner", model=model, tokenizer=tokenizer)
```""" ```"""
if model_kwargs is None: if model_kwargs is None:
model_kwargs = {} model_kwargs = {}
......
...@@ -272,6 +272,7 @@ class TableQuestionAnsweringPipeline(Pipeline): ...@@ -272,6 +272,7 @@ class TableQuestionAnsweringPipeline(Pipeline):
```python ```python
import pandas as pd import pandas as pd
table = pd.DataFrame.from_dict(data) table = pd.DataFrame.from_dict(data)
``` ```
......
...@@ -709,6 +709,7 @@ class CaptureStd: ...@@ -709,6 +709,7 @@ class CaptureStd:
# to capture stderr only with auto-replay # to capture stderr only with auto-replay
import sys import sys
with CaptureStderr() as cs: with CaptureStderr() as cs:
print("Warning: ", file=sys.stderr) print("Warning: ", file=sys.stderr)
assert "Warning" in cs.err assert "Warning" in cs.err
...@@ -826,7 +827,7 @@ class CaptureLogger: ...@@ -826,7 +827,7 @@ class CaptureLogger:
>>> logger = logging.get_logger("transformers.models.bart.tokenization_bart") >>> logger = logging.get_logger("transformers.models.bart.tokenization_bart")
>>> with CaptureLogger(logger) as cl: >>> with CaptureLogger(logger) as cl:
... logger.info(msg) ... logger.info(msg)
>>> assert cl.out, msg+"\n" >>> assert cl.out, msg + "\n"
``` ```
""" """
...@@ -878,8 +879,8 @@ def ExtendSysPath(path: Union[str, os.PathLike]) -> Iterator[None]: ...@@ -878,8 +879,8 @@ def ExtendSysPath(path: Union[str, os.PathLike]) -> Iterator[None]:
Usage : Usage :
```python ```python
with ExtendSysPath('/path/to/dir'): with ExtendSysPath("/path/to/dir"):
mymodule = importlib.import_module('mymodule') mymodule = importlib.import_module("mymodule")
``` ```
""" """
......
...@@ -73,6 +73,7 @@ class Trie: ...@@ -73,6 +73,7 @@ class Trie:
>>> trie.add("Hello 友達") >>> trie.add("Hello 友達")
>>> trie.data >>> trie.data
{"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}} {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}}
>>> trie.add("Hello") >>> trie.add("Hello")
>>> trie.data >>> trie.data
{"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}} {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}}
...@@ -100,6 +101,7 @@ class Trie: ...@@ -100,6 +101,7 @@ class Trie:
>>> trie = Trie() >>> trie = Trie()
>>> trie.split("[CLS] This is a extra_id_100") >>> trie.split("[CLS] This is a extra_id_100")
["[CLS] This is a extra_id_100"] ["[CLS] This is a extra_id_100"]
>>> trie.add("[CLS]") >>> trie.add("[CLS]")
>>> trie.add("extra_id_1") >>> trie.add("extra_id_1")
>>> trie.add("extra_id_100") >>> trie.add("extra_id_100")
...@@ -393,11 +395,11 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): ...@@ -393,11 +395,11 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
```python ```python
# Let's see how to increase the vocabulary of Bert model and tokenizer # Let's see how to increase the vocabulary of Bert model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained('bert-base-uncased') model = BertModel.from_pretrained("bert-base-uncased")
num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2']) num_added_toks = tokenizer.add_tokens(["new_tok1", "my_new-tok2"])
print('We have added', num_added_toks, 'tokens') print("We have added", num_added_toks, "tokens")
# Note: resize_token_embeddings expects to receive the full size of the new vocabulary, i.e. the length of the tokenizer. # Note: resize_token_embeddings expects to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
model.resize_token_embeddings(len(tokenizer)) model.resize_token_embeddings(len(tokenizer))
```""" ```"""
......
...@@ -862,17 +862,17 @@ class SpecialTokensMixin: ...@@ -862,17 +862,17 @@ class SpecialTokensMixin:
```python ```python
# Let's see how to add a new classification token to GPT-2 # Let's see how to add a new classification token to GPT-2
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2Model.from_pretrained('gpt2') model = GPT2Model.from_pretrained("gpt2")
special_tokens_dict = {'cls_token': '<CLS>'} special_tokens_dict = {"cls_token": "<CLS>"}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
print('We have added', num_added_toks, 'tokens') print("We have added", num_added_toks, "tokens")
# Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer. # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer.
model.resize_token_embeddings(len(tokenizer)) model.resize_token_embeddings(len(tokenizer))
assert tokenizer.cls_token == '<CLS>' assert tokenizer.cls_token == "<CLS>"
```""" ```"""
if not special_tokens_dict: if not special_tokens_dict:
return 0 return 0
...@@ -929,11 +929,11 @@ class SpecialTokensMixin: ...@@ -929,11 +929,11 @@ class SpecialTokensMixin:
```python ```python
# Let's see how to increase the vocabulary of Bert model and tokenizer # Let's see how to increase the vocabulary of Bert model and tokenizer
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained('bert-base-uncased') model = BertModel.from_pretrained("bert-base-uncased")
num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2']) num_added_toks = tokenizer.add_tokens(["new_tok1", "my_new-tok2"])
print('We have added', num_added_toks, 'tokens') print("We have added", num_added_toks, "tokens")
# Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer. # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer.
model.resize_token_embeddings(len(tokenizer)) model.resize_token_embeddings(len(tokenizer))
```""" ```"""
...@@ -1585,22 +1585,22 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1585,22 +1585,22 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
```python ```python
# We can't instantiate directly the base class *PreTrainedTokenizerBase* so let's show our examples on a derived class: BertTokenizer # We can't instantiate directly the base class *PreTrainedTokenizerBase* so let's show our examples on a derived class: BertTokenizer
# Download vocabulary from huggingface.co and cache. # Download vocabulary from huggingface.co and cache.
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# Download vocabulary from huggingface.co (user-uploaded) and cache. # Download vocabulary from huggingface.co (user-uploaded) and cache.
tokenizer = BertTokenizer.from_pretrained('dbmdz/bert-base-german-cased') tokenizer = BertTokenizer.from_pretrained("dbmdz/bert-base-german-cased")
# If vocabulary files are in a directory (e.g. tokenizer was saved using *save_pretrained('./test/saved_model/')*) # If vocabulary files are in a directory (e.g. tokenizer was saved using *save_pretrained('./test/saved_model/')*)
tokenizer = BertTokenizer.from_pretrained('./test/saved_model/') tokenizer = BertTokenizer.from_pretrained("./test/saved_model/")
# If the tokenizer uses a single vocabulary file, you can point directly to this file # If the tokenizer uses a single vocabulary file, you can point directly to this file
tokenizer = BertTokenizer.from_pretrained('./test/saved_model/my_vocab.txt') tokenizer = BertTokenizer.from_pretrained("./test/saved_model/my_vocab.txt")
# You can link tokens to special vocabulary when instantiating # You can link tokens to special vocabulary when instantiating
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', unk_token='<unk>') tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", unk_token="<unk>")
# You should be sure '<unk>' is in the vocabulary when doing that. # You should be sure '<unk>' is in the vocabulary when doing that.
# Otherwise use tokenizer.add_special_tokens({'unk_token': '<unk>'}) instead) # Otherwise use tokenizer.add_special_tokens({'unk_token': '<unk>'}) instead)
assert tokenizer.unk_token == '<unk>' assert tokenizer.unk_token == "<unk>"
```""" ```"""
cache_dir = kwargs.pop("cache_dir", None) cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
......
...@@ -201,7 +201,6 @@ class TrainerCallback: ...@@ -201,7 +201,6 @@ class TrainerCallback:
```python ```python
class PrinterCallback(TrainerCallback): class PrinterCallback(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs): def on_log(self, args, state, control, logs=None, **kwargs):
_ = logs.pop("total_flos", None) _ = logs.pop("total_flos", None)
if state.is_local_process_zero: if state.is_local_process_zero:
......
...@@ -870,7 +870,7 @@ def log_metrics(self, split, metrics): ...@@ -870,7 +870,7 @@ def log_metrics(self, split, metrics):
Now when this method is run, you will see a report that will include: : Now when this method is run, you will see a report that will include: :
```python ```
init_mem_cpu_alloc_delta = 1301MB init_mem_cpu_alloc_delta = 1301MB
init_mem_cpu_peaked_delta = 154MB init_mem_cpu_peaked_delta = 154MB
init_mem_gpu_alloc_delta = 230MB init_mem_gpu_alloc_delta = 230MB
......
...@@ -300,7 +300,7 @@ class TrainerMemoryTracker: ...@@ -300,7 +300,7 @@ class TrainerMemoryTracker:
```python ```python
self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics) self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
self._memory_tracker.start() self._memory_tracker.start()
code ... # code ...
metrics = {"train_runtime": 10.5} metrics = {"train_runtime": 10.5}
self._memory_tracker.stop_and_update_metrics(metrics) self._memory_tracker.stop_and_update_metrics(metrics)
``` ```
......
...@@ -526,6 +526,7 @@ def symbolic_trace( ...@@ -526,6 +526,7 @@ def symbolic_trace(
```python ```python
from transformers.utils.fx import symbolic_trace from transformers.utils.fx import symbolic_trace
traced_model = symbolic_trace( traced_model = symbolic_trace(
model, model,
input_names=["input_ids", "attention_mask", "token_type_ids"], input_names=["input_ids", "attention_mask", "token_type_ids"],
......
...@@ -19,6 +19,16 @@ import os ...@@ -19,6 +19,16 @@ import os
import re import re
import warnings import warnings
import black
BLACK_AVOID_PATTERNS = {
"===PT-TF-SPLIT===": "### PT-TF-SPLIT",
"{processor_class}": "FakeProcessorClass",
"{model_class}": "FakeModelClass",
"{object_class}": "FakeObjectClass",
}
# Regexes # Regexes
# Re pattern that catches list introduction (with potential indent) # Re pattern that catches list introduction (with potential indent)
...@@ -50,6 +60,136 @@ def find_indent(line): ...@@ -50,6 +60,136 @@ def find_indent(line):
return len(search.groups()[0]) return len(search.groups()[0])
def parse_code_example(code_lines):
"""
Parses a code example
Args:
code_lines (`List[str]`): The code lines to parse.
max_len (`int`): The maximum lengh per line.
Returns:
(List[`str`], List[`str`]): The list of code samples and the list of outputs.
"""
has_doctest = code_lines[0][:3] in DOCTEST_PROMPTS
code_samples = []
outputs = []
in_code = True
current_bit = []
for line in code_lines:
if in_code and has_doctest and not is_empty_line(line) and line[:3] not in DOCTEST_PROMPTS:
code_sample = "\n".join(current_bit)
code_samples.append(code_sample.strip())
in_code = False
current_bit = []
elif not in_code and line[:3] in DOCTEST_PROMPTS:
output = "\n".join(current_bit)
outputs.append(output.strip())
in_code = True
current_bit = []
# Add the line without doctest prompt
if line[:3] in DOCTEST_PROMPTS:
line = line[4:]
current_bit.append(line)
# Add last sample
if in_code:
code_sample = "\n".join(current_bit)
code_samples.append(code_sample.strip())
else:
output = "\n".join(current_bit)
outputs.append(output.strip())
return code_samples, outputs
def format_code_example(code: str, max_len: int, in_docstring: bool = False):
"""
Format a code example using black. Will take into account the doctest syntax as well as any initial indentation in
the code provided.
Args:
code (`str`): The code example to format.
max_len (`int`): The maximum lengh per line.
in_docstring (`bool`, *optional*, defaults to `False`): Whether or not the code example is inside a docstring.
Returns:
`str`: The formatted code.
"""
code_lines = code.split("\n")
# Find initial indent
idx = 0
while idx < len(code_lines) and is_empty_line(code_lines[idx]):
idx += 1
if idx >= len(code_lines):
return "", ""
indent = find_indent(code_lines[idx])
# Remove the initial indent for now, we will had it back after styling.
# Note that l[indent:] works for empty lines
code_lines = [l[indent:] for l in code_lines[idx:]]
has_doctest = code_lines[0][:3] in DOCTEST_PROMPTS
code_samples, outputs = parse_code_example(code_lines)
# Let's blackify the code! We put everything in one big text to go faster.
delimiter = "\n\n### New code sample ###\n"
full_code = delimiter.join(code_samples)
line_length = max_len - indent
if has_doctest:
line_length -= 4
for k, v in BLACK_AVOID_PATTERNS.items():
full_code = full_code.replace(k, v)
try:
formatted_code = black.format_str(
full_code, mode=black.FileMode([black.TargetVersion.PY37], line_length=line_length)
)
error = ""
except Exception as e:
formatted_code = full_code
error = f"Code sample:\n{full_code}\n\nError message:\n{e}"
# Let's get back the formatted code samples
for k, v in BLACK_AVOID_PATTERNS.items():
formatted_code = formatted_code.replace(v, k)
# Triple quotes will mess docstrings.
if in_docstring:
formatted_code = formatted_code.replace('"""', "'''")
code_samples = formatted_code.split(delimiter)
# We can have one output less than code samples
if len(outputs) == len(code_samples) - 1:
outputs.append("")
formatted_lines = []
for code_sample, output in zip(code_samples, outputs):
# black may have added some new lines, we remove them
code_sample = code_sample.strip()
in_triple_quotes = False
for line in code_sample.strip().split("\n"):
if has_doctest and not is_empty_line(line):
prefix = "... " if line.startswith(" ") or line in [")", "]", "}"] or in_triple_quotes else ">>> "
else:
prefix = ""
indent_str = "" if is_empty_line(line) else (" " * indent)
formatted_lines.append(indent_str + prefix + line)
if '"""' in line:
in_triple_quotes = not in_triple_quotes
formatted_lines.extend([" " * indent + line for line in output.split("\n")])
if not output.endswith("===PT-TF-SPLIT==="):
formatted_lines.append("")
result = "\n".join(formatted_lines)
return result.rstrip(), error
def format_text(text, max_len, prefix="", min_indent=None): def format_text(text, max_len, prefix="", min_indent=None):
""" """
Format a text in the biggest lines possible with the constraint of a maximum length and an indentation. Format a text in the biggest lines possible with the constraint of a maximum length and an indentation.
...@@ -110,6 +250,7 @@ def style_docstring(docstring, max_len): ...@@ -110,6 +250,7 @@ def style_docstring(docstring, max_len):
in_code = False in_code = False
param_indent = -1 param_indent = -1
prefix = "" prefix = ""
black_errors = []
# Special case for docstrings that begin with continuation of Args with no Args block. # Special case for docstrings that begin with continuation of Args with no Args block.
idx = 0 idx = 0
...@@ -153,8 +294,10 @@ def style_docstring(docstring, max_len): ...@@ -153,8 +294,10 @@ def style_docstring(docstring, max_len):
current_indent = -1 current_indent = -1
code = "\n".join(current_paragraph) code = "\n".join(current_paragraph)
if current_code in ["py", "python"]: if current_code in ["py", "python"]:
new_lines.append(code) formatted_code, error = format_code_example(code, max_len, in_docstring=True)
# new_lines.append(format_code_example(code, max_len)) new_lines.append(formatted_code)
if len(error) > 0:
black_errors.append(error)
else: else:
new_lines.append(code) new_lines.append(code)
current_paragraph = None current_paragraph = None
...@@ -210,7 +353,7 @@ def style_docstring(docstring, max_len): ...@@ -210,7 +353,7 @@ def style_docstring(docstring, max_len):
paragraph = " ".join(current_paragraph) paragraph = " ".join(current_paragraph)
new_lines.append(format_text(paragraph, max_len, prefix=prefix, min_indent=current_indent)) new_lines.append(format_text(paragraph, max_len, prefix=prefix, min_indent=current_indent))
return "\n".join(new_lines) return "\n".join(new_lines), "\n\n".join(black_errors)
def style_file_docstrings(code_file, max_len=119, check_only=False): def style_file_docstrings(code_file, max_len=119, check_only=False):
...@@ -234,6 +377,8 @@ def style_file_docstrings(code_file, max_len=119, check_only=False): ...@@ -234,6 +377,8 @@ def style_file_docstrings(code_file, max_len=119, check_only=False):
(s if i % 2 == 0 or _re_doc_ignore.search(splits[i - 1]) is not None else style_docstring(s, max_len=max_len)) (s if i % 2 == 0 or _re_doc_ignore.search(splits[i - 1]) is not None else style_docstring(s, max_len=max_len))
for i, s in enumerate(splits) for i, s in enumerate(splits)
] ]
black_errors = "\n\n".join([s[1] for s in splits if isinstance(s, tuple) and len(s[1]) > 0])
splits = [s[0] if isinstance(s, tuple) else s for s in splits]
clean_code = '\"\"\"'.join(splits) clean_code = '\"\"\"'.join(splits)
# fmt: on # fmt: on
...@@ -243,7 +388,7 @@ def style_file_docstrings(code_file, max_len=119, check_only=False): ...@@ -243,7 +388,7 @@ def style_file_docstrings(code_file, max_len=119, check_only=False):
with open(code_file, "w", encoding="utf-8", newline="\n") as f: with open(code_file, "w", encoding="utf-8", newline="\n") as f:
f.write(clean_code) f.write(clean_code)
return diff return diff, black_errors
def style_mdx_file(mdx_file, max_len=119, check_only=False): def style_mdx_file(mdx_file, max_len=119, check_only=False):
...@@ -267,6 +412,8 @@ def style_mdx_file(mdx_file, max_len=119, check_only=False): ...@@ -267,6 +412,8 @@ def style_mdx_file(mdx_file, max_len=119, check_only=False):
current_language = "" current_language = ""
in_code = False in_code = False
new_lines = [] new_lines = []
black_errors = []
for line in lines: for line in lines:
if _re_code.search(line) is not None: if _re_code.search(line) is not None:
in_code = not in_code in_code = not in_code
...@@ -276,8 +423,9 @@ def style_mdx_file(mdx_file, max_len=119, check_only=False): ...@@ -276,8 +423,9 @@ def style_mdx_file(mdx_file, max_len=119, check_only=False):
else: else:
code = "\n".join(current_code) code = "\n".join(current_code)
if current_language in ["py", "python"]: if current_language in ["py", "python"]:
pass code, error = format_code_example(code, max_len)
# code = format_code_example(code, max_len) if len(error) > 0:
black_errors.append(error)
new_lines.append(code) new_lines.append(code)
new_lines.append(line) new_lines.append(line)
...@@ -293,7 +441,7 @@ def style_mdx_file(mdx_file, max_len=119, check_only=False): ...@@ -293,7 +441,7 @@ def style_mdx_file(mdx_file, max_len=119, check_only=False):
with open(mdx_file, "w", encoding="utf-8", newline="\n") as f: with open(mdx_file, "w", encoding="utf-8", newline="\n") as f:
f.write(clean_content) f.write(clean_content)
return diff return diff, "\n\n".join(black_errors)
def style_doc_files(*files, max_len=119, check_only=False): def style_doc_files(*files, max_len=119, check_only=False):
...@@ -310,26 +458,49 @@ def style_doc_files(*files, max_len=119, check_only=False): ...@@ -310,26 +458,49 @@ def style_doc_files(*files, max_len=119, check_only=False):
List[`str`]: The list of files changed or that should be restyled. List[`str`]: The list of files changed or that should be restyled.
""" """
changed = [] changed = []
black_errors = []
for file in files: for file in files:
# Treat folders # Treat folders
if os.path.isdir(file): if os.path.isdir(file):
files = [os.path.join(file, f) for f in os.listdir(file)] files = [os.path.join(file, f) for f in os.listdir(file)]
files = [f for f in files if os.path.isdir(f) or f.endswith(".rst") or f.endswith(".py")] files = [f for f in files if os.path.isdir(f) or f.endswith(".mdx") or f.endswith(".py")]
changed += style_doc_files(*files, max_len=max_len, check_only=check_only) changed += style_doc_files(*files, max_len=max_len, check_only=check_only)
# Treat mdx # Treat mdx
elif file.endswith(".mdx"): elif file.endswith(".mdx"):
if style_mdx_file(file, max_len=max_len, check_only=check_only): try:
diff, black_error = style_mdx_file(file, max_len=max_len, check_only=check_only)
if diff:
changed.append(file) changed.append(file)
if len(black_error) > 0:
black_errors.append(
f"There was a problem while formatting an example in {file} with black:\m{black_error}"
)
except Exception:
print(f"There is a problem in {file}.")
raise
# Treat python files # Treat python files
elif file.endswith(".py"): elif file.endswith(".py"):
try: try:
if style_file_docstrings(file, max_len=max_len, check_only=check_only): diff, black_error = style_file_docstrings(file, max_len=max_len, check_only=check_only)
if diff:
changed.append(file) changed.append(file)
if len(black_error) > 0:
black_errors.append(
f"There was a problem while formatting an example in {file} with black:\m{black_error}"
)
except Exception: except Exception:
print(f"There is a problem in {file}.") print(f"There is a problem in {file}.")
raise raise
else: else:
warnings.warn(f"Ignoring {file} because it's not a py or an mdx file or a folder.") warnings.warn(f"Ignoring {file} because it's not a py or an mdx file or a folder.")
if len(black_errors) > 0:
black_message = "\n\n".join(black_errors)
raise ValueError(
"Some code examples can't be interpreted by black, which means they aren't regular python:\n\n"
+ black_message
+ "\n\nMake sure to fix the corresponding docstring or doc file, or remove the py/python after ``` if it "
+ "was not supposed to be a Python code sample."
)
return changed return changed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment