piqa.py 2.44 KB
Newer Older
Anish Thite's avatar
Anish Thite committed
1
2
import json
import random
Leo Gao's avatar
Leo Gao committed
3
from lm_eval.base import Dataset, rf, mean
Anish Thite's avatar
Anish Thite committed
4
from ..utils import sh
Leo Gao's avatar
Leo Gao committed
5
import os
Anish Thite's avatar
Anish Thite committed
6
7
8

class PiQA(Dataset):
    def download(self):
Leo Gao's avatar
Leo Gao committed
9
10
11
12
13
14
15
16
17
18
        if not os.path.exists('data/piqa'):
            #TODO: use best_download
            sh("""
            mkdir -p data/piqa
            wget https://yonatanbisk.com/piqa/data/train.jsonl -O data/piqa/piqa-train.jsonl
            wget https://yonatanbisk.com/piqa/data/train-labels.lst -O data/piqa/piqa-train-labels.lst
            wget https://yonatanbisk.com/piqa/data/valid.jsonl -O data/piqa/piqa-valid.jsonl
            wget https://yonatanbisk.com/piqa/data/valid-labels.lst -O data/piqa/piqa-valid-labels.lst
            wget https://yonatanbisk.com/piqa/data/tests.jsonl -O data/piqa/piqa-test.jsonl
            """)
Anish Thite's avatar
Anish Thite committed
19
20
21
22
23
24
25
26

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
Leo Gao's avatar
Leo Gao committed
27
        return False
Anish Thite's avatar
Anish Thite committed
28
29
30

    def load_docs(self, textfilename, labelfilename):
        if labelfilename != None:
Leo Gao's avatar
Leo Gao committed
31
            return zip([json.loads(entry) for entry in list(open(textfilename,'r'))],list(map(lambda x: x.strip(), open(labelfilename, 'r'))))
Anish Thite's avatar
Anish Thite committed
32
33
34
35
36
37
38
39
40
        else:
            return [json.loads(entry) for entry in list(open(textfilename,'r'))]
    
    def training_docs(self):
        return self.load_docs('data/piqa/piqa-train.jsonl', 'data/piqa/piqa-train-labels.lst')
   
    def validation_docs(self):
        return self.load_docs('data/piqa/piqa-valid.jsonl', 'data/piqa/piqa-valid-labels.lst')

Leo Gao's avatar
Leo Gao committed
41
42
    #def test_docs(self):
    #    return self.load_docs('data/piqa/piqa-test.jsonl', None)
Anish Thite's avatar
Anish Thite committed
43
44
    
    def fewshot_description(self):
Leo Gao's avatar
Leo Gao committed
45
46
        # TODO: figure out fewshot description
        return ""
Anish Thite's avatar
Anish Thite committed
47
    
48
    def doc_to_text(self, doc):
Leo Gao's avatar
Leo Gao committed
49
        return doc[0]['goal']
Anish Thite's avatar
Anish Thite committed
50

51
    def doc_to_target(self, doc):
Leo Gao's avatar
Leo Gao committed
52
53
54
        #TODO: check if oa uses newline
        rightanswer = int(doc[1]) + 1
        return '\n' + ''.join([doc[0]['goal'],' ',doc[0]['sol'+str(rightanswer)]])
55

Leo Gao's avatar
Leo Gao committed
56
    def construct_requests(self, doc, ctx):
Leo Gao's avatar
Leo Gao committed
57
58
        ll_1, _ = rf.loglikelihood(ctx, doc[0]['sol1'])
        ll_2, _ = rf.loglikelihood(ctx, doc[0]['sol2'])
Anish Thite's avatar
Anish Thite committed
59

Leo Gao's avatar
Leo Gao committed
60
        return ll_1, ll_2
Leo Gao's avatar
Leo Gao committed
61
62
    
    def process_results(self, doc, results):
Leo Gao's avatar
Leo Gao committed
63
        ll_1, ll_2 = results
Leo Gao's avatar
Leo Gao committed
64

Leo Gao's avatar
Leo Gao committed
65
66
67
        return {
            'acc': (ll_1 > ll_2) == (int(doc[1]) == 0)
        }
Leo Gao's avatar
Leo Gao committed
68
69

    def aggregation(self):
Leo Gao's avatar
Leo Gao committed
70
71
72
        return {
            'acc': mean
        }
Leo Gao's avatar
Leo Gao committed
73
74

    def higher_is_better(self):
Leo Gao's avatar
Leo Gao committed
75
76
77
        return {
            'acc': True
        }