Commit b9b3159b authored by thefazzer's avatar thefazzer
Browse files

Bugfixes, answer mapping, comments

parent 602d3e20
# REMINDER: this code needs to be rewritten for the new framework. Remove this comment when the code is fully converted.
import json
import random
import numpy as np
from lm_eval.base import Task, rf, mean
from ..utils import sh
from itertools import zip_longest
import transformers.data.metrics.squad_metrics as squad_metrics
import collections
import datasets
import numpy as np
from lm_eval.base import rf, mean
from . common import HFTask
from tqdm import tqdm
import string, re
class CoQA(Task):
def download(self):
pass
# -N only overwrites if the remote file has changed
......@@ -28,10 +34,14 @@ class CoQA(Task):
return False
def training_docs(self):
return json.load(open('data/coqa/coqa-train-v1.0.json'))['data']
doc_data = json.load(open('data/coqa/coqa-train-v1.0.json'))['data']
for doc in doc_data:
for answer in doc['answers']:
answer['input_text'] = self.get_answer_choice(answer['input_text'])
return doc_data
def validation_docs(self):
return json.load(open('data/coqa/coqa-dev-v1.0.json'))['data']
return json.load(open('data/coqa/coqa-dev-v1.0.json'))['data']
def test_docs(self):
pass
......@@ -40,44 +50,55 @@ class CoQA(Task):
return "Given a passage and a conversation so far, answer the next question in the conversation."
def doc_to_text(self, doc):
# Each "doc" is a story and conversation (Q and A pairs).
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
# and a question qi, the task is to predict the answer ai
doc_text = doc["story"] + '\n\n'
for (q, a) in zip_longest(doc["questions"], doc["answers"][:-1]): # omit target answer
for (q, a) in zip_longest(doc["questions"], doc["answers"][:-1]): # omit target answer ai
question = f"Q: {q['input_text']}" + '\n\n'
answer = f"A: {a['input_text']}" + '\n\n' if a is not None else "A: "
doc_text += question + answer
print(doc_text)
return doc_text
@classmethod
def get_answers(cls, doc, turn_id):
# This function returns an answer and valid alternatives.
# Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
answers = []
answer_forturn = doc["answers"][turn_id - 1]["input_text"]
answers.append(answer_forturn)
additionals = doc.get("additional_answers")
if additionals:
for key in additionals:
additional_answer_for_turn = additionals[key][turn_id - 1]["input_text"]
if additional_answer_for_turn.upper() not in map(str.upper, answers):
additional_answers = doc.get("additional_answers")
if additional_answers:
for key in additional_answers:
additional_answer_for_turn = additional_answers[key][turn_id - 1]["input_text"]
if additional_answer_for_turn.lower() not in map(str.lower, answers):
answers.append(additional_answer_for_turn)
return answers
def doc_to_target(self, doc, turnid=None):
# Default to predict last turn.
if turnid is None:
turnid = len(doc["questions"])
all_answers = self.get_answers(doc, turnid)
return all_answers[0] # ignore alternative answers for now
@classmethod
def get_answer_choice(self, raw_text):
# Function maps answers to CoQA answer categories
# ~ 1/5 of the CoQA answers are Yes/No
# ~ 2/3 of the CoQA answers are span-based
# (answers overlap with the passage ignoring punctuation and case mismatch)
if raw_text == "unknown":
return '0'
if squad_metrics.normalize_answer(raw_text) == "yes":
return '1'
if squad_metrics.normalize_answer(raw_text) == "no":
return '2'
return '3' # Not a yes/no question
@staticmethod
def compute_scores(gold_list, pred):
# tests for exact match and on the normalised answer (compute_exact)
# test for overlap (compute_f1)
f1_sum = 0.0
em_sum = 0.0
if len(gold_list) > 1:
for i in range(len(gold_list)):
gold_answers = gold_list[0:i] + gold_list[i + 1:]
# predictions compared against (n) golds and take maximum
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_answers)
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers)
else:
......@@ -86,6 +107,14 @@ class CoQA(Task):
return {'em': em_sum / max(1, len(gold_list)), 'f1': f1_sum / max(1, len(gold_list))}
def doc_to_target(self, doc, turnid=None):
# Default to prediction of last turn.
if turnid is None:
turnid = len(doc["questions"])
raw_text = doc['answers'][turnid - 1]["input_text"]
return self.get_answer_choice(raw_text)
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
......@@ -97,12 +126,12 @@ class CoQA(Task):
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
requests = []
for answers in self.get_answers(doc, len(doc["questions"])):
for a in answers:
requests.append(rf.loglikelihood(ctx, " " + a))
return requests
ll_requests = [
rf.loglikelihood(ctx, " " + i)
for i in ['0', '1', '2', '3']
]
return ll_requests
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
......@@ -113,16 +142,15 @@ class CoQA(Task):
:param results:
The results of the requests created in construct_requests.
"""
turn_id = len(doc["questions"])
gold_list = self.get_answers(doc, turn_id)
pred = np.argmax(results)
gold_list = [self.get_answer_choice(r_text) for r_text in self.get_answers(doc, turn_id)]
pred = str(np.argmax(results))
(em, f1) = self.compute_scores(gold_list, pred)
scores = self.compute_scores(gold_list, pred)
return {
"f1": f1,
"em": em,
"f1": scores['f1'],
"em": scores['em'],
}
def higher_is_better(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment