hellaswag.py 2.45 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
"""
HellaSwag: Can a Machine Really Finish Your Sentence?
https://arxiv.org/pdf/1905.07830.pdf

Hellaswag is a commonsense inference challenge dataset. Though its questions are
trivial for humans (>95% accuracy), state-of-the-art models struggle (<48%). This is
achieved via Adversarial Filtering (AF), a data collection paradigm wherein a
series of discriminators iteratively select an adversarial set of machine-generated
wrong answers. AF proves to be surprisingly robust. The key insight is to scale up
the length and complexity of the dataset examples towards a critical 'Goldilocks'
zone wherein generated text is ridiculous to humans, yet often misclassified by
state-of-the-art models.

Homepage: https://rowanzellers.com/hellaswag/
15
16
17
"""
import re
from lm_eval.base import MultipleChoiceTask
18

19
20

_CITATION = """
21
22
23
24
25
26
27
@inproceedings{zellers2019hellaswag,
    title={HellaSwag: Can a Machine Really Finish Your Sentence?},
    author={Zellers, Rowan and Holtzman, Ari and Bisk, Yonatan and Farhadi, Ali and Choi, Yejin},
    booktitle ={Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics},
    year={2019}
}
"""
28

Charles Foster's avatar
Charles Foster committed
29

Jonathan Tow's avatar
Jonathan Tow committed
30
class HellaSwag(MultipleChoiceTask):
Leo Gao's avatar
Leo Gao committed
31
    VERSION = 0
Charles Foster's avatar
Charles Foster committed
32
33
34
35
36
37
38
39
40
41
    DATASET_PATH = "hellaswag"
    DATASET_NAME = None

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
Jon Tow's avatar
Jon Tow committed
42
        return False
Charles Foster's avatar
Charles Foster committed
43

Jonathan Tow's avatar
Jonathan Tow committed
44
45
46
47
48
49
50
    def training_docs(self):
        if self._training_docs is None:
            self._training_docs = list(self.dataset["train"])
        return map(self._convert_standard, self._training_docs)

    def validation_docs(self):
        return map(self._convert_standard, self.dataset["validation"])
51
52
53
54
55
56
57
58
59
60

    def _convert_standard(self, doc):
        ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
        out_doc = {
            "query": self.preprocess(doc['activity_label'] + ': ' + ctx),
            "choices": [self.preprocess(ending) for ending in doc['endings']],
            "gold": int(doc['label']),
        }
        return out_doc

Jonathan Tow's avatar
Jonathan Tow committed
61
62
63
64
65
66
67
68
69
    @classmethod
    def preprocess(cls, text):
        text = text.strip()
        # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
        text = text.replace(" [title]", ". ")
        text = re.sub('\\[.*?\\]', '', text)
        text = text.replace("  ", " ")
        return text

70
    def doc_to_text(self, doc):
71
        return doc["query"]