Commit b9b1ac06 authored by Björn Bebensee's avatar Björn Bebensee
Browse files

Match prompt with existing works; fix article regex and whitespace

parent 71d3655b
...@@ -12,7 +12,7 @@ answered using the contents of English Wikipedia. ...@@ -12,7 +12,7 @@ answered using the contents of English Wikipedia.
Homepage: https://github.com/google-research-datasets/natural-questions/tree/master/nq_open Homepage: https://github.com/google-research-datasets/natural-questions/tree/master/nq_open
""" """
import re import regex
import string import string
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from lm_eval.metrics import mean from lm_eval.metrics import mean
...@@ -60,7 +60,7 @@ class NQOpen(Task): ...@@ -60,7 +60,7 @@ class NQOpen(Task):
raise NotImplementedError() raise NotImplementedError()
def doc_to_text(self, doc): def doc_to_text(self, doc):
return f"Question: {doc['question']}\nAnswer:" return f"Q: {doc['question']}\nA:"
def should_decontaminate(self): def should_decontaminate(self):
return True return True
...@@ -84,6 +84,18 @@ class NQOpen(Task): ...@@ -84,6 +84,18 @@ class NQOpen(Task):
continuation = rf.greedy_until(ctx, {"until": ["\n", ".", ","]}) continuation = rf.greedy_until(ctx, {"until": ["\n", ".", ","]})
return continuation return continuation
def _normalize_answer(self, text):
# Lowercase and remove punctuation, strip whitespace
text = text.strip().lower().translate(str.maketrans('', '', string.punctuation))
# Remove articles, resulting in duplicate whitespace
text = regex.sub(r'\b(a|an|the)\b', ' ', text)
# Remove duplicate whitespace
text = " ".join(text.split())
return text
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of dict where keys are the names of submetrics and values are the values of
...@@ -94,16 +106,9 @@ class NQOpen(Task): ...@@ -94,16 +106,9 @@ class NQOpen(Task):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
continuation = results[0].strip().lower().translate(str.maketrans('', '', string.punctuation)) continuation = self._normalize_answer(results[0])
answers = [answer.lower().translate(str.maketrans('', '', string.punctuation)) for answer in doc["answer"]] answers = [self._normalize_answer(answer) for answer in doc["answer"]]
# remove duplicate whitespace
continuation = " ".join(continuation.split())
# remove articles
continuation = re.sub('(\s+)(a|an|the)(\s+)', ' ', continuation)
answers = [re.sub('(\s+)(a|an|the)(\s+)', ' ', cand) for cand in answers]
return { return {
"em": float(continuation in answers) "em": float(continuation in answers)
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment