Commit 68122ef4 authored by thefazzer's avatar thefazzer
Browse files

Call greedy_until continuation, download fix

parent 37c3139d
# REMINDER: this code needs to be rewritten for the new framework. Remove this comment when the code is fully converted. import os
import json import json
from lm_eval.base import Task, rf, mean from lm_eval.base import Task, rf, mean
from ..utils import sh from ..utils import sh
...@@ -16,13 +15,14 @@ import string, re ...@@ -16,13 +15,14 @@ import string, re
class CoQA(Task): class CoQA(Task):
def download(self): def download(self):
pass coqa_train_filepath = 'data/coqa/coqa-train-v1.0.json'
# -N only overwrites if the remote file has changed coqa_dev_filepath = 'data/coqa/coqa-dev-v1.0.json'
sh ("""
mkdir -p data/coqa sh ("""mkdir -p data/coqa""")
wget -N http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json -O data/coqa/coqa-train-v1.0.json if not os.path.exists(coqa_train_filepath):
wget -N http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json -O data/coqa/coqa-dev-v1.0.json sh ("""wget http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json -O """ + coqa_train_filepath)
""") if not os.path.exists(coqa_dev_filepath):
sh ("""wget http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json -O """ + coqa_dev_filepath)
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -34,11 +34,7 @@ class CoQA(Task): ...@@ -34,11 +34,7 @@ class CoQA(Task):
return False return False
def training_docs(self): def training_docs(self):
doc_data = json.load(open('data/coqa/coqa-train-v1.0.json'))['data'] return json.load(open('data/coqa/coqa-train-v1.0.json'))['data']
for doc in doc_data:
for answer in doc['answers']:
answer['input_text'] = self.get_answer_choice(answer['input_text'])
return doc_data
def validation_docs(self): def validation_docs(self):
return json.load(open('data/coqa/coqa-dev-v1.0.json'))['data'] return json.load(open('data/coqa/coqa-dev-v1.0.json'))['data']
...@@ -111,9 +107,8 @@ class CoQA(Task): ...@@ -111,9 +107,8 @@ class CoQA(Task):
# Default to prediction of last turn. # Default to prediction of last turn.
if turnid is None: if turnid is None:
turnid = len(doc["questions"]) turnid = len(doc["questions"])
raw_text = doc['answers'][turnid - 1]["input_text"] raw_text = doc['answers'][turnid - 1]["input_text"]
return self.get_answer_choice(raw_text) return raw_text
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """ Uses RequestFactory to construct Requests and returns an iterable of
...@@ -126,11 +121,8 @@ class CoQA(Task): ...@@ -126,11 +121,8 @@ class CoQA(Task):
language description, as well as the few shot examples, and the question language description, as well as the few shot examples, and the question
part of the document for `doc`. part of the document for `doc`.
""" """
ll_requests = [ cont_request = rf.greedy_until(ctx, ['\n'])
rf.loglikelihood(ctx, " " + i) return cont_request
for i in ['0', '1', '2', '3']
]
return ll_requests
def process_results(self, doc, results): def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
...@@ -143,8 +135,8 @@ class CoQA(Task): ...@@ -143,8 +135,8 @@ class CoQA(Task):
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
turn_id = len(doc["questions"]) turn_id = len(doc["questions"])
gold_list = [self.get_answer_choice(r_text) for r_text in self.get_answers(doc, turn_id)] gold_list = self.get_answers(doc, turn_id)
pred = str(np.argmax(results)) pred = results[0]
scores = self.compute_scores(gold_list, pred) scores = self.compute_scores(gold_list, pred)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment