Unverified Commit 9eec7ce8 authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge pull request #55 from anishthite/master

Add piqa and implement include_target for triviqa, storycloze, coqa
parents 3c480b61 171f2924
...@@ -41,7 +41,10 @@ class CoQA(Dataset): ...@@ -41,7 +41,10 @@ class CoQA(Dataset):
for pair in zip(doc['questions'], doc['answers']): for pair in zip(doc['questions'], doc['answers']):
text.append('\n\n') text.append('\n\n')
text.append(''.join(['Q: ',pair[0]['input_text'], '\n\n'])) text.append(''.join(['Q: ',pair[0]['input_text'], '\n\n']))
text.append(''.join(['A: ',pair[1]['input_text']])) if include_target:
text.append(''.join(['A: ',pair[1]['input_text']]))
else:
text.append('A: ')
return ''.join(text) return ''.join(text)
......
import json
import random
from lm_eval.base import Dataset
from ..utils import sh
class PiQA(Dataset):
def __init__(self):
self.download()
def download(self):
#pass
#TODO: don't download if files already there
sh("""
mkdir -p data/piqa
wget https://yonatanbisk.com/piqa/data/train.jsonl -O data/piqa/piqa-train.jsonl
wget https://yonatanbisk.com/piqa/data/train-labels.lst -O data/piqa/piqa-train-labels.lst
wget https://yonatanbisk.com/piqa/data/valid.jsonl -O data/piqa/piqa-valid.jsonl
wget https://yonatanbisk.com/piqa/data/valid-labels.lst -O data/piqa/piqa-valid-labels.lst
wget https://yonatanbisk.com/piqa/data/tests.jsonl -O data/piqa/piqa-test.jsonl
""")
def has_training_docs(self):
return True
def has_validation_docs(self):
return True
def has_test_docs(self):
return True
def load_docs(self, textfilename, labelfilename):
if labelfilename != None:
return zip([json.loads(entry) for entry in list(open(textfilename,'r'))],list(open(labelfilename, 'r')))
else:
return [json.loads(entry) for entry in list(open(textfilename,'r'))]
def training_docs(self):
return self.load_docs('data/piqa/piqa-train.jsonl', 'data/piqa/piqa-train-labels.lst')
def validation_docs(self):
return self.load_docs('data/piqa/piqa-valid.jsonl', 'data/piqa/piqa-valid-labels.lst')
def test_docs(self):
return self.load_docs('data/piqa/piqa-test.jsonl', None)
def fewshot_description(self):
pass
def doc_to_text(self, doc, include_target=True):
if include_target:
rightanswer = int(doc[1][0])+1
return ''.join([doc[0]['goal'],' ',doc[0]['sol'+str(rightanswer)]])
#TODO: check if oa uses newline
return doc['goal'] + ' '
def evaluate(self, docs, lm):
pass
...@@ -40,7 +40,10 @@ class StoryCloze(Dataset): ...@@ -40,7 +40,10 @@ class StoryCloze(Dataset):
pass pass
def doc_to_text(self, doc, include_target=True): def doc_to_text(self, doc, include_target=True):
return ' '.join([*doc[1:5],doc[int(doc[-1])-4]]) if include_target:
return ' '.join([*doc[1:5],doc[int(doc[-1])-4]])
else:
return ' '.join([*doc[1:5]])
def evaluate(self, docs, lm): def evaluate(self, docs, lm):
pass pass
......
...@@ -38,8 +38,10 @@ class TriviaQA(Dataset): ...@@ -38,8 +38,10 @@ class TriviaQA(Dataset):
pass pass
def doc_to_text(self, doc, include_target=True): def doc_to_text(self, doc, include_target=True):
return ''.join(['Q: ', doc['Question'], '\n\n','A: ', doc['Answer']['Aliases'][0]]) if include_target:
return ''.join(['Q: ', doc['Question'], '\n\n','A: ', doc['Answer']['Aliases'][0]])
else:
return ''.join(['Q: ', doc['Question'], '\n\n','A: '])
def evaluate(self, docs, lm): def evaluate(self, docs, lm):
pass pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment