lambada.py 3.46 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
"""
The LAMBADA dataset: Word prediction requiring a broad discourse context∗
https://arxiv.org/pdf/1606.06031.pdf

LAMBADA is a dataset to evaluate the capabilities of computational models for text
understanding by means of a word prediction task. LAMBADA is a collection of narrative
passages sharing the characteristic that human subjects are able to guess their last
word if they are exposed to the whole passage, but not if they only see the last
sentence preceding the target word. To succeed on LAMBADA, computational models
cannot simply rely on local context, but must be able to keep track of information
in the broader discourse.

Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
"""
15
from lm_eval.api.task import Task
haileyschoelkopf's avatar
haileyschoelkopf committed
16
from lm_eval.api.instance import Instance
17
from lm_eval.api.metrics import mean, perplexity
sdtblck's avatar
sdtblck committed
18

19
from lm_eval import utils
sdtblck's avatar
sdtblck committed
20

21
22
_CITATION = """
@misc{
Fabrizio Milo's avatar
Fabrizio Milo committed
23
    author={Paperno, Denis and Kruszewski, Germán and Lazaridou, Angeliki and Pham, Quan Ngoc and Bernardi, Raffaella and Pezzelle, Sandro and Baroni, Marco and Boleda, Gemma and Fernández, Raquel},
24
25
26
27
28
29
30
    title={The LAMBADA dataset},
    DOI={10.5281/zenodo.2630551},
    publisher={Zenodo},
    year={2016},
    month={Aug}
}
"""
sdtblck's avatar
sdtblck committed
31
32


jon-tow's avatar
jon-tow committed
33
34
class LambadaBase(Task):
    VERSION = None
sdtblck's avatar
sdtblck committed
35

36
37
    OUTPUT_TYPE = "loglikelihood"

sdtblck's avatar
sdtblck committed
38
    def training_docs(self):
jon-tow's avatar
jon-tow committed
39
40
        if self.has_training_docs():
            return self.dataset["train"]
sdtblck's avatar
sdtblck committed
41
42

    def validation_docs(self):
jon-tow's avatar
jon-tow committed
43
44
        if self.has_validation_docs():
            return self.dataset["validation"]
Leo Gao's avatar
Leo Gao committed
45

Leo Gao's avatar
Leo Gao committed
46
    def test_docs(self):
jon-tow's avatar
jon-tow committed
47
48
        if self.has_test_docs():
            return self.dataset["test"]
Leo Gao's avatar
Leo Gao committed
49

Leo Gao's avatar
Leo Gao committed
50
    def doc_to_text(self, doc):
Fabrizio Milo's avatar
Fabrizio Milo committed
51
        return doc["text"].rsplit(" ", 1)[0]
sdtblck's avatar
sdtblck committed
52

53
54
55
56
    def should_decontaminate(self):
        return True

    def doc_to_decontamination_query(self, doc):
Fabrizio Milo's avatar
Fabrizio Milo committed
57
        return doc["text"]
58

Leo Gao's avatar
Leo Gao committed
59
    def doc_to_target(self, doc):
Fabrizio Milo's avatar
Fabrizio Milo committed
60
        return " " + doc["text"].rsplit(" ", 1)[1]
sdtblck's avatar
sdtblck committed
61

62
    def construct_requests(self, doc, ctx, **kwargs):
haileyschoelkopf's avatar
haileyschoelkopf committed
63
        return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(ctx, self.doc_to_target(doc)), **kwargs)
Fabrizio Milo's avatar
Fabrizio Milo committed
64

Leo Gao's avatar
Leo Gao committed
65
    def process_results(self, doc, results):
66
67
        # TODO: this ^ is a hack. filters should make it so that we only have one response per request that we score
        results = results[0] # TODO: recheck this. currently a list of [(ll, is_greedy)] is passed in
Leo Gao's avatar
Leo Gao committed
68
        ll, is_greedy = results
Leo Gao's avatar
Leo Gao committed
69

Fabrizio Milo's avatar
Fabrizio Milo committed
70
71
        return {"ppl": ll, "acc": int(is_greedy)}

Leo Gao's avatar
Leo Gao committed
72
    def aggregation(self):
Fabrizio Milo's avatar
Fabrizio Milo committed
73
        return {"ppl": perplexity, "acc": mean}
Leo Gao's avatar
Leo Gao committed
74
75

    def higher_is_better(self):
Fabrizio Milo's avatar
Fabrizio Milo committed
76
        return {"ppl": False, "acc": True}
jon-tow's avatar
jon-tow committed
77
78


79
@utils.register_task
jon-tow's avatar
jon-tow committed
80
81
82
class LambadaStandard(LambadaBase):
    """The LAMBADA task using the standard original LAMBADA dataset."""

83
    VERSION = "2.0"
84
    TASK_NAME = "lambada_standard"
jon-tow's avatar
jon-tow committed
85
86
87
88
89
90
91
92
93
94
95
    DATASET_PATH = "lambada"

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return True

96
97

@utils.register_task
jon-tow's avatar
jon-tow committed
98
99
100
101
102
103
104
class LambadaOpenAI(LambadaBase):
    """The LAMBADA task using the LAMBADA OpenAI dataset, a modified version of the
    original LAMBADA dataset created by OpenAI for evaluating their GPT-2 model.

    Reference: https://github.com/openai/gpt-2/issues/131#issuecomment-497136199
    """

105
    VERSION = "2.0"
106
    TASK_NAME = "lambada_openai"
107
    DATASET_PATH = "EleutherAI/lambada_openai"
jon-tow's avatar
jon-tow committed
108
109
110
111
112

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
113
        return False
jon-tow's avatar
jon-tow committed
114
115

    def has_test_docs(self):
116
        return True