Commit 5d601e14 authored by Anish Thite's avatar Anish Thite
Browse files

update coqa to be consistent with gpt3 paper

parent 01608cf4
...@@ -5,8 +5,10 @@ from ..utils import sh ...@@ -5,8 +5,10 @@ from ..utils import sh
class CoQA(Dataset): class CoQA(Dataset):
def __init__(self):
self.download()
def download(self): def download(self):
#TODO: don't download if files already there
sh(""" sh("""
mkdir -p data/coqa mkdir -p data/coqa
wget http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json -O data/coqa/coqa-train-v1.0.json wget http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json -O data/coqa/coqa-train-v1.0.json
...@@ -23,42 +25,25 @@ class CoQA(Dataset): ...@@ -23,42 +25,25 @@ class CoQA(Dataset):
return False return False
def training_docs(self): def training_docs(self):
myjson = json.load(open('data/coqa/coqa-train-v1.0.json'))['data'] return json.load(open('data/coqa/coqa-train-v1.0.json'))['data']
return self.load_doc(myjson)
def validation_docs(self): def validation_docs(self):
pass pass
def test_docs(self): def test_docs(self):
myjson = json.load(open('data/coqa/coqa-dev-v1.0.json'))['data'] return json.load(open('data/coqa/coqa-dev-v1.0.json'))['data']
return self.load_doc(myjson)
def fewshot_examples(self, k):
traindocs = list(self.training_docs())
random.seed(123)
random.shuffle(traindocs)
return traindocs[:k]
def fewshot_description(self): def fewshot_description(self):
pass pass
def load_doc(self, myjson):
docs = []
for item in myjson:
new_instance = [item['story']]
qa_pairs = zip(item['questions'], item['answers'])
for pair in qa_pairs:
new_instance.append('\n')
new_instance.append(''.join(['Q: ',pair[0]['input_text']]))
new_instance.append(''.join(['A: ',pair[1]['input_text']]))
docs.append(new_instance)
return docs
def doc_to_text(self, doc, include_target=True): def doc_to_text(self, doc, include_target=True):
text = '\n<|endoftext|>\n'.join(['\n'.join(instance) for instance in doc]) text = [doc['story']]
text = text + '\n<|endoftext|>' for pair in zip(doc['questions'], doc['answers']):
return text text.append('\n\n')
text.append(''.join(['Q: ',pair[0]['input_text'], '\n\n']))
text.append(''.join(['A: ',pair[1]['input_text']]))
return ''.join(text)
def evaluate(self, docs, lm): def evaluate(self, docs, lm):
pass pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment