lambada.py 3.53 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""
The LAMBADA dataset: Word prediction requiring a broad discourse context∗
https://arxiv.org/pdf/1606.06031.pdf

LAMBADA is a dataset to evaluate the capabilities of computational models for text
understanding by means of a word prediction task. LAMBADA is a collection of narrative
passages sharing the characteristic that human subjects are able to guess their last
word if they are exposed to the whole passage, but not if they only see the last
sentence preceding the target word. To succeed on LAMBADA, computational models
cannot simply rely on local context, but must be able to keep track of information
in the broader discourse.

Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
"""
from lm_eval.api.task import Task
from lm_eval.api.instance import Instance
from lm_eval.api.metrics import mean, perplexity

from lm_eval.api.register import register_task, register_group

_CITATION = """
@misc{
    author={Paperno, Denis and Kruszewski, Germán and Lazaridou, Angeliki and Pham, Quan Ngoc and Bernardi, Raffaella and Pezzelle, Sandro and Baroni, Marco and Boleda, Gemma and Fernández, Raquel},
    title={The LAMBADA dataset},
    DOI={10.5281/zenodo.2630551},
    publisher={Zenodo},
    year={2016},
    month={Aug}
}
"""


class LambadaBase(Task):
    VERSION = None

    OUTPUT_TYPE = "loglikelihood"

    def training_docs(self):
        if self.has_training_docs():
            return self.dataset["train"]

    def validation_docs(self):
        if self.has_validation_docs():
            return self.dataset["validation"]

    def test_docs(self):
        if self.has_test_docs():
            return self.dataset["test"]

    def doc_to_text(self, doc):
        return doc["text"].rsplit(" ", 1)[0]

    def should_decontaminate(self):
        return True

    def doc_to_decontamination_query(self, doc):
        return doc["text"]

    def doc_to_target(self, doc):
        return " " + doc["text"].rsplit(" ", 1)[1]

    def construct_requests(self, doc, ctx, **kwargs):
lintangsutawika's avatar
lintangsutawika committed
63
64
65
66
67
68
        return Instance(
            request_type=self.OUTPUT_TYPE,
            doc=doc,
            arguments=(ctx, self.doc_to_target(doc)),
            **kwargs
        )
69
70
71

    def process_results(self, doc, results):
        # TODO: this ^ is a hack. filters should make it so that we only have one response per request that we score
lintangsutawika's avatar
lintangsutawika committed
72
73
74
        results = results[
            0
        ]  # TODO: recheck this. currently a list of [(ll, is_greedy)] is passed in
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        ll, is_greedy = results

        return {"ppl": ll, "acc": int(is_greedy)}

    def aggregation(self):
        return {"ppl": perplexity, "acc": mean}

    def higher_is_better(self):
        return {"ppl": False, "acc": True}


@register_task("lambada_standard")
class LambadaStandard(LambadaBase):
    """The LAMBADA task using the standard original LAMBADA dataset."""

    VERSION = "2.0"
    DATASET_PATH = "lambada"

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return True


@register_task("lambada_openai")
class LambadaOpenAI(LambadaBase):
    """The LAMBADA task using the LAMBADA OpenAI dataset, a modified version of the
    original LAMBADA dataset created by OpenAI for evaluating their GPT-2 model.

    Reference: https://github.com/openai/gpt-2/issues/131#issuecomment-497136199
    """

    VERSION = "2.0"
    DATASET_PATH = "EleutherAI/lambada_openai"

    def has_training_docs(self):
        return False

    def has_validation_docs(self):
        return False

    def has_test_docs(self):
        return True