coqa.py 1.82 KB
Newer Older
1
2
import json
import random
Jason Phang's avatar
lib  
Jason Phang committed
3
from lm_eval.base import Dataset
sdtblck's avatar
sdtblck committed
4
from ..utils import sh
Jason Phang's avatar
gpt3  
Jason Phang committed
5

6
7

class CoQA(Dataset):
sdtblck's avatar
sdtblck committed
8
9
10
11
12
13
14
15

    def download(self):
        sh("""
            mkdir -p data/coqa 
            wget http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-train-v1.0.json -O data/coqa/coqa-train-v1.0.json
            wget http://downloads.cs.stanford.edu/nlp/data/coqa/coqa-dev-v1.0.json -O data/coqa/coqa-dev-v1.0.json
            """)

16
17
18
19
20
    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return False
Jason Phang's avatar
Jason Phang committed
21
22
23
24

    def has_test_docs(self):
        return False

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    def training_docs(self):
        myjson = json.load(open('data/coqa/coqa-train-v1.0.json'))['data']
        return self.load_doc(myjson)

    def validation_docs(self):
        pass

    def test_docs(self):
        myjson = json.load(open('data/coqa/coqa-dev-v1.0.json'))['data']    
        return self.load_doc(myjson)
    
    def fewshot_examples(self, k):
        traindocs = list(self.training_docs())
        random.seed(123)
        random.shuffle(traindocs)

        return traindocs[:k]
    
    def fewshot_description(self):
        pass

    def load_doc(self, myjson):
        docs = []
        for item in myjson:
            new_instance = [item['story']]
            qa_pairs = zip(item['questions'], item['answers'])
            for pair in qa_pairs:
                new_instance.append('\n')
                new_instance.append(''.join(['Q: ',pair[0]['input_text']]))
                new_instance.append(''.join(['A: ',pair[1]['input_text']]))
            docs.append(new_instance)  
        return docs
    
    def doc_to_text(self, doc, include_target=True):
        text = '\n<|endoftext|>\n'.join(['\n'.join(instance) for instance in doc])
        text = text + '\n<|endoftext|>'
        return text

    def evaluate(self, docs, lm):
        pass