Unverified Commit b278e42d authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge pull request #53 from anishthite/master

Update drop to be consistent with gpt3 paper
parents 00abc99a 042e2926
...@@ -5,8 +5,10 @@ from ..utils import sh ...@@ -5,8 +5,10 @@ from ..utils import sh
class CoQA(Dataset): class CoQA(Dataset):
def __init__(self):
self.download()
def download(self): def download(self):
#TODO: don't download if files already there
sh(""" sh("""
mkdir -p data/coqa mkdir -p data/coqa
wget http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json -O data/coqa/coqa-train-v1.0.json wget http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json -O data/coqa/coqa-train-v1.0.json
...@@ -17,48 +19,31 @@ class CoQA(Dataset): ...@@ -17,48 +19,31 @@ class CoQA(Dataset):
return True return True
def has_validation_docs(self): def has_validation_docs(self):
return False return True
def has_test_docs(self): def has_test_docs(self):
return False return False
def training_docs(self): def training_docs(self):
myjson = json.load(open('data/coqa/coqa-train-v1.0.json'))['data'] return json.load(open('data/coqa/coqa-train-v1.0.json'))['data']
return self.load_doc(myjson)
def validation_docs(self): def validation_docs(self):
pass return json.load(open('data/coqa/coqa-dev-v1.0.json'))['data']
def test_docs(self): def test_docs(self):
myjson = json.load(open('data/coqa/coqa-dev-v1.0.json'))['data'] pass
return self.load_doc(myjson)
def fewshot_examples(self, k):
traindocs = list(self.training_docs())
random.seed(123)
random.shuffle(traindocs)
return traindocs[:k]
def fewshot_description(self): def fewshot_description(self):
pass pass
def load_doc(self, myjson):
docs = []
for item in myjson:
new_instance = [item['story']]
qa_pairs = zip(item['questions'], item['answers'])
for pair in qa_pairs:
new_instance.append('\n')
new_instance.append(''.join(['Q: ',pair[0]['input_text']]))
new_instance.append(''.join(['A: ',pair[1]['input_text']]))
docs.append(new_instance)
return docs
def doc_to_text(self, doc, include_target=True): def doc_to_text(self, doc, include_target=True):
text = '\n<|endoftext|>\n'.join(['\n'.join(instance) for instance in doc]) text = [doc['story']]
text = text + '\n<|endoftext|>' for pair in zip(doc['questions'], doc['answers']):
return text text.append('\n\n')
text.append(''.join(['Q: ',pair[0]['input_text'], '\n\n']))
text.append(''.join(['A: ',pair[1]['input_text']]))
return ''.join(text)
def evaluate(self, docs, lm): def evaluate(self, docs, lm):
pass pass
...@@ -10,6 +10,9 @@ from ..base import Dataset ...@@ -10,6 +10,9 @@ from ..base import Dataset
class DROP(Dataset): class DROP(Dataset):
DATAFOLDER = Path(__file__).parent / "../../data/drop" DATAFOLDER = Path(__file__).parent / "../../data/drop"
def __init__(self):
self.download()
def has_training_docs(self): def has_training_docs(self):
"""Whether the task has a training set""" """Whether the task has a training set"""
return True return True
...@@ -35,10 +38,10 @@ class DROP(Dataset): ...@@ -35,10 +38,10 @@ class DROP(Dataset):
pass pass
def doc_to_text(self, doc, include_target=True): def doc_to_text(self, doc, include_target=True):
doctext = "Passage: {}\n\n".format(doc["passage"]) doctext = "Passage: {}\n".format(doc["passage"])
qa_texts = [] qa_texts = []
for pair in doc["qa_pairs"]: for pair in doc["qa_pairs"]:
text = ''.join(['Q: ', pair['question'],'\nA: ']) text = ''.join(['Question: ', pair['question'],'\nAnswer: '])
if include_target: if include_target:
def get_answer(ans_dict): def get_answer(ans_dict):
if ans_dict['number'] != '': if ans_dict['number'] != '':
...@@ -52,7 +55,7 @@ class DROP(Dataset): ...@@ -52,7 +55,7 @@ class DROP(Dataset):
ans_dict['date']['year']]).strip() ans_dict['date']['year']]).strip()
text = ''.join([text, get_answer(pair['answer'])]) text = ''.join([text, get_answer(pair['answer'])])
qa_texts.append(text) qa_texts.append(text)
return ''.join([doctext, '\n\n'.join(qa_texts)]) return ''.join([doctext, '\n'.join(qa_texts)])
def evaluate(self, docs, lm, provide_description, num_fewshot): def evaluate(self, docs, lm, provide_description, num_fewshot):
......
import json
import random
from lm_eval.base import Dataset
from ..utils import sh
import csv
class StoryCloze(Dataset):
def __init__(self):
self.download()
def download(self):
#TODO: replace with Eye link
pass
def has_training_docs(self):
return False
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
pass
def load_doc(self, filename):
with open(filename, newline='') as file:
filereader = csv.reader(file)
return list(filereader)
def validation_docs(self):
return self.load_doc("data/storycloze/cloze_test_val__winter2018-cloze_test_ALL_val - 1 - 1.csv")
def test_docs(self):
return self.load_doc("data/storycloze/cloze_test_test__winter2018-cloze_test_ALL_test - 1.csv")
def fewshot_description(self):
pass
def doc_to_text(self, doc, include_target=True):
return ' '.join([*doc[1:5],doc[int(doc[-1])-4]])
def evaluate(self, docs, lm):
pass
import json
import random
from lm_eval.base import Dataset
from ..utils import sh
class TriviaQA(Dataset):
def __init__(self):
self.download()
def download(self):
#pass
#TODO: don't download if files already there
sh("""
mkdir -p data/triviaqa
wget http://nlp.cs.washington.edu/triviaqa/data/triviaqa-unfiltered.tar.gz -O data/triviaqa/trivia_qa-unfiltered.tar.gz
tar -xf data/triviaqa/trivia_qa-unfiltered.tar.gz
mv triviaqa-unfiltered/ data/triviaqa/
""")
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def training_docs(self):
return json.load(open('data/triviaqa/triviaqa-unfiltered/unfiltered-web-train.json'))['Data']
def validation_docs(self):
return json.load(open('data/triviaqa/triviaqa-unfiltered/unfiltered-web-dev.json'))['Data']
def test_docs(self):
return json.load(open('data/triviaqa/triviaqa-unfiltered/unfiltered-web-test.json'))['Data']
def fewshot_description(self):
pass
def doc_to_text(self, doc, include_target=True):
return ''.join(['Q: ', doc['Question'], '\n\n','A: ', doc['Answer']['Aliases'][0]])
def evaluate(self, docs, lm):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment