Commit 68122ef4 authored by thefazzer's avatar thefazzer
Browse files

Call greedy_until continuation, download fix

parent 37c3139d
# REMINDER: this code needs to be rewritten for the new framework. Remove this comment when the code is fully converted.
import os
import json
from lm_eval.base import Task, rf, mean
from ..utils import sh
......@@ -16,13 +15,14 @@ import string, re
class CoQA(Task):
def download(self):
pass
# -N only overwrites if the remote file has changed
sh ("""
mkdir -p data/coqa
wget -N http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json -O data/coqa/coqa-train-v1.0.json
wget -N http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json -O data/coqa/coqa-dev-v1.0.json
""")
coqa_train_filepath = 'data/coqa/coqa-train-v1.0.json'
coqa_dev_filepath = 'data/coqa/coqa-dev-v1.0.json'
sh ("""mkdir -p data/coqa""")
if not os.path.exists(coqa_train_filepath):
sh ("""wget http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json -O """ + coqa_train_filepath)
if not os.path.exists(coqa_dev_filepath):
sh ("""wget http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json -O """ + coqa_dev_filepath)
def has_training_docs(self):
return True
......@@ -34,11 +34,7 @@ class CoQA(Task):
return False
def training_docs(self):
doc_data = json.load(open('data/coqa/coqa-train-v1.0.json'))['data']
for doc in doc_data:
for answer in doc['answers']:
answer['input_text'] = self.get_answer_choice(answer['input_text'])
return doc_data
return json.load(open('data/coqa/coqa-train-v1.0.json'))['data']
def validation_docs(self):
return json.load(open('data/coqa/coqa-dev-v1.0.json'))['data']
......@@ -111,9 +107,8 @@ class CoQA(Task):
# Default to prediction of last turn.
if turnid is None:
turnid = len(doc["questions"])
raw_text = doc['answers'][turnid - 1]["input_text"]
return self.get_answer_choice(raw_text)
return raw_text
def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of
......@@ -126,11 +121,8 @@ class CoQA(Task):
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
ll_requests = [
rf.loglikelihood(ctx, " " + i)
for i in ['0', '1', '2', '3']
]
return ll_requests
cont_request = rf.greedy_until(ctx, ['\n'])
return cont_request
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a
......@@ -143,8 +135,8 @@ class CoQA(Task):
The results of the requests created in construct_requests.
"""
turn_id = len(doc["questions"])
gold_list = [self.get_answer_choice(r_text) for r_text in self.get_answers(doc, turn_id)]
pred = str(np.argmax(results))
gold_list = self.get_answers(doc, turn_id)
pred = results[0]
scores = self.compute_scores(gold_list, pred)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment