coqa.py 1.44 KB
Newer Older
1
2
import json
import random
Jason Phang's avatar
lib  
Jason Phang committed
3
from lm_eval.base import Dataset
sdtblck's avatar
sdtblck committed
4
from ..utils import sh
Jason Phang's avatar
gpt3  
Jason Phang committed
5

6
7

class CoQA(Dataset):
8
9
    def __init__(self):
        self.download()
sdtblck's avatar
sdtblck committed
10
    def download(self):
11
        #TODO: don't download if files already there
sdtblck's avatar
sdtblck committed
12
13
14
15
16
17
        sh("""
            mkdir -p data/coqa 
            wget http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json -O data/coqa/coqa-train-v1.0.json
            wget http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json -O data/coqa/coqa-dev-v1.0.json
            """)

18
19
20
21
    def has_training_docs(self):
        return True

    def has_validation_docs(self):
Anish Thite's avatar
Anish Thite committed
22
        return True
Jason Phang's avatar
Jason Phang committed
23
24
25
26

    def has_test_docs(self):
        return False

27
    def training_docs(self):
28
        return json.load(open('data/coqa/coqa-train-v1.0.json'))['data']
29
30

    def validation_docs(self):
Anish Thite's avatar
Anish Thite committed
31
        return  json.load(open('data/coqa/coqa-dev-v1.0.json'))['data']  
32
33

    def test_docs(self):
Anish Thite's avatar
Anish Thite committed
34
        pass   
35
36
37
38
39
    
    def fewshot_description(self):
        pass
    
    def doc_to_text(self, doc, include_target=True):
40
41
42
43
        text = [doc['story']]
        for pair in zip(doc['questions'], doc['answers']):
            text.append('\n\n')
            text.append(''.join(['Q: ',pair[0]['input_text'], '\n\n']))
Anish Thite's avatar
Anish Thite committed
44
45
46
47
            if include_target:
                text.append(''.join(['A: ',pair[1]['input_text']]))
            else:
                text.append('A: ')
48
49

        return ''.join(text)
50
51
52

    def evaluate(self, docs, lm):
        pass