lambada.py 3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
"""
The LAMBADA dataset: Word prediction requiring a broad discourse context∗
https://arxiv.org/pdf/1606.06031.pdf

LAMBADA is a dataset to evaluate the capabilities of computational models for text
understanding by means of a word prediction task. LAMBADA is a collection of narrative
passages sharing the characteristic that human subjects are able to guess their last
word if they are exposed to the whole passage, but not if they only see the last
sentence preceding the target word. To succeed on LAMBADA, computational models
cannot simply rely on local context, but must be able to keep track of information
in the broader discourse.

Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
"""
&'s avatar
& committed
15
16
from lm_eval.base import Task, rf
from lm_eval.metrics import mean, perplexity
sdtblck's avatar
sdtblck committed
17
18


19
20
_CITATION = """
@misc{
bzantium's avatar
bzantium committed
21
    author={Paperno, Denis and Kruszewski, Germán and Lazaridou, Angeliki and Pham, Quan Ngoc and Bernardi, Raffaella and Pezzelle, Sandro and Baroni, Marco and Boleda, Gemma and Fernández, Raquel},
22
23
24
25
26
27
28
29
30
    title={The LAMBADA dataset},
    DOI={10.5281/zenodo.2630551},
    publisher={Zenodo},
    year={2016},
    month={Aug}
}
"""


bzantium's avatar
bzantium committed
31
32
class LambadaBase(Task):
    VERSION = None
sdtblck's avatar
sdtblck committed
33
34

    def training_docs(self):
bzantium's avatar
bzantium committed
35
36
        if self.has_training_docs():
            return self.dataset["train"]
sdtblck's avatar
sdtblck committed
37
38

    def validation_docs(self):
bzantium's avatar
bzantium committed
39
40
        if self.has_validation_docs():
            return self.dataset["validation"]
Leo Gao's avatar
Leo Gao committed
41

Leo Gao's avatar
Leo Gao committed
42
    def test_docs(self):
bzantium's avatar
bzantium committed
43
44
        if self.has_test_docs():
            return self.dataset["test"]
Leo Gao's avatar
Leo Gao committed
45

Leo Gao's avatar
Leo Gao committed
46
    def doc_to_text(self, doc):
bzantium's avatar
bzantium committed
47
48
49
50
51
52
53
        return doc["text"].rsplit(" ", 1)[0]

    def should_decontaminate(self):
        return True

    def doc_to_decontamination_query(self, doc):
        return doc["text"]
sdtblck's avatar
sdtblck committed
54

Leo Gao's avatar
Leo Gao committed
55
    def doc_to_target(self, doc):
bzantium's avatar
bzantium committed
56
        return " " + doc["text"].rsplit(" ", 1)[1]
sdtblck's avatar
sdtblck committed
57

Leo Gao's avatar
Leo Gao committed
58
    def construct_requests(self, doc, ctx):
Leo Gao's avatar
Leo Gao committed
59
        ll, is_greedy = rf.loglikelihood(ctx, self.doc_to_target(doc))
60

Leo Gao's avatar
Leo Gao committed
61
        return ll, is_greedy
bzantium's avatar
bzantium committed
62

Leo Gao's avatar
Leo Gao committed
63
    def process_results(self, doc, results):
Leo Gao's avatar
Leo Gao committed
64
        ll, is_greedy = results
Leo Gao's avatar
Leo Gao committed
65

bzantium's avatar
bzantium committed
66
67
        return {"ppl": ll, "acc": int(is_greedy)}

Leo Gao's avatar
Leo Gao committed
68
    def aggregation(self):
bzantium's avatar
bzantium committed
69
        return {"ppl": perplexity, "acc": mean}
Leo Gao's avatar
Leo Gao committed
70
71

    def higher_is_better(self):
bzantium's avatar
bzantium committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
        return {"ppl": False, "acc": True}


class LambadaStandard(LambadaBase):
    """The LAMBADA task using the standard original LAMBADA dataset."""

    VERSION = 0
    DATASET_PATH = "lambada"

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return True


class LambadaOpenAI(LambadaBase):
    """The LAMBADA task using the LAMBADA OpenAI dataset, a modified version of the
    original LAMBADA dataset created by OpenAI for evaluating their GPT-2 model.

    Reference: https://github.com/openai/gpt-2/issues/131#issuecomment-497136199
    """

    VERSION = 0
    DATASET_PATH = "EleutherAI/lambada_openai"

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return False

    def has_test_docs(self):
        return True