hellaswag.py 2.18 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
"""
HellaSwag: Can a Machine Really Finish Your Sentence?
https://arxiv.org/pdf/1905.07830.pdf

Hellaswag is a commonsense inference challenge dataset. Though its questions are
trivial for humans (>95% accuracy), state-of-the-art models struggle (<48%). This is
achieved via Adversarial Filtering (AF), a data collection paradigm wherein a
series of discriminators iteratively select an adversarial set of machine-generated
wrong answers. AF proves to be surprisingly robust. The key insight is to scale up
the length and complexity of the dataset examples towards a critical 'Goldilocks'
zone wherein generated text is ridiculous to humans, yet often misclassified by
state-of-the-art models.

Homepage: https://rowanzellers.com/hellaswag/
15
16
17
18
"""
import re
from lm_eval.base import MultipleChoiceTask
from . common import HFTask
19

20
21

_CITATION = """
22
23
24
25
26
27
28
@inproceedings{zellers2019hellaswag,
    title={HellaSwag: Can a Machine Really Finish Your Sentence?},
    author={Zellers, Rowan and Holtzman, Ari and Bisk, Yonatan and Farhadi, Ali and Choi, Yejin},
    booktitle ={Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics},
    year={2019}
}
"""
29

Charles Foster's avatar
Charles Foster committed
30

31
class HellaSwag(HFTask, MultipleChoiceTask):
Leo Gao's avatar
Leo Gao committed
32
    VERSION = 0
Charles Foster's avatar
Charles Foster committed
33
34
35
36
37
38
39
40
41
42
    DATASET_PATH = "hellaswag"
    DATASET_NAME = None

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
Jon Tow's avatar
Jon Tow committed
43
        return False
Charles Foster's avatar
Charles Foster committed
44

45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    @classmethod
    def preprocess(cls, text):
        text = text.strip()
        # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
        text = text.replace(" [title]", ". ")
        text = re.sub('\\[.*?\\]', '', text)
        text = text.replace("  ", " ")
        return text

    def _convert_standard(self, doc):
        ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
        out_doc = {
            "query": self.preprocess(doc['activity_label'] + ': ' + ctx),
            "choices": [self.preprocess(ending) for ending in doc['endings']],
            "gold": int(doc['label']),
        }
        return out_doc

63
    def doc_to_text(self, doc):
64
        return doc["query"]