Commit fc7cd630 authored by Jon Tow's avatar Jon Tow
Browse files

Implement `DROP` evaluation

parent f3bf1c07
...@@ -29,6 +29,7 @@ from . import qa4mre ...@@ -29,6 +29,7 @@ from . import qa4mre
from . import translation from . import translation
from . import headqa from . import headqa
from . import mathqa from . import mathqa
from . import drop
######################################## ########################################
# Translation tasks # Translation tasks
...@@ -83,6 +84,7 @@ TASK_REGISTRY = { ...@@ -83,6 +84,7 @@ TASK_REGISTRY = {
# Order by benchmark/genre? # Order by benchmark/genre?
"coqa": coqa.CoQA, "coqa": coqa.CoQA,
"drop": drop.DROP,
"lambada": lambada.LAMBADA, "lambada": lambada.LAMBADA,
"piqa": piqa.PiQA, "piqa": piqa.PiQA,
......
import numpy as np
import json import json
from scipy.stats import pearsonr, spearmanr import numpy as np
from sklearn.metrics import f1_score, matthews_corrcoef import re
from tqdm import auto as tqdm_lib import transformers.data.metrics.squad_metrics as squad_metrics
from . common import HFTask, simple_accuracy_metric, yesno from best_download import download_file
from scipy.optimize import linear_sum_assignment
from lm_eval.base import Task, rf
from lm_eval.metrics import mean
from pathlib import Path from pathlib import Path
from ..base import Task from zipfile import ZipFile
class DROP(Task): class DROP(Task):
DATAFOLDER = Path(__file__).parent / "../../data/drop" DATAFOLDER = Path("data/drop")
URL = "https://s3-us-west-2.amazonaws.com/allennlp/datasets/drop/drop_dataset.zip"
def __init__(self):
super().__init__() def download(self):
if self.DATAFOLDER.exists():
return
Path.mkdir(self.DATAFOLDER)
download_file(self.URL, to=str(self.DATAFOLDER / "drop_dataset.zip"))
with ZipFile(self.DATAFOLDER / "drop_dataset.zip", "r") as zip:
zip.extractall(self.DATAFOLDER)
def has_training_docs(self): def has_training_docs(self):
"""Whether the task has a training set"""
return True return True
def has_validation_docs(self): def has_validation_docs(self):
"""Whether the task has a validation set"""
return True return True
def has_test_docs(self): def has_test_docs(self):
"""Whether the task has a test set"""
return False return False
def training_docs(self): def fewshot_description(self):
docs = json.load(open(self.DATAFOLDER / 'drop_dataset_train.json')) # TODO: figure out description
return [docs[k] for k in docs.keys()] return ""
def _load_docs(self, docs):
for doc in docs:
for qa in doc["qa_pairs"]:
yield {
"passage": doc["passage"],
"question": qa["question"],
"answers": self.get_answers(qa["answer"]),
}
@classmethod
def get_answers(cls, answers):
# NOTE: We wrap every non-`list` answer into a list for uniformity.
if answers["number"] != "":
return [answers["number"]]
if answers["spans"] != []:
return answers["spans"]
return [" ".join([answers["date"]["day"],
answers["date"]["month"],
answers["date"]["year"]]).strip()]
def training_docs(self):
docs = json.load(open(self.DATAFOLDER / "drop_dataset" / "drop_dataset_train.json"))
return self._load_docs([docs[k] for k in docs.keys()])
def validation_docs(self): def validation_docs(self):
docs = json.load(open(self.DATAFOLDER / 'drop_dataset_dev.json')) docs = json.load(open(self.DATAFOLDER / "drop_dataset" / "drop_dataset_dev.json"))
return [docs[k] for k in docs.keys()] return self._load_docs([docs[k] for k in docs.keys()])
def test_docs(self): def test_docs(self):
pass pass
def doc_to_text(self, doc, include_target=True):
doctext = "Passage: {}\n".format(doc["passage"])
qa_texts = []
for pair in doc["qa_pairs"]:
text = ''.join(['Question: ', pair['question'],'\nAnswer: '])
if include_target:
def get_answer(ans_dict):
if ans_dict['number'] != '':
return ans_dict['number']
if ans_dict['spans'] != []:
if len(ans_dict['spans']) > 0:
return ', '.join(ans_dict['spans'])
return ans_dict['spans'][0]
return ' '.join([ans_dict['date']['day'],
ans_dict['date']['month'],
ans_dict['date']['year']]).strip()
text = ''.join([text, get_answer(pair['answer'])])
qa_texts.append(text)
return ''.join([doctext, '\n'.join(qa_texts)])
def fewshot_description(self): def doc_to_text(self, doc):
# TODO: figure out description return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:"
return ""
def doc_to_target(self, doc):
return " " + ", ".join(doc["answers"])
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
:param doc: :param doc:
The document as returned from training_docs, validation_docs, or test_docs. The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str :param ctx: str
The context string, generated by fewshot_context. This includes the natural The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
# TODO: implement evaluation. conts = []
raise NotImplementedError('Evaluation not implemented') for _ in doc["answers"]:
conts.append(rf.greedy_until(ctx, ["\n", "."]))
return conts
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of dict where keys are the names of submetrics and values are the values of
the metric for that one document the metric for that one document
:param doc: :param
The document as returned from training_docs, validation_docs, or test_docs. The document as returned from training_docs, validation_docs, or test_docs.
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
# TODO: implement evaluation. gold, pred = doc["answers"], results
raise NotImplementedError('Evaluation not implemented') print(gold)
print(pred)
exact_match = self._exact_match(gold, pred)
f1_score = self._f1_score(gold, pred)
return {"em": exact_match, "f1": f1_score}
def _exact_match(self, golds, preds):
""" Returns the exact match of normalized gold answers and predictions. """
normalized_golds = set([self._normalize(gold) for gold in golds])
normalized_preds = set([self._normalize(pred) for pred in preds])
return int(normalized_golds == normalized_preds)
def _f1_score(self, golds, preds):
"""Returns the average F1-score over normalized `gold` and `pred`
answer lists.
"""
gold_bags = self._answer_to_bags(golds)
print("GOLD BAGS: " + str(gold_bags))
pred_bags = self._answer_to_bags(preds)
print("PRED BAGS: " + str(pred_bags))
f1_per_bag = self._align_bags(gold_bags, pred_bags)
return np.mean(f1_per_bag)
def _answer_to_bags(self, answers):
return [set(self._normalize(answer).split()) for answer in answers]
def _align_bags(self, gold_bags, pred_bags):
""" Returns the max metric value over all the answers. """
scores = np.zeros([len(gold_bags), len(pred_bags)])
for gold_index, gold_bag in enumerate(gold_bags):
for pred_index, pred_bag in enumerate(pred_bags):
print(self._is_number_match(gold_bag, pred_bag))
if self._is_number_match(gold_bag, pred_bag):
scores[gold_index, pred_index] = self._bag_f1(pred_bag, gold_bag)
print(scores)
row_ind, col_ind = linear_sum_assignment(-scores)
max_scores = np.zeros([max(len(gold_bags), len(pred_bags))])
for row, column in zip(row_ind, col_ind):
max_scores[row] = max(max_scores[row], scores[row, column])
return max_scores
def _bag_f1(self, gold_bag, pred_bag):
intersection = len(gold_bag.intersection(pred_bag))
if intersection == 0:
return 0.0
precision = intersection / float(len(pred_bag)) if pred_bag else 1.0
recall = intersection / float(len(gold_bag)) if gold_bag else 1.0
f1 = (2 * precision * recall) / (precision + recall)
return f1
def _is_number_match(self, gold_bag, pred_bag):
gold_numbers = set(filter(lambda s: s.isnumeric(), list(gold_bag)))
pred_numbers = set(filter(lambda s: s.isnumeric(), list(pred_bag)))
return (not gold_numbers) or gold_numbers.intersection(pred_numbers)
def _normalize(self, answer):
def tokenize(text):
return re.split(" |-", text)
tokens = [squad_metrics.normalize_answer(token) for token in tokenize(answer)]
normalized = " ".join(tokens).strip()
return normalized
def aggregation(self): def aggregation(self):
""" """
:returns: {str: [float] -> float} :returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
""" """
# TODO: implement evaluation. return {"em": mean, "f1": mean}
raise NotImplementedError('Evaluation not implemented')
def higher_is_better(self): def higher_is_better(self):
""" """
:returns: {str: bool} :returns: {str: bool}
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
""" """
# TODO: implement evaluation. return {"em": True, "f1": True}
raise NotImplementedError('Evaluation not implemented')
# Temporary sanity-checks
def main():
drop = DROP()
def test_bags():
multiple_answers = ["Pacific Ocean", "Pacific"]
ma_bags = drop._answer_to_bags(multiple_answers)
print(f"Multiple Choice Answer Bags: {multiple_answers} => {ma_bags}")
assert len(ma_bags) == 2
number_answer = ["1974"]
number_bags = drop._answer_to_bags(number_answer)
print(f"Number Bags: {number_answer} => {number_bags}")
print()
test_bags()
def test_is_number_match():
gold = ["10 29 1999"]
pred = ["4 29 1990"]
gb = drop._answer_to_bags(gold)
pb = drop._answer_to_bags(pred)
print(gb)
print(pb)
for g in gb:
for p in pb:
match = drop._is_number_match(g, p)
print(match)
print()
#test_is_number_match()
def test_exact_match():
gold = ["Bob Ross"]
pred = ["Bob Ross"]
em = drop._exact_match(gold, pred)
print(em)
#test_exact_match()
def test_f1_score():
gold = ["25 to 44"]
pred = ["25 to 44 or 45 to 64"]
f1 = drop._f1_score(gold, pred)
print(gold)
print(pred)
print(f1)
gold = ["300", "1992"]
pred = ["300", "1992"]
f1 = drop._f1_score(gold, pred)
print(f1)
#test_f1_score()
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment