triviaqa.py 3.04 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
"""
TriviaQA: A Large Scale Distantly Supervised Challenge Dataset for Reading Comprehension
https://arxiv.org/pdf/1705.03551.pdf

TriviaQA is a reading comprehension dataset containing over 650K question-answer-evidence
triples. TriviaQA includes 95K question-answer pairs authored by trivia enthusiasts
and independently gathered evidence documents, six per question on average, that provide
high quality distant supervision for answering the questions.

Homepage: https://nlp.cs.washington.edu/triviaqa/
"""
Jonathan Tow's avatar
Jonathan Tow committed
12
import inspect
seopbo's avatar
seopbo committed
13
import string
&'s avatar
& committed
14
from lm_eval.base import Task, rf
Jonathan Tow's avatar
Jonathan Tow committed
15
from lm_eval.metrics import mean
Anish Thite's avatar
Anish Thite committed
16

17
18
19
20
21
22
23
24
25
26
27
_CITATION = """
@InProceedings{JoshiTriviaQA2017,
    author = {Joshi, Mandar and Choi, Eunsol and Weld, Daniel S. and Zettlemoyer, Luke},
    title = {TriviaQA: A Large Scale Distantly Supervised Challenge Dataset for Reading Comprehension},
    booktitle = {Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics},
    month = {July},
    year = {2017},
    address = {Vancouver, Canada},
    publisher = {Association for Computational Linguistics},
}
"""
Anish Thite's avatar
Anish Thite committed
28

29

30
class TriviaQA(Task):
Stella Biderman's avatar
Stella Biderman committed
31
    VERSION = 3
32
33
    DATASET_PATH = "trivia_qa"
    DATASET_NAME = "rc.nocontext"
Anish Thite's avatar
Anish Thite committed
34
35
36
37
38
39
40
41

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
Leo Gao's avatar
Leo Gao committed
42
        return False
Anish Thite's avatar
Anish Thite committed
43
44

    def training_docs(self):
Fabrizio Milo's avatar
Fabrizio Milo committed
45
        return self.dataset["train"]
Anish Thite's avatar
Anish Thite committed
46
47

    def validation_docs(self):
Fabrizio Milo's avatar
Fabrizio Milo committed
48
        return self.dataset["validation"]
Anish Thite's avatar
Anish Thite committed
49
50

    def test_docs(self):
51
        raise NotImplementedError()
Jonathan Tow's avatar
Jonathan Tow committed
52

53
    def doc_to_text(self, doc):
Jonathan Tow's avatar
Jonathan Tow committed
54
        return f"Question: {doc['question']}\nAnswer:"
55

56
57
58
59
    def should_decontaminate(self):
        return True

    def doc_to_decontamination_query(self, doc):
Fabrizio Milo's avatar
Fabrizio Milo committed
60
        return doc["question"]
61

62
    def doc_to_target(self, doc):
Fabrizio Milo's avatar
Fabrizio Milo committed
63
        return " " + doc["answer"]["value"]
Leo Gao's avatar
Leo Gao committed
64

Leo Gao's avatar
Leo Gao committed
65
    def construct_requests(self, doc, ctx):
seopbo's avatar
seopbo committed
66
67
68
69
70
71
72
73
74
75
76
        """Uses RequestFactory to construct Requests and returns an iterable of
        Requests which will be sent to the LM.
        :param doc:
                The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
                The context string, generated by fewshot_context. This includes the natural
                language description, as well as the few shot examples, and the question
                part of the document for `doc`.
        """
        continuation = rf.greedy_until(ctx, {"until": ["\n", ".", ","]})
        return continuation
77

Leo Gao's avatar
Leo Gao committed
78
    def process_results(self, doc, results):
jonabur's avatar
jonabur committed
79
80
81
82
83
84
85
86
87
88
        continuation = (
            results[0]
            .strip()
            .lower()
            .translate(str.maketrans("", "", string.punctuation))
        )
        list_of_candidates = [
            alias.lower().translate(str.maketrans("", "", string.punctuation))
            for alias in doc["answer"]["aliases"]
        ]
89
        return {"em": float(continuation in list_of_candidates)}
Leo Gao's avatar
Leo Gao committed
90
91

    def aggregation(self):
92
        return {
seopbo's avatar
seopbo committed
93
            "em": mean,
94
        }
Leo Gao's avatar
Leo Gao committed
95
96

    def higher_is_better(self):
seopbo's avatar
seopbo committed
97
        return {"em": True}