Unverified Commit 57461ac0 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add possibility to maintain full copies of files (#12312)

parent 9490d668
...@@ -38,6 +38,7 @@ def postprocess_qa_predictions( ...@@ -38,6 +38,7 @@ def postprocess_qa_predictions(
null_score_diff_threshold: float = 0.0, null_score_diff_threshold: float = 0.0,
output_dir: Optional[str] = None, output_dir: Optional[str] = None,
prefix: Optional[str] = None, prefix: Optional[str] = None,
is_world_process_zero: bool = True,
): ):
""" """
Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the
...@@ -90,6 +91,7 @@ def postprocess_qa_predictions( ...@@ -90,6 +91,7 @@ def postprocess_qa_predictions(
scores_diff_json = collections.OrderedDict() scores_diff_json = collections.OrderedDict()
# Logging. # Logging.
logger.setLevel(logging.INFO if is_world_process_zero else logging.WARN)
logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.") logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
# Let's loop over all the examples! # Let's loop over all the examples!
......
...@@ -27,6 +27,9 @@ TRANSFORMERS_PATH = "src/transformers" ...@@ -27,6 +27,9 @@ TRANSFORMERS_PATH = "src/transformers"
PATH_TO_DOCS = "docs/source" PATH_TO_DOCS = "docs/source"
REPO_PATH = "." REPO_PATH = "."
# Mapping for files that are full copies of others (keys are copies, values the file to keep them up to data with)
FULL_COPIES = {"examples/tensorflow/question-answering/utils_qa.py": "examples/pytorch/question-answering/utils_qa.py"}
def _should_continue(line, indent): def _should_continue(line, indent):
return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\):\s*$", line) is not None return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\):\s*$", line) is not None
...@@ -192,6 +195,30 @@ def check_copies(overwrite: bool = False): ...@@ -192,6 +195,30 @@ def check_copies(overwrite: bool = False):
check_model_list_copy(overwrite=overwrite) check_model_list_copy(overwrite=overwrite)
def check_full_copies(overwrite: bool = False):
diffs = []
for target, source in FULL_COPIES.items():
with open(source, "r", encoding="utf-8") as f:
source_code = f.read()
with open(target, "r", encoding="utf-8") as f:
target_code = f.read()
if source_code != target_code:
if overwrite:
with open(target, "w", encoding="utf-8") as f:
print(f"Replacing the content of {target} by the one of {source}.")
f.write(source_code)
else:
diffs.append(f"- {target}: copy does not match {source}.")
if not overwrite and len(diffs) > 0:
diff = "\n".join(diffs)
raise Exception(
"Found the following copy inconsistencies:\n"
+ diff
+ "\nRun `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them."
)
def get_model_list(): def get_model_list():
"""Extracts the model list from the README.""" """Extracts the model list from the README."""
# If the introduction or the conclusion of the list change, the prompts may need to be updated. # If the introduction or the conclusion of the list change, the prompts may need to be updated.
...@@ -324,3 +351,4 @@ if __name__ == "__main__": ...@@ -324,3 +351,4 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
check_copies(args.fix_and_overwrite) check_copies(args.fix_and_overwrite)
check_full_copies(args.fix_and_overwrite)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment