Commit c6a35696 authored by Stephen Hogg's avatar Stephen Hogg
Browse files

Include extraction in process_results; fixes per test results

parent c2f12474
......@@ -22,12 +22,86 @@ https://arxiv.org/abs/2105.03011
bibsource = {dblp computer science bibliography, https://dblp.org}
}
"""
from collections import Counter
from math import exp
import re
import string
from lm_eval.base import rf
from lm_eval.metrics import f1_score
from lm_eval.metrics import f1_score, mean
from .common import HFTask
def normalize_answer(s):
"""
Taken from the official evaluation script for v1.1 of the SQuAD dataset.
Lower text and remove punctuation, articles and extra whitespace.
"""
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def categorise_answer(answer_blob):
if answer_blob["unanswerable"]:
answer = "unanswerable"
answer_type = "unanswerable"
return answer, answer_type
elif answer_blob["yes_no"]:
answer = "Yes"
answer_type = "bool"
return answer, answer_type
elif answer_blob["free_form_answer"]:
answer = answer_blob["free_form_answer"]
answer_type = "free form answer"
return answer, answer_type
elif answer_blob["extractive_spans"]:
answer = answer_blob["extractive_spans"]
answer_type = "extractive spans"
return answer, answer_type
elif answer_blob["yes_no"] is False:
answer = "No"
answer_type = "bool"
return answer, answer_type
def token_f1_score(prediction, ground_truth):
"""
Taken from the official evaluation script for v1.1 of the SQuAD dataset.
"""
prediction_tokens = normalize_answer(prediction).split()
ground_truth_tokens = normalize_answer(ground_truth).split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def paragraph_f1_score(prediction, ground_truth):
num_same = len(set(ground_truth).intersection(set(prediction)))
if num_same == 0:
return 0.0
precision = num_same / len(prediction)
recall = num_same / len(ground_truth)
f1 = (2 * precision * recall) / (precision + recall)
return f1
class QASPER(HFTask):
VERSION = 0
DATASET_PATH = "qasper"
......@@ -50,7 +124,7 @@ class QASPER(HFTask):
def doc_to_target(self, doc):
# this method is invoked by tests only
return " " + doc["answer_str"]
return " " + doc["answer"]
def training_docs(self):
for doc in self.data["train"]:
......@@ -67,33 +141,18 @@ class QASPER(HFTask):
https://github.com/allenai/qasper-led-baseline/blob/main/scripts/evaluator.py
"""
obs_list = []
for qa in doc["qas"]:
for question, answer_list in zip(qa["question"], qa["answers"]):
for answer in answer_list:
if answer["unanswerable"]:
answer_str = "unanswerable"
answer_type = "unanswerable"
elif answer["yes_no"]:
answer_str = "Yes"
answer_type = "bool"
elif answer["yes_no"] is not None:
answer_str = "No"
answer_type = "bool"
elif answer["free_form_answer"]:
answer_str = answer["free_form_answer"]
answer_type = "free form answer"
elif answer["extractive_spans"]:
answer_str = ", ".join(answer["extractive_spans"])
answer_type = "extractive spans"
obs_list.append[
{
"title": doc["title"],
"abstract": doc["abstract"],
"question": question,
"answer_str": answer_str,
"answer_type": answer_type,
}
]
for question, answer_list in zip(doc["qas"]["question"], doc["qas"]["answers"]):
for answer_blob in answer_list["answer"]:
answer, answer_type = categorise_answer(answer_blob)
obs_list.append(
{
"title": doc["title"],
"abstract": doc["abstract"],
"question": question,
"answer": answer,
"answer_type": answer_type,
}
)
return obs_list
def process_results(self, doc, results):
......@@ -114,16 +173,15 @@ class QASPER(HFTask):
# Handle completions
if doc["answer_type"] == "free form answer":
res_dict["f1_ab"] = None
res_dict["f1_ab"] = token_f1_score(res["answer"], doc["answer"])
# Handle extraction
if doc["answer_type"] == "extractive spans":
res_dict["f1_ex"] = paragraph_f1_score(res["answer"], doc["answer"])
return res_dict
def aggregation(self):
return {
"f1_un": f1_score,
"f1_yn": f1_score,
"f1_ab": f1_score,
"f1_ex": f1_score,
}
return {"f1_un": f1_score, "f1_yn": f1_score, "f1_ab": mean, "f1_ex": mean}
def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of
......@@ -138,9 +196,18 @@ class QASPER(HFTask):
"""
unanswerable = rf.loglikelihood(ctx, " " + "unanswerable")
if doc["answer_type"] in ("free form answer", "extractive spans"):
res = rf.greedy_until(ctx, ["\n"])
return rf.greedy_until(ctx, ["\n"]), unanswerable
elif doc["answer_type"] in ("bool"):
ll_yes, _ = rf.loglikelihood(ctx, " yes")
ll_no, _ = rf.loglikelihood(ctx, " no")
res = (ll_yes, ll_no)
return res, unanswerable
return ll_yes, ll_no, unanswerable
else:
return unanswerable
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {"f1_un": True, "f1_yn": True, "f1_ab": True, "f1_ex": True}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment