Unverified Commit e4b94d8e authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Copy code from Bert to Roberta and add safeguard script (#7219)



* Copy code from Bert to Roberta and add safeguard script

* Fix docstring

* Comment code

* Formatting

* Update src/transformers/modeling_roberta.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Add test and fix bugs

* Fix style and make new comand
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent 656c27c3
...@@ -247,6 +247,7 @@ jobs: ...@@ -247,6 +247,7 @@ jobs:
- run: black --check --line-length 119 --target-version py35 examples templates tests src utils - run: black --check --line-length 119 --target-version py35 examples templates tests src utils
- run: isort --check-only examples templates tests src utils - run: isort --check-only examples templates tests src utils
- run: flake8 examples templates tests src utils - run: flake8 examples templates tests src utils
- run: python utils/check_copies.py
- run: python utils/check_repo.py - run: python utils/check_repo.py
check_repository_consistency: check_repository_consistency:
working_directory: ~/transformers working_directory: ~/transformers
......
...@@ -6,6 +6,7 @@ quality: ...@@ -6,6 +6,7 @@ quality:
black --check --line-length 119 --target-version py35 examples templates tests src utils black --check --line-length 119 --target-version py35 examples templates tests src utils
isort --check-only examples templates tests src utils isort --check-only examples templates tests src utils
flake8 examples templates tests src utils flake8 examples templates tests src utils
python utils/check_copies.py
python utils/check_repo.py python utils/check_repo.py
# Format source code automatically # Format source code automatically
...@@ -14,6 +15,11 @@ style: ...@@ -14,6 +15,11 @@ style:
black --line-length 119 --target-version py35 examples templates tests src utils black --line-length 119 --target-version py35 examples templates tests src utils
isort examples templates tests src utils isort examples templates tests src utils
# Make marked copies of snippets of codes conform to the original
fix-copies:
python utils/check_copies.py --fix_and_overwrite
# Run tests for the library # Run tests for the library
test: test:
......
This diff is collapsed.
import os
import re
import shutil
import sys
import tempfile
import unittest
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
sys.path.append(os.path.join(git_repo_path, "utils"))
import check_copies # noqa: E402
# This is the reference code that will be used in the tests.
# If BertLMPredictionHead is changed in modeling_bert.py, this code needs to be manually updated.
REFERENCE_CODE = """ def __init__(self, config):
super().__init__()
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
"""
class CopyCheckTester(unittest.TestCase):
def setUp(self):
self.transformer_dir = tempfile.mkdtemp()
check_copies.TRANSFORMER_PATH = self.transformer_dir
shutil.copy(
os.path.join(git_repo_path, "src/transformers/modeling_bert.py"),
os.path.join(self.transformer_dir, "modeling_bert.py"),
)
def tearDown(self):
check_copies.TRANSFORMER_PATH = "src/transformers"
shutil.rmtree(self.transformer_dir)
def check_copy_consistency(self, comment, class_name, class_code, overwrite_result=None):
code = comment + f"\nclass {class_name}(nn.Module):\n" + class_code
if overwrite_result is not None:
expected = comment + f"\nclass {class_name}(nn.Module):\n" + overwrite_result
fname = os.path.join(self.transformer_dir, "new_code.py")
with open(fname, "w") as f:
f.write(code)
if overwrite_result is None:
self.assertTrue(check_copies.is_copy_consistent(fname))
else:
check_copies.is_copy_consistent(f.name, overwrite=True)
with open(fname, "r") as f:
self.assertTrue(f.read(), expected)
def test_find_code_in_transformers(self):
code = check_copies.find_code_in_transformers("modeling_bert.BertLMPredictionHead")
self.assertEqual(code, REFERENCE_CODE)
def test_is_copy_consistent(self):
# Base copy consistency
self.check_copy_consistency(
"# Copied from transformers.modeling_bert.BertLMPredictionHead",
"BertLMPredictionHead",
REFERENCE_CODE + "\n",
)
# With no empty line at the end
self.check_copy_consistency(
"# Copied from transformers.modeling_bert.BertLMPredictionHead",
"BertLMPredictionHead",
REFERENCE_CODE,
)
# Copy consistency with rename
self.check_copy_consistency(
"# Copied from transformers.modeling_bert.BertLMPredictionHead with Bert->TestModel",
"TestModelLMPredictionHead",
re.sub("Bert", "TestModel", REFERENCE_CODE),
)
# Copy consistency with a really long name
long_class_name = "TestModelWithAReallyLongNameBecauseSomePeopleLikeThatForSomeReasonIReallyDontUnderstand"
self.check_copy_consistency(
f"# Copied from transformers.modeling_bert.BertLMPredictionHead with Bert->{long_class_name}",
f"{long_class_name}LMPredictionHead",
re.sub("Bert", long_class_name, REFERENCE_CODE),
)
# Copy consistency with overwrite
self.check_copy_consistency(
"# Copied from transformers.modeling_bert.BertLMPredictionHead with Bert->TestModel",
"TestModelLMPredictionHead",
REFERENCE_CODE,
overwrite_result=re.sub("Bert", "TestModel", REFERENCE_CODE),
)
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob
import os
import re
import tempfile
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_copies.py
TRANSFORMERS_PATH = "src/transformers"
def find_code_in_transformers(object_name):
""" Find and return the code source code of `object_name`."""
parts = object_name.split(".")
i = 0
# First let's find the module where our object lives.
module = parts[i]
while i < len(parts) and not os.path.isfile(os.path.join(TRANSFORMERS_PATH, f"{module}.py")):
i += 1
module = os.path.join(module, parts[i])
if i >= len(parts):
raise ValueError(
f"`object_name` should begin with the name of a module of transformers but got {object_name}."
)
with open(os.path.join(TRANSFORMERS_PATH, f"{module}.py"), "r") as f:
lines = f.readlines()
# Now let's find the class / func in the code!
indent = ""
line_index = 0
for name in parts[i + 1 :]:
while line_index < len(lines) and re.search(f"^{indent}(class|def)\s+{name}", lines[line_index]) is None:
line_index += 1
indent += " "
line_index += 1
if line_index >= len(lines):
raise ValueError(f" {object_name} does not match any function or class in {module}.")
# We found the beginning of the class / func, now let's find the end (when the indent diminishes).
start_index = line_index
while line_index < len(lines) and (lines[line_index].startswith(indent) or len(lines[line_index]) <= 1):
line_index += 1
# Clean up empty lines at the end (if any).
while len(lines[line_index - 1]) <= 1:
line_index -= 1
code_lines = lines[start_index:line_index]
return "".join(code_lines)
_re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+transformers\.(\S+\.\S+)\s*($|\S.*$)")
_re_replace_pattern = re.compile(r"with\s+(\S+)->(\S+)(?:\s|$)")
def blackify(code):
"""
Applies the black part of our `make style` command to `code`.
"""
has_indent = code.startswith(" ")
if has_indent:
code = f"class Bla:\n{code}"
with tempfile.TemporaryDirectory() as d:
fname = os.path.join(d, "tmp.py")
with open(fname, "w") as f:
f.write(code)
os.system(f"black -q --line-length 119 --target-version py35 {fname}")
with open(fname, "r") as f:
result = f.read()
return result[len("class Bla:\n") :] if has_indent else result
def is_copy_consistent(filename, overwrite=False):
"""
Check if the code commented as a copy in `filename` matches the original.
Return the differences or overwrites the content depending on `overwrite`.
"""
with open(filename) as f:
lines = f.readlines()
found_diff = False
line_index = 0
# Not a foor loop cause `lines` is going to change (if `overwrite=True`).
while line_index < len(lines):
search = _re_copy_warning.search(lines[line_index])
if search is None:
line_index += 1
continue
# There is some copied code here, let's retrieve the original.
indent, object_name, replace_pattern = search.groups()
theoretical_code = find_code_in_transformers(object_name)
theoretical_indent = re.search(r"^(\s*)\S", theoretical_code).groups()[0]
start_index = line_index + 1 if indent == theoretical_indent else line_index + 2
indent = theoretical_indent
line_index = start_index
# Loop to check the observed code, stop when indentation diminishes or if we see a End copy comment.
should_continue = True
while line_index < len(lines) and should_continue:
line_index += 1
if line_index >= len(lines):
break
line = lines[line_index]
should_continue = (len(line) <= 1 or line.startswith(indent)) and re.search(
f"^{indent}# End copy", line
) is None
# Clean up empty lines at the end (if any).
while len(lines[line_index - 1]) <= 1:
line_index -= 1
observed_code_lines = lines[start_index:line_index]
observed_code = "".join(observed_code_lines)
# Before comparing, use the `replace_pattern` on the original code.
if len(replace_pattern) > 0:
search_patterns = _re_replace_pattern.search(replace_pattern)
if search_patterns is not None:
obj1, obj2 = search_patterns.groups()
theoretical_code = re.sub(obj1, obj2, theoretical_code)
# Blackify each version before comparing them.
observed_code = blackify(observed_code)
theoretical_code = blackify(theoretical_code)
# Test for a diff and act accordingly.
if observed_code != theoretical_code:
found_diff = True
if overwrite:
lines = lines[:start_index] + [theoretical_code] + lines[line_index:]
line_index = start_index + 1
if overwrite and found_diff:
# Warn the user a file has been modified.
print(f"Detected changes, rewriting {filename}.")
with open(filename, "w") as f:
f.writelines(lines)
return not found_diff
def check_copies(overwrite: bool = False):
all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True)
diffs = []
for filename in all_files:
consistent = is_copy_consistent(filename, overwrite)
if not consistent:
diffs.append(filename)
if not overwrite and len(diffs) > 0:
diff = "\n".join(diffs)
raise Exception(
"Found copy inconsistencies in the following files:\n"
+ diff
+ "\nRun `make fix-copies` or `python utils/check_copies --fix_and_overwrite` to fix them."
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
args = parser.parse_args()
check_copies(args.fix_and_overwrite)
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib import importlib
import inspect import inspect
import os import os
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment