triviaqa.py 3.96 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
"""
TriviaQA: A Large Scale Distantly Supervised Challenge Dataset for Reading Comprehension
https://arxiv.org/pdf/1705.03551.pdf

TriviaQA is a reading comprehension dataset containing over 650K question-answer-evidence
triples. TriviaQA includes 95K question-answer pairs authored by trivia enthusiasts
and independently gathered evidence documents, six per question on average, that provide
high quality distant supervision for answering the questions.

Homepage: https://nlp.cs.washington.edu/triviaqa/
"""
import inspect
lintangsutawika's avatar
lintangsutawika committed
13

14
15
16
17
# import lm_eval.datasets.triviaqa.triviaqa
import string
from lm_eval.api.task import Task
from lm_eval.api.instance import Instance
lintangsutawika's avatar
lintangsutawika committed
18
from lm_eval.api.registry import register_task
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from lm_eval.api.metrics import mean

_CITATION = """
@InProceedings{JoshiTriviaQA2017,
    author = {Joshi, Mandar and Choi, Eunsol and Weld, Daniel S. and Zettlemoyer, Luke},
    title = {TriviaQA: A Large Scale Distantly Supervised Challenge Dataset for Reading Comprehension},
    booktitle = {Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics},
    month = {July},
    year = {2017},
    address = {Vancouver, Canada},
    publisher = {Association for Computational Linguistics},
}
"""

lintangsutawika's avatar
lintangsutawika committed
33

34
35
36
@register_task("triviaqa")
class TriviaQA(Task):
    VERSION = 1
lintangsutawika's avatar
lintangsutawika committed
37
    DATASET_PATH = "trivia_qa"  # inspect.getfile(lm_eval.datasets.triviaqa.triviaqa)
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    DATASET_NAME = "unfiltered.nocontext"

    OUTPUT_TYPE = "greedy_until"

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

    def training_docs(self):
        return self.dataset["train"]

    def validation_docs(self):
        return self.dataset["validation"]

    def test_docs(self):
        raise NotImplementedError()

    def doc_to_text(self, doc):
        return f"Q: {doc['question']}\nA:"

    def should_decontaminate(self):
        return True

    def doc_to_decontamination_query(self, doc):
        return doc["question"]

    def doc_to_target(self, doc):
        return " " + doc["answer"]["value"]

    def _remove_prefixes(self, aliases):
        # Optimization: Remove any alias that has a strict prefix elsewhere in the list
        # we can do this because if the prefix is acceptable by isgreedy, we can stop looking
        aliases.sort()
        ret = [aliases[0]]
        for alias in aliases[1:]:
            if not alias.startswith(ret[-1]):
                ret.append(alias)
        return ret

    def construct_requests(self, doc, ctx, **kwargs):
        """Uses RequestFactory to construct Requests and returns an iterable of
        Requests which will be sent to the LM.
        :param doc:
                The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
                The context string, generated by fewshot_context. This includes the natural
                language description, as well as the few shot examples, and the question
                part of the document for `doc`.
        """
        continuation = Instance(
            request_type=self.OUTPUT_TYPE,
            doc=doc,
lintangsutawika's avatar
lintangsutawika committed
95
96
97
98
99
100
101
            arguments=(
                ctx,
                {
                    "until": ["\n", ".", ","],
                    "do_sample": False,
                },
            ),
102
103
104
105
106
107
            idx=0,
            **kwargs,
        )
        return continuation

    def process_results(self, doc, results):
lintangsutawika's avatar
lintangsutawika committed
108
109
110
111
112
113
114
115
116
117
        continuation = (
            results[0]
            .strip()
            .lower()
            .translate(str.maketrans("", "", string.punctuation))
        )
        list_of_candidates = [
            alias.lower().translate(str.maketrans("", "", string.punctuation))
            for alias in self._remove_prefixes(doc["answer"]["aliases"])
        ]
118
119
120
121
122
123
124
125
        return {"em": float(continuation in list_of_candidates)}

    def aggregation(self):
        return {
            "em": mean,
        }

    def higher_is_better(self):
lintangsutawika's avatar
lintangsutawika committed
126
        return {"em": True}