Commit b9b3159b authored by thefazzer's avatar thefazzer
Browse files

Bugfixes, answer mapping, comments

parent 602d3e20
# REMINDER: this code needs to be rewritten for the new framework. Remove this comment when the code is fully converted. # REMINDER: this code needs to be rewritten for the new framework. Remove this comment when the code is fully converted.
import json import json
import random
import numpy as np
from lm_eval.base import Task, rf, mean from lm_eval.base import Task, rf, mean
from ..utils import sh from ..utils import sh
from itertools import zip_longest from itertools import zip_longest
import transformers.data.metrics.squad_metrics as squad_metrics import transformers.data.metrics.squad_metrics as squad_metrics
import collections
import datasets
import numpy as np
from lm_eval.base import rf, mean
from . common import HFTask
from tqdm import tqdm
import string, re
class CoQA(Task): class CoQA(Task):
def download(self): def download(self):
pass pass
# -N only overwrites if the remote file has changed # -N only overwrites if the remote file has changed
...@@ -28,7 +34,11 @@ class CoQA(Task): ...@@ -28,7 +34,11 @@ class CoQA(Task):
return False return False
def training_docs(self): def training_docs(self):
return json.load(open('data/coqa/coqa-train-v1.0.json'))['data'] doc_data = json.load(open('data/coqa/coqa-train-v1.0.json'))['data']
for doc in doc_data:
for answer in doc['answers']:
answer['input_text'] = self.get_answer_choice(answer['input_text'])
return doc_data
def validation_docs(self): def validation_docs(self):
return json.load(open('data/coqa/coqa-dev-v1.0.json'))['data'] return json.load(open('data/coqa/coqa-dev-v1.0.json'))['data']
...@@ -40,9 +50,10 @@ class CoQA(Task): ...@@ -40,9 +50,10 @@ class CoQA(Task):
return "Given a passage and a conversation so far, answer the next question in the conversation." return "Given a passage and a conversation so far, answer the next question in the conversation."
def doc_to_text(self, doc): def doc_to_text(self, doc):
# Each "doc" is a story and conversation (Q and A pairs). # Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai
doc_text = doc["story"] + '\n\n' doc_text = doc["story"] + '\n\n'
for (q, a) in zip_longest(doc["questions"], doc["answers"][:-1]): # omit target answer for (q, a) in zip_longest(doc["questions"], doc["answers"][:-1]): # omit target answer ai
question = f"Q: {q['input_text']}" + '\n\n' question = f"Q: {q['input_text']}" + '\n\n'
answer = f"A: {a['input_text']}" + '\n\n' if a is not None else "A: " answer = f"A: {a['input_text']}" + '\n\n' if a is not None else "A: "
doc_text += question + answer doc_text += question + answer
...@@ -51,33 +62,43 @@ class CoQA(Task): ...@@ -51,33 +62,43 @@ class CoQA(Task):
@classmethod @classmethod
def get_answers(cls, doc, turn_id): def get_answers(cls, doc, turn_id):
# This function returns an answer and valid alternatives. # Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
answers = [] answers = []
answer_forturn = doc["answers"][turn_id - 1]["input_text"] answer_forturn = doc["answers"][turn_id - 1]["input_text"]
answers.append(answer_forturn) answers.append(answer_forturn)
additionals = doc.get("additional_answers") additional_answers = doc.get("additional_answers")
if additionals: if additional_answers:
for key in additionals: for key in additional_answers:
additional_answer_for_turn = additionals[key][turn_id - 1]["input_text"] additional_answer_for_turn = additional_answers[key][turn_id - 1]["input_text"]
if additional_answer_for_turn.upper() not in map(str.upper, answers): if additional_answer_for_turn.lower() not in map(str.lower, answers):
answers.append(additional_answer_for_turn) answers.append(additional_answer_for_turn)
return answers return answers
def doc_to_target(self, doc, turnid=None): @classmethod
# Default to predict last turn. def get_answer_choice(self, raw_text):
if turnid is None: # Function maps answers to CoQA answer categories
turnid = len(doc["questions"]) # ~ 1/5 of the CoQA answers are Yes/No
all_answers = self.get_answers(doc, turnid) # ~ 2/3 of the CoQA answers are span-based
return all_answers[0] # ignore alternative answers for now # (answers overlap with the passage ignoring punctuation and case mismatch)
if raw_text == "unknown":
return '0'
if squad_metrics.normalize_answer(raw_text) == "yes":
return '1'
if squad_metrics.normalize_answer(raw_text) == "no":
return '2'
return '3' # Not a yes/no question
@staticmethod @staticmethod
def compute_scores(gold_list, pred): def compute_scores(gold_list, pred):
# tests for exact match and on the normalised answer (compute_exact)
# test for overlap (compute_f1)
f1_sum = 0.0 f1_sum = 0.0
em_sum = 0.0 em_sum = 0.0
if len(gold_list) > 1: if len(gold_list) > 1:
for i in range(len(gold_list)): for i in range(len(gold_list)):
gold_answers = gold_list[0:i] + gold_list[i + 1:] gold_answers = gold_list[0:i] + gold_list[i + 1:]
# predictions compared against (n) golds and take maximum
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_answers) em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_answers)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers) f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers)
else: else:
...@@ -86,6 +107,14 @@ class CoQA(Task): ...@@ -86,6 +107,14 @@ class CoQA(Task):
return {'em': em_sum / max(1, len(gold_list)), 'f1': f1_sum / max(1, len(gold_list))} return {'em': em_sum / max(1, len(gold_list)), 'f1': f1_sum / max(1, len(gold_list))}
def doc_to_target(self, doc, turnid=None):
# Default to prediction of last turn.
if turnid is None:
turnid = len(doc["questions"])
raw_text = doc['answers'][turnid - 1]["input_text"]
return self.get_answer_choice(raw_text)
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """ Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
...@@ -97,11 +126,11 @@ class CoQA(Task): ...@@ -97,11 +126,11 @@ class CoQA(Task):
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
requests = [] ll_requests = [
for answers in self.get_answers(doc, len(doc["questions"])): rf.loglikelihood(ctx, " " + i)
for a in answers: for i in ['0', '1', '2', '3']
requests.append(rf.loglikelihood(ctx, " " + a)) ]
return requests return ll_requests
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
...@@ -113,16 +142,15 @@ class CoQA(Task): ...@@ -113,16 +142,15 @@ class CoQA(Task):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
turn_id = len(doc["questions"]) turn_id = len(doc["questions"])
gold_list = self.get_answers(doc, turn_id) gold_list = [self.get_answer_choice(r_text) for r_text in self.get_answers(doc, turn_id)]
pred = np.argmax(results) pred = str(np.argmax(results))
(em, f1) = self.compute_scores(gold_list, pred) scores = self.compute_scores(gold_list, pred)
return { return {
"f1": f1, "f1": scores['f1'],
"em": em, "em": scores['em'],
} }
def higher_is_better(self): def higher_is_better(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment