Unverified Commit c155698f authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Merge pull request #133 from EleutherAI/fazz/refactor-task-coqa

CoQa Task Initial Implementation
parents 1e5194a2 431e53de
......@@ -60,7 +60,7 @@ class GPT2LM(LM):
for context, until in tqdm(requests):
if isinstance(until, str): until = [until]
context_enc = torch.tensor([self.tokenizer.encode(context)]).to(self.device)
context_enc = torch.tensor([self.tokenizer.encode(context)[self.MAX_GEN_TOKS - 1024:]]).to(self.device)
primary_until, = self.tokenizer.encode(until[0])
......
......@@ -5,6 +5,7 @@ import sacrebleu
from . import superglue
from . import glue
from . import arc
from . import coqa
from . import race
from . import webqs
from . import anli
......@@ -81,7 +82,7 @@ TASK_REGISTRY = {
"wsc": superglue.SGWinogradSchemaChallenge,
# Order by benchmark/genre?
"coqa": coqa.CoQA,
"lambada": lambada.LAMBADA,
"piqa": piqa.PiQA,
......
# REMINDER: this code needs to be rewritten for the new framework. Remove this comment when the code is fully converted.
import os
import json
import random
from lm_eval.base import Task
from lm_eval.base import Task, rf, mean
from ..utils import sh
from itertools import zip_longest
import transformers.data.metrics.squad_metrics as squad_metrics
import collections
import datasets
import numpy as np
from lm_eval.base import rf, mean
from . common import HFTask
from tqdm import tqdm
import string, re
class CoQA(Task):
def __init__(self):
self.download()
def download(self):
#TODO: don't download if files already there
sh("""
mkdir -p data/coqa
wget http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json -O data/coqa/coqa-train-v1.0.json
wget http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json -O data/coqa/coqa-dev-v1.0.json
""")
coqa_train_filepath = 'data/coqa/coqa-train-v1.0.json'
coqa_dev_filepath = 'data/coqa/coqa-dev-v1.0.json'
sh ("""mkdir -p data/coqa""")
if not os.path.exists(coqa_train_filepath):
sh ("""wget http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json -O """ + coqa_train_filepath)
if not os.path.exists(coqa_dev_filepath):
sh ("""wget http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json -O """ + coqa_dev_filepath)
def has_training_docs(self):
return True
......@@ -30,22 +37,77 @@ class CoQA(Task):
return json.load(open('data/coqa/coqa-train-v1.0.json'))['data']
def validation_docs(self):
return json.load(open('data/coqa/coqa-dev-v1.0.json'))['data']
return json.load(open('data/coqa/coqa-dev-v1.0.json'))['data']
def test_docs(self):
pass
def fewshot_description(self):
# TODO: figure out description
return ""
return "Given a passage and a conversation so far, answer the next question in the conversation."
def doc_to_text(self, doc):
# TODO: implement.
raise NotImplementedError('doc_to_text not implemented')
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai
doc_text = doc["story"] + '\n\n'
for (q, a) in zip_longest(doc["questions"], doc["answers"][:-1]): # omit target answer ai
question = f"Q: {q['input_text']}" + '\n\n'
answer = f"A: {a['input_text']}" + '\n\n' if a is not None else "A:"
doc_text += question + answer
return doc_text
@classmethod
def get_answers(cls, doc, turn_id):
# Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
answers = []
answer_forturn = doc["answers"][turn_id - 1]["input_text"]
answers.append(answer_forturn)
additional_answers = doc.get("additional_answers")
if additional_answers:
for key in additional_answers:
additional_answer_for_turn = additional_answers[key][turn_id - 1]["input_text"]
if additional_answer_for_turn.lower() not in map(str.lower, answers):
answers.append(additional_answer_for_turn)
return answers
@classmethod
def get_answer_choice(self, raw_text):
# Function maps answers to CoQA answer categories
# ~ 1/5 of the CoQA answers are Yes/No
# ~ 2/3 of the CoQA answers are span-based
# (answers overlap with the passage ignoring punctuation and case mismatch)
if raw_text == "unknown":
return '0'
if squad_metrics.normalize_answer(raw_text) == "yes":
return '1'
if squad_metrics.normalize_answer(raw_text) == "no":
return '2'
return '3' # Not a yes/no question
def doc_to_target(self, doc):
# TODO: implement.
raise NotImplementedError('doc_to_target not implemented')
@staticmethod
def compute_scores(gold_list, pred):
# tests for exact match and on the normalised answer (compute_exact)
# test for overlap (compute_f1)
f1_sum = 0.0
em_sum = 0.0
if len(gold_list) > 1:
for i in range(len(gold_list)):
gold_answers = gold_list[0:i] + gold_list[i + 1:]
# predictions compared against (n) golds and take maximum
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_answers)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers)
else:
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list)
return {'em': em_sum / max(1, len(gold_list)), 'f1': f1_sum / max(1, len(gold_list))}
def doc_to_target(self, doc, turnid=None):
# Default to prediction of last turn.
if turnid is None:
turnid = len(doc["questions"])
raw_text = doc['answers'][turnid - 1]["input_text"]
return " " + raw_text
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
......@@ -58,9 +120,9 @@ class CoQA(Task):
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
cont_request = rf.greedy_until(ctx, ['\n'])
return cont_request
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
......@@ -71,23 +133,25 @@ class CoQA(Task):
:param results:
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
turn_id = len(doc["questions"])
gold_list = self.get_answers(doc, turn_id)
pred = results[0]
def aggregation(self):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
scores = self.compute_scores(gold_list, pred)
return {
"f1": scores['f1'],
"em": scores['em'],
}
def higher_is_better(self):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
# TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented')
return {
"f1": True,
"em": True,
}
def aggregation(self):
return {
"f1": mean,
"em": mean,
}
......@@ -75,6 +75,9 @@ def test_documents_and_requests(taskname, Task):
assert tgt[0] == ' ' or txt[-1] == '\n'
reqs = task.construct_requests(doc, txt)
# construct_requests can return just one request
if not isinstance(reqs, (list, tuple)): reqs = [reqs]
# todo: mock lm after refactoring evaluator.py to not be a mess
for req in reqs:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment