triviaqa.py 3.43 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
"""
TriviaQA: A Large Scale Distantly Supervised Challenge Dataset for Reading Comprehension
https://arxiv.org/pdf/1705.03551.pdf

TriviaQA is a reading comprehension dataset containing over 650K question-answer-evidence
triples. TriviaQA includes 95K question-answer pairs authored by trivia enthusiasts
and independently gathered evidence documents, six per question on average, that provide
high quality distant supervision for answering the questions.

Homepage: https://nlp.cs.washington.edu/triviaqa/
"""
Jonathan Tow's avatar
Jonathan Tow committed
12
13
import inspect
import lm_eval.datasets.triviaqa.triviaqa
seopbo's avatar
seopbo committed
14
import string
&'s avatar
& committed
15
from lm_eval.base import Task, rf
Jonathan Tow's avatar
Jonathan Tow committed
16
from lm_eval.metrics import mean
Anish Thite's avatar
Anish Thite committed
17

18
19
20
21
22
23
24
25
26
27
28
_CITATION = """
@InProceedings{JoshiTriviaQA2017,
    author = {Joshi, Mandar and Choi, Eunsol and Weld, Daniel S. and Zettlemoyer, Luke},
    title = {TriviaQA: A Large Scale Distantly Supervised Challenge Dataset for Reading Comprehension},
    booktitle = {Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics},
    month = {July},
    year = {2017},
    address = {Vancouver, Canada},
    publisher = {Association for Computational Linguistics},
}
"""
Anish Thite's avatar
Anish Thite committed
29

30

31
class TriviaQA(Task):
jon-tow's avatar
jon-tow committed
32
    VERSION = 1
Jonathan Tow's avatar
Jonathan Tow committed
33
34
    DATASET_PATH = inspect.getfile(lm_eval.datasets.triviaqa.triviaqa)
    DATASET_NAME = None
Anish Thite's avatar
Anish Thite committed
35
36
37
38
39
40
41
42

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
Leo Gao's avatar
Leo Gao committed
43
        return False
Anish Thite's avatar
Anish Thite committed
44
45

    def training_docs(self):
Fabrizio Milo's avatar
Fabrizio Milo committed
46
        return self.dataset["train"]
Anish Thite's avatar
Anish Thite committed
47
48

    def validation_docs(self):
Fabrizio Milo's avatar
Fabrizio Milo committed
49
        return self.dataset["validation"]
Anish Thite's avatar
Anish Thite committed
50
51

    def test_docs(self):
52
        raise NotImplementedError()
Jonathan Tow's avatar
Jonathan Tow committed
53

54
    def doc_to_text(self, doc):
Jonathan Tow's avatar
Jonathan Tow committed
55
        return f"Question: {doc['question']}\nAnswer:"
56

57
58
59
60
    def should_decontaminate(self):
        return True

    def doc_to_decontamination_query(self, doc):
Fabrizio Milo's avatar
Fabrizio Milo committed
61
        return doc["question"]
62

63
    def doc_to_target(self, doc):
Fabrizio Milo's avatar
Fabrizio Milo committed
64
        return " " + doc["answer"]["value"]
Leo Gao's avatar
Leo Gao committed
65
66
67
68
69
70
71
72
73
74

    def _remove_prefixes(self, aliases):
        # Optimization: Remove any alias that has a strict prefix elsewhere in the list
        # we can do this because if the prefix is acceptable by isgreedy, we can stop looking
        aliases.sort()
        ret = [aliases[0]]
        for alias in aliases[1:]:
            if not alias.startswith(ret[-1]):
                ret.append(alias)
        return ret
Anish Thite's avatar
Anish Thite committed
75

Leo Gao's avatar
Leo Gao committed
76
    def construct_requests(self, doc, ctx):
seopbo's avatar
seopbo committed
77
78
79
80
81
82
83
84
85
86
87
        """Uses RequestFactory to construct Requests and returns an iterable of
        Requests which will be sent to the LM.
        :param doc:
                The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
                The context string, generated by fewshot_context. This includes the natural
                language description, as well as the few shot examples, and the question
                part of the document for `doc`.
        """
        continuation = rf.greedy_until(ctx, {"until": ["\n", ".", ","]})
        return continuation
88

Leo Gao's avatar
Leo Gao committed
89
    def process_results(self, doc, results):
seopbo's avatar
seopbo committed
90
91
92
        continuation = results[0].strip().lower().translate(str.maketrans('', '', string.punctuation))
        list_of_candidates = [alias.lower().translate(str.maketrans('', '', string.punctuation)) for alias in self._remove_prefixes(doc["answer"]["aliases"])]
        return {"em": float(continuation in list_of_candidates)}
Leo Gao's avatar
Leo Gao committed
93
94

    def aggregation(self):
95
        return {
seopbo's avatar
seopbo committed
96
            "em": mean,
97
        }
Leo Gao's avatar
Leo Gao committed
98
99

    def higher_is_better(self):
seopbo's avatar
seopbo committed
100
        return {"em": True}