Unverified Commit 1a8f6111 authored by Sebastian's avatar Sebastian Committed by GitHub
Browse files

fix: Update run_qa.py to work with deepset/germanquad (#23225)

Call str on id to make sure any ints are converted into the expected format for squad datasets
parent 51ae5665
...@@ -590,12 +590,12 @@ def main(): ...@@ -590,12 +590,12 @@ def main():
# Format the result to the format the metric expects. # Format the result to the format the metric expects.
if data_args.version_2_with_negative: if data_args.version_2_with_negative:
formatted_predictions = [ formatted_predictions = [
{"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items() {"id": str(k), "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items()
] ]
else: else:
formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()] formatted_predictions = [{"id": str(k), "prediction_text": v} for k, v in predictions.items()]
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples] references = [{"id": str(ex["id"]), "answers": ex[answer_column_name]} for ex in examples]
return EvalPrediction(predictions=formatted_predictions, label_ids=references) return EvalPrediction(predictions=formatted_predictions, label_ids=references)
metric = evaluate.load("squad_v2" if data_args.version_2_with_negative else "squad") metric = evaluate.load("squad_v2" if data_args.version_2_with_negative else "squad")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment