Commit 84774770 authored by Leo Gao's avatar Leo Gao
Browse files

Merge branch 'fazz/refactor-task-coqa' of...

Merge branch 'fazz/refactor-task-coqa' of github.com:EleutherAI/lm_evaluation_harness into fazz/refactor-task-coqa
parents 758b9e3c 6a534600
......@@ -3,6 +3,7 @@ from pprint import pprint
from . import superglue
from . import glue
from . import arc
from . import coqa
from . import race
from . import webqs
from . import anli
......@@ -49,7 +50,7 @@ TASK_REGISTRY = {
"wsc": superglue.SGWinogradSchemaChallenge,
# Order by benchmark/genre?
"coqa": coqa.CoQA,
"lambada": lambada.LAMBADA,
"piqa": piqa.PiQA,
......
# REMINDER: this code needs to be rewritten for the new framework. Remove this comment when the code is fully converted.
import os
import json
import random
from lm_eval.base import Task
from lm_eval.base import Task, rf, mean
from ..utils import sh
from itertools import zip_longest
import transformers.data.metrics.squad_metrics as squad_metrics
import collections
import datasets
import numpy as np
from lm_eval.base import rf, mean
from . common import HFTask
from tqdm import tqdm
import string, re
class CoQA(Task):
def __init__(self):
self.download()
def download(self):
#TODO: don't download if files already there
sh("""
mkdir -p data/coqa
wget http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json -O data/coqa/coqa-train-v1.0.json
wget http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json -O data/coqa/coqa-dev-v1.0.json
""")
coqa_train_filepath = 'data/coqa/coqa-train-v1.0.json'
coqa_dev_filepath = 'data/coqa/coqa-dev-v1.0.json'
sh ("""mkdir -p data/coqa""")
if not os.path.exists(coqa_train_filepath):
sh ("""wget http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json -O """ + coqa_train_filepath)
if not os.path.exists(coqa_dev_filepath):
sh ("""wget http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json -O """ + coqa_dev_filepath)
def has_training_docs(self):
return True
......@@ -30,22 +37,78 @@ class CoQA(Task):
return json.load(open('data/coqa/coqa-train-v1.0.json'))['data']
def validation_docs(self):
return json.load(open('data/coqa/coqa-dev-v1.0.json'))['data']
return json.load(open('data/coqa/coqa-dev-v1.0.json'))['data']
def test_docs(self):
pass
def fewshot_description(self):
# TODO: figure out description
return ""
return "Given a passage and a conversation so far, answer the next question in the conversation."
def doc_to_text(self, doc):
# TODO: implement.
raise NotImplementedError('doc_to_text not implemented')
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai
doc_text = doc["story"] + '\n\n'
for (q, a) in zip_longest(doc["questions"], doc["answers"][:-1]): # omit target answer ai
question = f"Q: {q['input_text']}" + '\n\n'
answer = f"A: {a['input_text']}" + '\n\n' if a is not None else "A:"
doc_text += question + answer
print(doc_text)
return doc_text
@classmethod
def get_answers(cls, doc, turn_id):
# Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
answers = []
answer_forturn = doc["answers"][turn_id - 1]["input_text"]
answers.append(answer_forturn)
additional_answers = doc.get("additional_answers")
if additional_answers:
for key in additional_answers:
additional_answer_for_turn = additional_answers[key][turn_id - 1]["input_text"]
if additional_answer_for_turn.lower() not in map(str.lower, answers):
answers.append(additional_answer_for_turn)
return answers
@classmethod
def get_answer_choice(self, raw_text):
# Function maps answers to CoQA answer categories
# ~ 1/5 of the CoQA answers are Yes/No
# ~ 2/3 of the CoQA answers are span-based
# (answers overlap with the passage ignoring punctuation and case mismatch)
if raw_text == "unknown":
return '0'
if squad_metrics.normalize_answer(raw_text) == "yes":
return '1'
if squad_metrics.normalize_answer(raw_text) == "no":
return '2'
return '3' # Not a yes/no question
def doc_to_target(self, doc):
# TODO: implement.
raise NotImplementedError('doc_to_target not implemented')
@staticmethod
def compute_scores(gold_list, pred):
# tests for exact match and on the normalised answer (compute_exact)
# test for overlap (compute_f1)
f1_sum = 0.0
em_sum = 0.0
if len(gold_list) > 1:
for i in range(len(gold_list)):
gold_answers = gold_list[0:i] + gold_list[i + 1:]
# predictions compared against (n) golds and take maximum
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_answers)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers)
else:
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list)
return {'em': em_sum / max(1, len(gold_list)), 'f1': f1_sum / max(1, len(gold_list))}
def doc_to_target(self, doc, turnid=None):
# Default to prediction of last turn.
if turnid is None:
turnid = len(doc["questions"])
raw_text = doc['answers'][turnid - 1]["input_text"]
return raw_text
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
......@@ -58,9 +121,9 @@ class CoQA(Task):
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
cont_request = rf.greedy_until(ctx, ['\n'])
return cont_request
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
......@@ -71,23 +134,25 @@ class CoQA(Task):
:param results:
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
turn_id = len(doc["questions"])
gold_list = self.get_answers(doc, turn_id)
pred = results[0]
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
scores = self.compute_scores(gold_list, pred)
return {
"f1": scores['f1'],
"em": scores['em'],
}
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
return {
"f1": True,
"em": True,
}
def aggregation(self):
return {
"f1": mean,
"em": mean,
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment