Unverified Commit 40b049c7 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Check copies blackify (#10775)

* Apply black before checking copies

* Fix for class methods

* Deal with lonely brackets

* Remove debug and add forward changes

* Separate copies and fix test

* Add black as a test dependency
parent 39373919
...@@ -228,7 +228,7 @@ extras["speech"] = deps_list("soundfile", "torchaudio") ...@@ -228,7 +228,7 @@ extras["speech"] = deps_list("soundfile", "torchaudio")
extras["sentencepiece"] = deps_list("sentencepiece", "protobuf") extras["sentencepiece"] = deps_list("sentencepiece", "protobuf")
extras["testing"] = ( extras["testing"] = (
deps_list("pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets", "pytest-sugar") deps_list("pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets", "pytest-sugar", "black")
+ extras["retrieval"] + extras["retrieval"]
+ extras["modelcreation"] + extras["modelcreation"]
) )
......
...@@ -671,7 +671,6 @@ class M2M100Encoder(M2M100PreTrainedModel): ...@@ -671,7 +671,6 @@ class M2M100Encoder(M2M100PreTrainedModel):
self.init_weights() self.init_weights()
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoder.forward with MBart->M2M100
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -830,7 +829,6 @@ class M2M100Decoder(M2M100PreTrainedModel): ...@@ -830,7 +829,6 @@ class M2M100Decoder(M2M100PreTrainedModel):
self.init_weights() self.init_weights()
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoder.forward with MBart->M2M100
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
......
...@@ -1398,6 +1398,7 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel): ...@@ -1398,6 +1398,7 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
""", """,
MOBILEBERT_START_DOCSTRING, MOBILEBERT_START_DOCSTRING,
) )
# Copied from transformers.models.bert.modeling_bert.BertForMultipleChoice with Bert->MobileBert all-casing
class MobileBertForMultipleChoice(MobileBertPreTrainedModel): class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1417,7 +1418,6 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel): ...@@ -1417,7 +1418,6 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
output_type=MultipleChoiceModelOutput, output_type=MultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
# Copied from transformers.models.bert.modeling_bert.BertForMultipleChoice.forward with Bert->MobileBert all-casing
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
......
...@@ -737,8 +737,10 @@ class RobertaModel(RobertaPreTrainedModel): ...@@ -737,8 +737,10 @@ class RobertaModel(RobertaPreTrainedModel):
the model is configured as a decoder. the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: ``1`` for the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
tokens that are NOT MASKED, ``0`` for MASKED tokens.
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
...@@ -754,9 +756,10 @@ class RobertaModel(RobertaPreTrainedModel): ...@@ -754,9 +756,10 @@ class RobertaModel(RobertaPreTrainedModel):
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = use_cache if use_cache is not None else self.config.use_cache
if not self.config.is_decoder: if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False use_cache = False
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
......
...@@ -872,7 +872,6 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel): ...@@ -872,7 +872,6 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel):
return combined_attention_mask return combined_attention_mask
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoder.forward with MBart->Speech2Text
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
......
...@@ -19,6 +19,8 @@ import sys ...@@ -19,6 +19,8 @@ import sys
import tempfile import tempfile
import unittest import unittest
import black
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
sys.path.append(os.path.join(git_repo_path, "utils")) sys.path.append(os.path.join(git_repo_path, "utils"))
...@@ -66,6 +68,7 @@ class CopyCheckTester(unittest.TestCase): ...@@ -66,6 +68,7 @@ class CopyCheckTester(unittest.TestCase):
code = comment + f"\nclass {class_name}(nn.Module):\n" + class_code code = comment + f"\nclass {class_name}(nn.Module):\n" + class_code
if overwrite_result is not None: if overwrite_result is not None:
expected = comment + f"\nclass {class_name}(nn.Module):\n" + overwrite_result expected = comment + f"\nclass {class_name}(nn.Module):\n" + overwrite_result
code = black.format_str(code, mode=black.FileMode([black.TargetVersion.PY35], line_length=119))
fname = os.path.join(self.transformer_dir, "new_code.py") fname = os.path.join(self.transformer_dir, "new_code.py")
with open(fname, "w") as f: with open(fname, "w") as f:
f.write(code) f.write(code)
...@@ -103,7 +106,7 @@ class CopyCheckTester(unittest.TestCase): ...@@ -103,7 +106,7 @@ class CopyCheckTester(unittest.TestCase):
) )
# Copy consistency with a really long name # Copy consistency with a really long name
long_class_name = "TestModelWithAReallyLongNameBecauseSomePeopleLikeThatForSomeReasonIReallyDontUnderstand" long_class_name = "TestModelWithAReallyLongNameBecauseSomePeopleLikeThatForSomeReason"
self.check_copy_consistency( self.check_copy_consistency(
f"# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->{long_class_name}", f"# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->{long_class_name}",
f"{long_class_name}LMPredictionHead", f"{long_class_name}LMPredictionHead",
......
...@@ -17,7 +17,8 @@ import argparse ...@@ -17,7 +17,8 @@ import argparse
import glob import glob
import os import os
import re import re
import tempfile
import black
# All paths are set with the intent you should run this script from the root of the repo with the command # All paths are set with the intent you should run this script from the root of the repo with the command
...@@ -27,6 +28,10 @@ PATH_TO_DOCS = "docs/source" ...@@ -27,6 +28,10 @@ PATH_TO_DOCS = "docs/source"
REPO_PATH = "." REPO_PATH = "."
def _should_continue(line, indent):
return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\):\s*$", line) is not None
def find_code_in_transformers(object_name): def find_code_in_transformers(object_name):
""" Find and return the code source code of `object_name`.""" """ Find and return the code source code of `object_name`."""
parts = object_name.split(".") parts = object_name.split(".")
...@@ -62,7 +67,7 @@ def find_code_in_transformers(object_name): ...@@ -62,7 +67,7 @@ def find_code_in_transformers(object_name):
# We found the beginning of the class / func, now let's find the end (when the indent diminishes). # We found the beginning of the class / func, now let's find the end (when the indent diminishes).
start_index = line_index start_index = line_index
while line_index < len(lines) and (lines[line_index].startswith(indent) or len(lines[line_index]) <= 1): while line_index < len(lines) and _should_continue(lines[line_index], indent):
line_index += 1 line_index += 1
# Clean up empty lines at the end (if any). # Clean up empty lines at the end (if any).
while len(lines[line_index - 1]) <= 1: while len(lines[line_index - 1]) <= 1:
...@@ -76,23 +81,6 @@ _re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+transformers\.(\S+\.\S+) ...@@ -76,23 +81,6 @@ _re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+transformers\.(\S+\.\S+)
_re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)") _re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)")
def blackify(code):
"""
Applies the black part of our `make style` command to `code`.
"""
has_indent = code.startswith(" ")
if has_indent:
code = f"class Bla:\n{code}"
with tempfile.TemporaryDirectory() as d:
fname = os.path.join(d, "tmp.py")
with open(fname, "w", encoding="utf-8", newline="\n") as f:
f.write(code)
os.system(f"black -q --line-length 119 --target-version py35 {fname}")
with open(fname, "r", encoding="utf-8", newline="\n") as f:
result = f.read()
return result[len("class Bla:\n") :] if has_indent else result
def get_indent(code): def get_indent(code):
lines = code.split("\n") lines = code.split("\n")
idx = 0 idx = 0
...@@ -100,7 +88,18 @@ def get_indent(code): ...@@ -100,7 +88,18 @@ def get_indent(code):
idx += 1 idx += 1
if idx < len(lines): if idx < len(lines):
return re.search(r"^(\s*)\S", lines[idx]).groups()[0] return re.search(r"^(\s*)\S", lines[idx]).groups()[0]
return 0 return ""
def blackify(code):
"""
Applies the black part of our `make style` command to `code`.
"""
has_indent = len(get_indent(code)) > 0
if has_indent:
code = f"class Bla:\n{code}"
result = black.format_str(code, mode=black.FileMode([black.TargetVersion.PY35], line_length=119))
return result[len("class Bla:\n") :] if has_indent else result
def is_copy_consistent(filename, overwrite=False): def is_copy_consistent(filename, overwrite=False):
...@@ -136,9 +135,7 @@ def is_copy_consistent(filename, overwrite=False): ...@@ -136,9 +135,7 @@ def is_copy_consistent(filename, overwrite=False):
if line_index >= len(lines): if line_index >= len(lines):
break break
line = lines[line_index] line = lines[line_index]
should_continue = (len(line) <= 1 or line.startswith(indent)) and re.search( should_continue = _should_continue(line, indent) and re.search(f"^{indent}# End copy", line) is None
f"^{indent}# End copy", line
) is None
# Clean up empty lines at the end (if any). # Clean up empty lines at the end (if any).
while len(lines[line_index - 1]) <= 1: while len(lines[line_index - 1]) <= 1:
line_index -= 1 line_index -= 1
...@@ -159,6 +156,11 @@ def is_copy_consistent(filename, overwrite=False): ...@@ -159,6 +156,11 @@ def is_copy_consistent(filename, overwrite=False):
theoretical_code = re.sub(obj1.lower(), obj2.lower(), theoretical_code) theoretical_code = re.sub(obj1.lower(), obj2.lower(), theoretical_code)
theoretical_code = re.sub(obj1.upper(), obj2.upper(), theoretical_code) theoretical_code = re.sub(obj1.upper(), obj2.upper(), theoretical_code)
# Blackify after replacement. To be able to do that, we need the header (class or function definition)
# from the previous line
theoretical_code = blackify(lines[start_index - 1] + theoretical_code)
theoretical_code = theoretical_code[len(lines[start_index - 1]) :]
# Test for a diff and act accordingly. # Test for a diff and act accordingly.
if observed_code != theoretical_code: if observed_code != theoretical_code:
diffs.append([object_name, start_index]) diffs.append([object_name, start_index])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment