coqa.py 4.63 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
"""
CoQA: A Conversational Question Answering Challenge
https://arxiv.org/pdf/1808.07042.pdf

CoQA is a large-scale dataset for building Conversational Question Answering 
systems. The goal of the CoQA challenge is to measure the ability of machines to 
understand a text passage and answer a series of interconnected questions that 
appear in a conversation.

Homepage: https://stanfordnlp.github.io/coqa/
11
"""
Jonathan Tow's avatar
Jonathan Tow committed
12
import inspect
13
import transformers.data.metrics.squad_metrics as squad_metrics
Jonathan Tow's avatar
Jonathan Tow committed
14
import lm_eval.datasets.coqa.coqa
jon-tow's avatar
jon-tow committed
15
from lm_eval.base import PromptSourceTask, Task, rf, mean
16
from itertools import zip_longest
Jonathan Tow's avatar
Jonathan Tow committed
17

18

19
_CITATION = """
20
21
22
23
24
25
26
27
28
@misc{reddy2018coqa,
    title={CoQA: A Conversational Question Answering Challenge},
    author={Siva Reddy and Danqi Chen and Christopher D. Manning},
    year={2018},
    eprint={1808.07042},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}
"""
29

30

31
class CoQA(PromptSourceTask):
Leo Gao's avatar
Leo Gao committed
32
    VERSION = 1
33
    DATASET_PATH = "coqa"
Jonathan Tow's avatar
Jonathan Tow committed
34
    DATASET_NAME = None
sdtblck's avatar
sdtblck committed
35

36
37
38
39
    def has_training_docs(self):
        return True

    def has_validation_docs(self):
Anish Thite's avatar
Anish Thite committed
40
        return True
Jason Phang's avatar
Jason Phang committed
41
42
43
44

    def has_test_docs(self):
        return False

45
    def training_docs(self):
Jonathan Tow's avatar
Jonathan Tow committed
46
        return self.dataset["train"]
47
48

    def validation_docs(self):
Jonathan Tow's avatar
Jonathan Tow committed
49
        return self.dataset["validation"]
50
51

    def test_docs(self):
Leo Gao's avatar
Leo Gao committed
52
        pass
53

cjlovering's avatar
cjlovering committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    # @classmethod
    # def get_answers(cls, doc, turn_id):
    #     # Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
    #     answers = []
    #     answer_forturn = doc["answers"]["input_text"][turn_id - 1]
    #     answers.append(answer_forturn)
    #     additional_answers = doc.get("additional_answers")
    #     if additional_answers:
    #         for key in additional_answers:
    #             additional_answer_for_turn = additional_answers[key]["input_text"][
    #                 turn_id - 1
    #             ]
    #             if additional_answer_for_turn.lower() not in map(str.lower, answers):
    #                 answers.append(additional_answer_for_turn)
    #     return answers
Leo Gao's avatar
Leo Gao committed
69

70
71
    @staticmethod
    def compute_scores(gold_list, pred):
thefazzer's avatar
thefazzer committed
72
73
        # tests for exact match and on the normalised answer (compute_exact)
        # test for overlap (compute_f1)
74
75
76
77
        f1_sum = 0.0
        em_sum = 0.0
        if len(gold_list) > 1:
            for i in range(len(gold_list)):
cjlovering's avatar
cjlovering committed
78
                gold_answers = gold_list[0:i] + gold_list[i + 1 :]
thefazzer's avatar
thefazzer committed
79
                # predictions compared against (n) golds and take maximum
cjlovering's avatar
cjlovering committed
80
81
82
                em_sum += max(
                    squad_metrics.compute_exact(a, pred) for a in gold_answers
                )
83
84
85
86
87
                f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers)
        else:
            em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
            f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list)

cjlovering's avatar
cjlovering committed
88
89
90
91
        return {
            "em": em_sum / max(1, len(gold_list)),
            "f1": f1_sum / max(1, len(gold_list)),
        }
thefazzer's avatar
thefazzer committed
92

cjlovering's avatar
cjlovering committed
93
    def stopping_criteria(self):
94
        return "\n\n"
Leo Gao's avatar
Leo Gao committed
95

jon-tow's avatar
jon-tow committed
96
97
98
99
100
101
102
103
104
105
106
107
    # def construct_requests(self, doc, ctx):
    #     """Uses RequestFactory to construct Requests and returns an iterable of
    #     Requests which will be sent to the LM.

    #     :param doc:
    #         The document as returned from training_docs, validation_docs, or test_docs.
    #     :param ctx: str
    #         The context string, generated by fewshot_context. This includes the natural
    #         language description, as well as the few shot examples, and the question
    #         part of the document for `doc`.
    #     """
    #     return cont_request
thefazzer's avatar
thefazzer committed
108

Leo Gao's avatar
Leo Gao committed
109
    def process_results(self, doc, results):
cjlovering's avatar
cjlovering committed
110
111
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
Leo Gao's avatar
Leo Gao committed
112
113
114
115
116
117
118
        the metric for that one document

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
        """
cjlovering's avatar
cjlovering committed
119
120
121
        target = self.doc_to_target(doc).strip()
        pred = results[0].strip().split("\n")[0]
        scores = self.compute_scores([target], pred)
122

cjlovering's avatar
cjlovering committed
123
        out = {
cjlovering's avatar
cjlovering committed
124
125
            "f1": scores["f1"],
            "em": scores["em"],
thefazzer's avatar
thefazzer committed
126
        }
127

cjlovering's avatar
cjlovering committed
128
        if self.save_examples:
129
            example = {"target": target, "pred": pred}
cjlovering's avatar
cjlovering committed
130
131
132
            return out, example
        return out

133
    def higher_is_better(self):
134
        return {
135
136
            "f1": True,
            "em": True,
137
        }
Leo Gao's avatar
Leo Gao committed
138

139
    def aggregation(self):
140
        return {
141
142
            "f1": mean,
            "em": mean,
Leo Gao's avatar
Leo Gao committed
143
        }