lambada.py 2.18 KB
Newer Older
1
import json
&'s avatar
& committed
2
3
from lm_eval.base import Task, rf
from lm_eval.metrics import mean, perplexity
sdtblck's avatar
sdtblck committed
4
from lm_eval.utils import sh
Leo Gao's avatar
Leo Gao committed
5
from best_download import download_file
6
import os
sdtblck's avatar
sdtblck committed
7
8


9
class LAMBADA(Task):
Leo Gao's avatar
Leo Gao committed
10
    VERSION = 0
sdtblck's avatar
sdtblck committed
11
12
    def download(self):
        sh("mkdir -p data/lambada")
13
        try:
14
15
16
            if not os.path.exists("data/lambada/lambada_test.jsonl"):
                download_file(
                    "http://eaidata.bmk.sh/data/lambada_test.jsonl", 
17
18
                    local_file="data/lambada/lambada_test.jsonl",
                    expected_checksum="4aa8d02cd17c719165fc8a7887fddd641f43fcafa4b1c806ca8abc31fabdb226"
19
                )
20
21
22
23
        except:
            # fallback - for some reason best_download doesnt work all the time here
            sh("wget http://eaidata.bmk.sh/data/lambada_test.jsonl -O data/lambada/lambada_test.jsonl")
            sh('echo "4aa8d02cd17c719165fc8a7887fddd641f43fcafa4b1c806ca8abc31fabdb226  data/lambada/lambada_test.jsonl" | sha256sum --check')
sdtblck's avatar
sdtblck committed
24
25
26
27
28

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
Leo Gao's avatar
Leo Gao committed
29
        return True
sdtblck's avatar
sdtblck committed
30
31

    def has_test_docs(self):
Leo Gao's avatar
Leo Gao committed
32
        return False
sdtblck's avatar
sdtblck committed
33
34
35
36
37

    def training_docs(self):
        pass

    def validation_docs(self):
Leo Gao's avatar
Leo Gao committed
38
39
40
41
        with open("data/lambada/lambada_test.jsonl") as fh:
            for line in fh:
                yield json.loads(line)

Leo Gao's avatar
Leo Gao committed
42
43
44
    def test_docs(self):
        pass

Leo Gao's avatar
Leo Gao committed
45
46
    def doc_to_text(self, doc):
        return doc['text'].rsplit(' ', 1)[0]
sdtblck's avatar
sdtblck committed
47

48
49
50
51
52
53
    def should_decontaminate(self):
        return True

    def doc_to_decontamination_query(self, doc):
        return doc['text']

Leo Gao's avatar
Leo Gao committed
54
55
    def doc_to_target(self, doc):
        return " " + doc['text'].rsplit(' ', 1)[1]
sdtblck's avatar
sdtblck committed
56

Leo Gao's avatar
Leo Gao committed
57
    def construct_requests(self, doc, ctx):
Leo Gao's avatar
Leo Gao committed
58
        ll, is_greedy = rf.loglikelihood(ctx, self.doc_to_target(doc))
59

Leo Gao's avatar
Leo Gao committed
60
        return ll, is_greedy
Leo Gao's avatar
Leo Gao committed
61
62
    
    def process_results(self, doc, results):
Leo Gao's avatar
Leo Gao committed
63
        ll, is_greedy = results
Leo Gao's avatar
Leo Gao committed
64

Leo Gao's avatar
Leo Gao committed
65
        return {
Leo Gao's avatar
Leo Gao committed
66
67
            'ppl': ll,
            'acc': int(is_greedy)
Leo Gao's avatar
Leo Gao committed
68
69
        }
        
Leo Gao's avatar
Leo Gao committed
70
    def aggregation(self):
Leo Gao's avatar
Leo Gao committed
71
        return {
Leo Gao's avatar
Leo Gao committed
72
73
            'ppl': perplexity,
            'acc': mean
Leo Gao's avatar
Leo Gao committed
74
        }
Leo Gao's avatar
Leo Gao committed
75
76

    def higher_is_better(self):
Leo Gao's avatar
Leo Gao committed
77
        return {
Leo Gao's avatar
Leo Gao committed
78
79
            'ppl': False,
            'acc': True
Leo Gao's avatar
Leo Gao committed
80
        }