lambada.py 3.09 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
"""
The LAMBADA dataset: Word prediction requiring a broad discourse context∗
https://arxiv.org/pdf/1606.06031.pdf

LAMBADA is a dataset to evaluate the capabilities of computational models for text
understanding by means of a word prediction task. LAMBADA is a collection of narrative
passages sharing the characteristic that human subjects are able to guess their last
word if they are exposed to the whole passage, but not if they only see the last
sentence preceding the target word. To succeed on LAMBADA, computational models
cannot simply rely on local context, but must be able to keep track of information
in the broader discourse.

Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
"""
15
import json
&'s avatar
& committed
16
17
from lm_eval.base import Task, rf
from lm_eval.metrics import mean, perplexity
sdtblck's avatar
sdtblck committed
18
from lm_eval.utils import sh
Leo Gao's avatar
Leo Gao committed
19
from best_download import download_file
20
import os
sdtblck's avatar
sdtblck committed
21
22


23
24
25
26
27
28
29
30
31
32
33
34
_CITATION = """
@misc{
    author={Paperno, Denis and Kruszewski, Germán and Lazaridou, Angeliki and Pham, Quan Ngoc and Bernardi, Raffaella and Pezzelle, Sandro and Baroni, Marco and Boleda, Gemma and Fernández, Raquel}, 
    title={The LAMBADA dataset},
    DOI={10.5281/zenodo.2630551},
    publisher={Zenodo},
    year={2016},
    month={Aug}
}
"""


35
class LAMBADA(Task):
Leo Gao's avatar
Leo Gao committed
36
    VERSION = 0
sdtblck's avatar
sdtblck committed
37
38
    def download(self):
        sh("mkdir -p data/lambada")
39
        try:
40
41
42
            if not os.path.exists("data/lambada/lambada_test.jsonl"):
                download_file(
                    "http://eaidata.bmk.sh/data/lambada_test.jsonl", 
43
44
                    local_file="data/lambada/lambada_test.jsonl",
                    expected_checksum="4aa8d02cd17c719165fc8a7887fddd641f43fcafa4b1c806ca8abc31fabdb226"
45
                )
46
47
48
49
        except:
            # fallback - for some reason best_download doesnt work all the time here
            sh("wget http://eaidata.bmk.sh/data/lambada_test.jsonl -O data/lambada/lambada_test.jsonl")
            sh('echo "4aa8d02cd17c719165fc8a7887fddd641f43fcafa4b1c806ca8abc31fabdb226  data/lambada/lambada_test.jsonl" | sha256sum --check')
sdtblck's avatar
sdtblck committed
50
51
52
53
54

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
Leo Gao's avatar
Leo Gao committed
55
        return True
sdtblck's avatar
sdtblck committed
56
57

    def has_test_docs(self):
Leo Gao's avatar
Leo Gao committed
58
        return False
sdtblck's avatar
sdtblck committed
59
60
61
62
63

    def training_docs(self):
        pass

    def validation_docs(self):
Leo Gao's avatar
Leo Gao committed
64
65
66
67
        with open("data/lambada/lambada_test.jsonl") as fh:
            for line in fh:
                yield json.loads(line)

Leo Gao's avatar
Leo Gao committed
68
69
70
    def test_docs(self):
        pass

Leo Gao's avatar
Leo Gao committed
71
72
    def doc_to_text(self, doc):
        return doc['text'].rsplit(' ', 1)[0]
sdtblck's avatar
sdtblck committed
73

Leo Gao's avatar
Leo Gao committed
74
75
    def doc_to_target(self, doc):
        return " " + doc['text'].rsplit(' ', 1)[1]
sdtblck's avatar
sdtblck committed
76

Leo Gao's avatar
Leo Gao committed
77
    def construct_requests(self, doc, ctx):
Leo Gao's avatar
Leo Gao committed
78
        ll, is_greedy = rf.loglikelihood(ctx, self.doc_to_target(doc))
79

Leo Gao's avatar
Leo Gao committed
80
        return ll, is_greedy
Leo Gao's avatar
Leo Gao committed
81
82
    
    def process_results(self, doc, results):
Leo Gao's avatar
Leo Gao committed
83
        ll, is_greedy = results
Leo Gao's avatar
Leo Gao committed
84

Leo Gao's avatar
Leo Gao committed
85
        return {
Leo Gao's avatar
Leo Gao committed
86
87
            'ppl': ll,
            'acc': int(is_greedy)
Leo Gao's avatar
Leo Gao committed
88
89
        }
        
Leo Gao's avatar
Leo Gao committed
90
    def aggregation(self):
Leo Gao's avatar
Leo Gao committed
91
        return {
Leo Gao's avatar
Leo Gao committed
92
93
            'ppl': perplexity,
            'acc': mean
Leo Gao's avatar
Leo Gao committed
94
        }
Leo Gao's avatar
Leo Gao committed
95
96

    def higher_is_better(self):
Leo Gao's avatar
Leo Gao committed
97
        return {
Leo Gao's avatar
Leo Gao committed
98
99
            'ppl': False,
            'acc': True
Leo Gao's avatar
Leo Gao committed
100
        }