triviaqa.py 3.46 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
"""
TriviaQA: A Large Scale Distantly Supervised Challenge Dataset for Reading Comprehension
https://arxiv.org/pdf/1705.03551.pdf

TriviaQA is a reading comprehension dataset containing over 650K question-answer-evidence
triples. TriviaQA includes 95K question-answer pairs authored by trivia enthusiasts
and independently gathered evidence documents, six per question on average, that provide
high quality distant supervision for answering the questions.

Homepage: https://nlp.cs.washington.edu/triviaqa/
"""
Jonathan Tow's avatar
Jonathan Tow committed
12
import inspect
seopbo's avatar
seopbo committed
13
import string
&'s avatar
& committed
14
from lm_eval.base import Task, rf
Jonathan Tow's avatar
Jonathan Tow committed
15
from lm_eval.metrics import mean
Anish Thite's avatar
Anish Thite committed
16

17
18
19
20
21
22
23
24
25
26
27
_CITATION = """
@InProceedings{JoshiTriviaQA2017,
    author = {Joshi, Mandar and Choi, Eunsol and Weld, Daniel S. and Zettlemoyer, Luke},
    title = {TriviaQA: A Large Scale Distantly Supervised Challenge Dataset for Reading Comprehension},
    booktitle = {Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics},
    month = {July},
    year = {2017},
    address = {Vancouver, Canada},
    publisher = {Association for Computational Linguistics},
}
"""
Anish Thite's avatar
Anish Thite committed
28

29

30
class TriviaQA(Task):
31
    VERSION = 3
32
33
    DATASET_PATH = "trivia_qa"
    DATASET_NAME = "rc.nocontext"
Anish Thite's avatar
Anish Thite committed
34
35
36
37
38
39
40
41

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
Leo Gao's avatar
Leo Gao committed
42
        return False
Anish Thite's avatar
Anish Thite committed
43
44

    def training_docs(self):
Fabrizio Milo's avatar
Fabrizio Milo committed
45
        return self.dataset["train"]
Anish Thite's avatar
Anish Thite committed
46
47

    def validation_docs(self):
Fabrizio Milo's avatar
Fabrizio Milo committed
48
        return self.dataset["validation"]
Anish Thite's avatar
Anish Thite committed
49
50

    def test_docs(self):
51
        raise NotImplementedError()
Jonathan Tow's avatar
Jonathan Tow committed
52

53
    def doc_to_text(self, doc):
Jonathan Tow's avatar
Jonathan Tow committed
54
        return f"Question: {doc['question']}\nAnswer:"
55

56
57
58
59
    def should_decontaminate(self):
        return True

    def doc_to_decontamination_query(self, doc):
Fabrizio Milo's avatar
Fabrizio Milo committed
60
        return doc["question"]
61

62
    def doc_to_target(self, doc):
Fabrizio Milo's avatar
Fabrizio Milo committed
63
        return " " + doc["answer"]["value"]
Leo Gao's avatar
Leo Gao committed
64
65
66
67
68
69
70
71
72
73

    def _remove_prefixes(self, aliases):
        # Optimization: Remove any alias that has a strict prefix elsewhere in the list
        # we can do this because if the prefix is acceptable by isgreedy, we can stop looking
        aliases.sort()
        ret = [aliases[0]]
        for alias in aliases[1:]:
            if not alias.startswith(ret[-1]):
                ret.append(alias)
        return ret
Anish Thite's avatar
Anish Thite committed
74

Leo Gao's avatar
Leo Gao committed
75
    def construct_requests(self, doc, ctx):
seopbo's avatar
seopbo committed
76
77
78
79
80
81
82
83
84
85
86
        """Uses RequestFactory to construct Requests and returns an iterable of
        Requests which will be sent to the LM.
        :param doc:
                The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
                The context string, generated by fewshot_context. This includes the natural
                language description, as well as the few shot examples, and the question
                part of the document for `doc`.
        """
        continuation = rf.greedy_until(ctx, {"until": ["\n", ".", ","]})
        return continuation
87

Leo Gao's avatar
Leo Gao committed
88
    def process_results(self, doc, results):
89
90
91
92
93
        generated = results[0].strip().lower().translate(str.maketrans('', '', string.punctuation))
        list_of_truth_candidates = [alias.lower().translate(str.maketrans('', '', string.punctuation)) for alias in self._remove_prefixes(doc["answer"]["aliases"])]
        def match(candidate):
            return candidate in generated
        return {"em": float(any(match(candidate) for candidate in list_of_truth_candidates))}
Leo Gao's avatar
Leo Gao committed
94
95

    def aggregation(self):
96
        return {
seopbo's avatar
seopbo committed
97
            "em": mean,
98
        }
Leo Gao's avatar
Leo Gao committed
99
100

    def higher_is_better(self):
seopbo's avatar
seopbo committed
101
        return {"em": True}