lambada_cloze.py 1.63 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import json
from lm_eval.base import Task, rf
from lm_eval.metrics import mean, perplexity
from lm_eval.utils import sh
from best_download import download_file


class LAMBADA_cloze(Task):
    def download(self):
        sh("mkdir -p data/lambada")
        download_file(
            "http://eaidata.bmk.sh/data/lambada_test.jsonl", 
            "data/lambada/lambada_test.jsonl", 
            "4aa8d02cd17c719165fc8a7887fddd641f43fcafa4b1c806ca8abc31fabdb226"
        )

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

    def training_docs(self):
        pass

    def validation_docs(self):
        with open("data/lambada/lambada_test.jsonl") as fh:
            for line in fh:
                yield json.loads(line)

    def test_docs(self):
        pass

    def doc_to_text(self, doc):
        return doc['text'].rsplit(' ', 1)[0] + " ____. ->"

    def doc_to_target(self, doc):
        return " " + doc['text'].rsplit(' ', 1)[1]
    
    def fewshot_description(self):
        return "Fill in blank:\n"

    def construct_requests(self, doc, ctx):
        ll, is_greedy = rf.loglikelihood(ctx, self.doc_to_target(doc))

        return ll, is_greedy
    
    def process_results(self, doc, results):
        ll, is_greedy = results

        return {
            'ppl': ll,
            'acc': int(is_greedy)
        }
        
    def aggregation(self):
        return {
            'ppl': perplexity,
            'acc': mean
        }

    def higher_is_better(self):
        return {
            'ppl': False,
            'acc': True
        }