Unverified Commit 4ba24874 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Get a better error when check_copies fails (#7457)

* Get a better error when check_copies fails

* Fix tests
parent bef01751
...@@ -55,7 +55,7 @@ class CopyCheckTester(unittest.TestCase): ...@@ -55,7 +55,7 @@ class CopyCheckTester(unittest.TestCase):
with open(fname, "w") as f: with open(fname, "w") as f:
f.write(code) f.write(code)
if overwrite_result is None: if overwrite_result is None:
self.assertTrue(check_copies.is_copy_consistent(fname)) self.assertTrue(len(check_copies.is_copy_consistent(fname)) == 0)
else: else:
check_copies.is_copy_consistent(f.name, overwrite=True) check_copies.is_copy_consistent(f.name, overwrite=True)
with open(fname, "r") as f: with open(fname, "r") as f:
......
...@@ -96,7 +96,7 @@ def is_copy_consistent(filename, overwrite=False): ...@@ -96,7 +96,7 @@ def is_copy_consistent(filename, overwrite=False):
""" """
with open(filename, "r", encoding="utf-8") as f: with open(filename, "r", encoding="utf-8") as f:
lines = f.readlines() lines = f.readlines()
found_diff = False diffs = []
line_index = 0 line_index = 0
# Not a foor loop cause `lines` is going to change (if `overwrite=True`). # Not a foor loop cause `lines` is going to change (if `overwrite=True`).
while line_index < len(lines): while line_index < len(lines):
...@@ -140,30 +140,29 @@ def is_copy_consistent(filename, overwrite=False): ...@@ -140,30 +140,29 @@ def is_copy_consistent(filename, overwrite=False):
# Test for a diff and act accordingly. # Test for a diff and act accordingly.
if observed_code != theoretical_code: if observed_code != theoretical_code:
found_diff = True diffs.append([object_name, start_index])
if overwrite: if overwrite:
lines = lines[:start_index] + [theoretical_code] + lines[line_index:] lines = lines[:start_index] + [theoretical_code] + lines[line_index:]
line_index = start_index + 1 line_index = start_index + 1
if overwrite and found_diff: if overwrite and len(diffs) > 0:
# Warn the user a file has been modified. # Warn the user a file has been modified.
print(f"Detected changes, rewriting {filename}.") print(f"Detected changes, rewriting {filename}.")
with open(filename, "w", encoding="utf-8") as f: with open(filename, "w", encoding="utf-8") as f:
f.writelines(lines) f.writelines(lines)
return not found_diff return diffs
def check_copies(overwrite: bool = False): def check_copies(overwrite: bool = False):
all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True) all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True)
diffs = [] diffs = []
for filename in all_files: for filename in all_files:
consistent = is_copy_consistent(filename, overwrite) new_diffs = is_copy_consistent(filename, overwrite)
if not consistent: diffs += [f"- {filename}: copy does not match {d[0]} at line {d[1]}" for d in new_diffs]
diffs.append(filename)
if not overwrite and len(diffs) > 0: if not overwrite and len(diffs) > 0:
diff = "\n".join(diffs) diff = "\n".join(diffs)
raise Exception( raise Exception(
"Found copy inconsistencies in the following files:\n" "Found the follwing copy inconsistencies:\n"
+ diff + diff
+ "\nRun `make fix-copies` or `python utils/check_copies --fix_and_overwrite` to fix them." + "\nRun `make fix-copies` or `python utils/check_copies --fix_and_overwrite` to fix them."
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment