Commit c6a35696 authored by Stephen Hogg's avatar Stephen Hogg
Browse files

Include extraction in process_results; fixes per test results

parent c2f12474
...@@ -22,12 +22,86 @@ https://arxiv.org/abs/2105.03011 ...@@ -22,12 +22,86 @@ https://arxiv.org/abs/2105.03011
bibsource = {dblp computer science bibliography, https://dblp.org} bibsource = {dblp computer science bibliography, https://dblp.org}
} }
""" """
from collections import Counter
from math import exp from math import exp
import re
import string
from lm_eval.base import rf from lm_eval.base import rf
from lm_eval.metrics import f1_score from lm_eval.metrics import f1_score, mean
from .common import HFTask from .common import HFTask
def normalize_answer(s):
"""
Taken from the official evaluation script for v1.1 of the SQuAD dataset.
Lower text and remove punctuation, articles and extra whitespace.
"""
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def categorise_answer(answer_blob):
if answer_blob["unanswerable"]:
answer = "unanswerable"
answer_type = "unanswerable"
return answer, answer_type
elif answer_blob["yes_no"]:
answer = "Yes"
answer_type = "bool"
return answer, answer_type
elif answer_blob["free_form_answer"]:
answer = answer_blob["free_form_answer"]
answer_type = "free form answer"
return answer, answer_type
elif answer_blob["extractive_spans"]:
answer = answer_blob["extractive_spans"]
answer_type = "extractive spans"
return answer, answer_type
elif answer_blob["yes_no"] is False:
answer = "No"
answer_type = "bool"
return answer, answer_type
def token_f1_score(prediction, ground_truth):
"""
Taken from the official evaluation script for v1.1 of the SQuAD dataset.
"""
prediction_tokens = normalize_answer(prediction).split()
ground_truth_tokens = normalize_answer(ground_truth).split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def paragraph_f1_score(prediction, ground_truth):
num_same = len(set(ground_truth).intersection(set(prediction)))
if num_same == 0:
return 0.0
precision = num_same / len(prediction)
recall = num_same / len(ground_truth)
f1 = (2 * precision * recall) / (precision + recall)
return f1
class QASPER(HFTask): class QASPER(HFTask):
VERSION = 0 VERSION = 0
DATASET_PATH = "qasper" DATASET_PATH = "qasper"
...@@ -50,7 +124,7 @@ class QASPER(HFTask): ...@@ -50,7 +124,7 @@ class QASPER(HFTask):
def doc_to_target(self, doc): def doc_to_target(self, doc):
# this method is invoked by tests only # this method is invoked by tests only
return " " + doc["answer_str"] return " " + doc["answer"]
def training_docs(self): def training_docs(self):
for doc in self.data["train"]: for doc in self.data["train"]:
...@@ -67,33 +141,18 @@ class QASPER(HFTask): ...@@ -67,33 +141,18 @@ class QASPER(HFTask):
https://github.com/allenai/qasper-led-baseline/blob/main/scripts/evaluator.py https://github.com/allenai/qasper-led-baseline/blob/main/scripts/evaluator.py
""" """
obs_list = [] obs_list = []
for qa in doc["qas"]: for question, answer_list in zip(doc["qas"]["question"], doc["qas"]["answers"]):
for question, answer_list in zip(qa["question"], qa["answers"]): for answer_blob in answer_list["answer"]:
for answer in answer_list: answer, answer_type = categorise_answer(answer_blob)
if answer["unanswerable"]: obs_list.append(
answer_str = "unanswerable"
answer_type = "unanswerable"
elif answer["yes_no"]:
answer_str = "Yes"
answer_type = "bool"
elif answer["yes_no"] is not None:
answer_str = "No"
answer_type = "bool"
elif answer["free_form_answer"]:
answer_str = answer["free_form_answer"]
answer_type = "free form answer"
elif answer["extractive_spans"]:
answer_str = ", ".join(answer["extractive_spans"])
answer_type = "extractive spans"
obs_list.append[
{ {
"title": doc["title"], "title": doc["title"],
"abstract": doc["abstract"], "abstract": doc["abstract"],
"question": question, "question": question,
"answer_str": answer_str, "answer": answer,
"answer_type": answer_type, "answer_type": answer_type,
} }
] )
return obs_list return obs_list
def process_results(self, doc, results): def process_results(self, doc, results):
...@@ -114,16 +173,15 @@ class QASPER(HFTask): ...@@ -114,16 +173,15 @@ class QASPER(HFTask):
# Handle completions # Handle completions
if doc["answer_type"] == "free form answer": if doc["answer_type"] == "free form answer":
res_dict["f1_ab"] = None res_dict["f1_ab"] = token_f1_score(res["answer"], doc["answer"])
# Handle extraction
if doc["answer_type"] == "extractive spans":
res_dict["f1_ex"] = paragraph_f1_score(res["answer"], doc["answer"])
return res_dict return res_dict
def aggregation(self): def aggregation(self):
return { return {"f1_un": f1_score, "f1_yn": f1_score, "f1_ab": mean, "f1_ex": mean}
"f1_un": f1_score,
"f1_yn": f1_score,
"f1_ab": f1_score,
"f1_ex": f1_score,
}
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
"""Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
...@@ -138,9 +196,18 @@ class QASPER(HFTask): ...@@ -138,9 +196,18 @@ class QASPER(HFTask):
""" """
unanswerable = rf.loglikelihood(ctx, " " + "unanswerable") unanswerable = rf.loglikelihood(ctx, " " + "unanswerable")
if doc["answer_type"] in ("free form answer", "extractive spans"): if doc["answer_type"] in ("free form answer", "extractive spans"):
res = rf.greedy_until(ctx, ["\n"]) return rf.greedy_until(ctx, ["\n"]), unanswerable
elif doc["answer_type"] in ("bool"): elif doc["answer_type"] in ("bool"):
ll_yes, _ = rf.loglikelihood(ctx, " yes") ll_yes, _ = rf.loglikelihood(ctx, " yes")
ll_no, _ = rf.loglikelihood(ctx, " no") ll_no, _ = rf.loglikelihood(ctx, " no")
res = (ll_yes, ll_no) return ll_yes, ll_no, unanswerable
return res, unanswerable else:
return unanswerable
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
return {"f1_un": True, "f1_yn": True, "f1_ab": True, "f1_ex": True}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment