coqa.py 6.04 KB
Newer Older
1
2
3
4
"""
CoQA: A Conversational Question Answering Challenge
https://arxiv.org/pdf/1808.07042.pdf

bzantium's avatar
bzantium committed
5
6
7
CoQA is a large-scale dataset for building Conversational Question Answering
systems. The goal of the CoQA challenge is to measure the ability of machines to
understand a text passage and answer a series of interconnected questions that
8
9
10
appear in a conversation.

Homepage: https://stanfordnlp.github.io/coqa/
11
"""
Jonathan Tow's avatar
Jonathan Tow committed
12
import inspect
13
import transformers.data.metrics.squad_metrics as squad_metrics
Jonathan Tow's avatar
Jonathan Tow committed
14
import lm_eval.datasets.coqa.coqa
15
16
from lm_eval.base import Task, rf, mean
from itertools import zip_longest
Jonathan Tow's avatar
Jonathan Tow committed
17

18

19
_CITATION = """
20
21
22
23
24
25
26
27
28
@misc{reddy2018coqa,
    title={CoQA: A Conversational Question Answering Challenge},
    author={Siva Reddy and Danqi Chen and Christopher D. Manning},
    year={2018},
    eprint={1808.07042},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}
"""
29

30

31
class CoQA(Task):
Leo Gao's avatar
Leo Gao committed
32
    VERSION = 1
Jonathan Tow's avatar
Jonathan Tow committed
33
34
    DATASET_PATH = inspect.getfile(lm_eval.datasets.coqa.coqa)
    DATASET_NAME = None
sdtblck's avatar
sdtblck committed
35

36
37
38
39
    def has_training_docs(self):
        return True

    def has_validation_docs(self):
Anish Thite's avatar
Anish Thite committed
40
        return True
Jason Phang's avatar
Jason Phang committed
41
42
43
44

    def has_test_docs(self):
        return False

45
    def training_docs(self):
Jonathan Tow's avatar
Jonathan Tow committed
46
        return self.dataset["train"]
47
48

    def validation_docs(self):
Jonathan Tow's avatar
Jonathan Tow committed
49
        return self.dataset["validation"]
50
51

    def test_docs(self):
Leo Gao's avatar
Leo Gao committed
52
        pass
53

Leo Gao's avatar
Leo Gao committed
54
    def doc_to_text(self, doc):
bzantium's avatar
bzantium committed
55
        # Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
thefazzer's avatar
thefazzer committed
56
        # and a question qi, the task is to predict the answer ai
bzantium's avatar
bzantium committed
57
58
59
60
        doc_text = doc["story"] + "\n\n"
        for (q, a) in zip_longest(
            doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]
        ):  # omit target answer ai
Jonathan Tow's avatar
Jonathan Tow committed
61
62
            question = f"Q: {q}\n\n"
            answer = f"A: {a}\n\n" if a is not None else "A:"
63
64
            doc_text += question + answer
        return doc_text
bzantium's avatar
bzantium committed
65
66
67
68
69
70
71

    def should_decontaminate(self):
        return True

    def doc_to_decontamination_query(self, doc):
        return doc["story"] + " " + "\n".join(doc["questions"]["input_text"])

72
73
    @classmethod
    def get_answers(cls, doc, turn_id):
thefazzer's avatar
thefazzer committed
74
        # Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
75
        answers = []
Jonathan Tow's avatar
Jonathan Tow committed
76
        answer_forturn = doc["answers"]["input_text"][turn_id - 1]
77
        answers.append(answer_forturn)
bzantium's avatar
bzantium committed
78

thefazzer's avatar
thefazzer committed
79
80
81
        additional_answers = doc.get("additional_answers")
        if additional_answers:
            for key in additional_answers:
bzantium's avatar
bzantium committed
82
83
84
                additional_answer_for_turn = additional_answers[key]["input_text"][
                    turn_id - 1
                ]
thefazzer's avatar
thefazzer committed
85
                if additional_answer_for_turn.lower() not in map(str.lower, answers):
86
87
                    answers.append(additional_answer_for_turn)
        return answers
bzantium's avatar
bzantium committed
88

thefazzer's avatar
thefazzer committed
89
90
91
    @classmethod
    def get_answer_choice(self, raw_text):
        # Function maps answers to CoQA answer categories
bzantium's avatar
bzantium committed
92
        # ~ 1/5 of the CoQA answers are Yes/No
thefazzer's avatar
thefazzer committed
93
94
95
        # ~ 2/3 of the CoQA answers are span-based
        # (answers overlap with the passage ignoring punctuation and case mismatch)
        if raw_text == "unknown":
bzantium's avatar
bzantium committed
96
            return "0"
thefazzer's avatar
thefazzer committed
97
        if squad_metrics.normalize_answer(raw_text) == "yes":
bzantium's avatar
bzantium committed
98
            return "1"
thefazzer's avatar
thefazzer committed
99
        if squad_metrics.normalize_answer(raw_text) == "no":
bzantium's avatar
bzantium committed
100
101
            return "2"
        return "3"  # Not a yes/no question
Leo Gao's avatar
Leo Gao committed
102

103
104
    @staticmethod
    def compute_scores(gold_list, pred):
thefazzer's avatar
thefazzer committed
105
106
        # tests for exact match and on the normalised answer (compute_exact)
        # test for overlap (compute_f1)
107
108
109
110
        f1_sum = 0.0
        em_sum = 0.0
        if len(gold_list) > 1:
            for i in range(len(gold_list)):
bzantium's avatar
bzantium committed
111
                gold_answers = gold_list[0:i] + gold_list[i + 1 :]
thefazzer's avatar
thefazzer committed
112
                # predictions compared against (n) golds and take maximum
bzantium's avatar
bzantium committed
113
114
115
                em_sum += max(
                    squad_metrics.compute_exact(a, pred) for a in gold_answers
                )
116
117
118
119
120
                f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers)
        else:
            em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
            f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list)

bzantium's avatar
bzantium committed
121
122
123
124
        return {
            "em": em_sum / max(1, len(gold_list)),
            "f1": f1_sum / max(1, len(gold_list)),
        }
125

thefazzer's avatar
thefazzer committed
126
127
128
    def doc_to_target(self, doc, turnid=None):
        # Default to prediction of last turn.
        if turnid is None:
Jonathan Tow's avatar
Jonathan Tow committed
129
            turnid = len(doc["questions"]["input_text"])
bzantium's avatar
bzantium committed
130
        raw_text = doc["answers"]["input_text"][turnid - 1]
Leo Gao's avatar
Leo Gao committed
131
        return " " + raw_text
thefazzer's avatar
thefazzer committed
132

Leo Gao's avatar
Leo Gao committed
133
    def construct_requests(self, doc, ctx):
bzantium's avatar
bzantium committed
134
        """Uses RequestFactory to construct Requests and returns an iterable of
Leo Gao's avatar
Leo Gao committed
135
136
137
138
139
        Requests which will be sent to the LM.

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
bzantium's avatar
bzantium committed
140
            The context string, generated by fewshot_context. This includes the natural
Leo Gao's avatar
Leo Gao committed
141
            language description, as well as the few shot examples, and the question
bzantium's avatar
bzantium committed
142
            part of the document for `doc`.
Leo Gao's avatar
Leo Gao committed
143
        """
bzantium's avatar
bzantium committed
144
        cont_request = rf.greedy_until(ctx, {"until": ["\nQ:"]})
145
        return cont_request
thefazzer's avatar
thefazzer committed
146

Leo Gao's avatar
Leo Gao committed
147
    def process_results(self, doc, results):
bzantium's avatar
bzantium committed
148
149
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
Leo Gao's avatar
Leo Gao committed
150
151
152
153
154
155
156
        the metric for that one document

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
        """
Jonathan Tow's avatar
Jonathan Tow committed
157
        turn_id = len(doc["questions"]["input_text"])
158
        gold_list = self.get_answers(doc, turn_id)
bzantium's avatar
bzantium committed
159
        pred = results[0].strip().split("\n")[0]
160

thefazzer's avatar
thefazzer committed
161
        scores = self.compute_scores(gold_list, pred)
162

thefazzer's avatar
thefazzer committed
163
        return {
bzantium's avatar
bzantium committed
164
165
            "f1": scores["f1"],
            "em": scores["em"],
thefazzer's avatar
thefazzer committed
166
        }
167
168

    def higher_is_better(self):
169
        return {
170
171
            "f1": True,
            "em": True,
172
        }
Leo Gao's avatar
Leo Gao committed
173

174
    def aggregation(self):
175
        return {
176
177
            "f1": mean,
            "em": mean,
Leo Gao's avatar
Leo Gao committed
178
        }