Unverified Commit 156d30da authored by Jacky Lee's avatar Jacky Lee Committed by GitHub
Browse files

Add warning message for `run_qa.py` (#29867)

* improve: error message for best model metric

* update: raise warning instead of error
parent 6fd93fe9
......@@ -627,6 +627,14 @@ def main():
references = [{"id": str(ex["id"]), "answers": ex[answer_column_name]} for ex in examples]
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
if data_args.version_2_with_negative:
accepted_best_metrics = ("exact", "f1", "HasAns_exact", "HasAns_f1")
else:
accepted_best_metrics = ("exact_match", "f1")
if training_args.load_best_model_at_end and training_args.metric_for_best_model not in accepted_best_metrics:
warnings.warn(f"--metric_for_best_model should be set to one of {accepted_best_metrics}")
metric = evaluate.load(
"squad_v2" if data_args.version_2_with_negative else "squad", cache_dir=model_args.cache_dir
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment