coqa.py 4.89 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
"""
CoQA: A Conversational Question Answering Challenge
https://arxiv.org/pdf/1808.07042.pdf

CoQA is a large-scale dataset for building Conversational Question Answering 
systems. The goal of the CoQA challenge is to measure the ability of machines to 
understand a text passage and answer a series of interconnected questions that 
appear in a conversation.

Homepage: https://stanfordnlp.github.io/coqa/
11
"""
Jonathan Tow's avatar
Jonathan Tow committed
12
import inspect
13
import transformers.data.metrics.squad_metrics as squad_metrics
Jonathan Tow's avatar
Jonathan Tow committed
14
import lm_eval.datasets.coqa.coqa
jon-tow's avatar
jon-tow committed
15
from lm_eval.base import PromptSourceTask, Task, rf, mean
16
from itertools import zip_longest
Jonathan Tow's avatar
Jonathan Tow committed
17

18

19
_CITATION = """
20
21
22
23
24
25
26
27
28
@misc{reddy2018coqa,
    title={CoQA: A Conversational Question Answering Challenge},
    author={Siva Reddy and Danqi Chen and Christopher D. Manning},
    year={2018},
    eprint={1808.07042},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}
"""
29

30

31
class CoQA(PromptSourceTask):
Leo Gao's avatar
Leo Gao committed
32
    VERSION = 1
33
    DATASET_PATH = "coqa"
Jonathan Tow's avatar
Jonathan Tow committed
34
    DATASET_NAME = None
sdtblck's avatar
sdtblck committed
35

36
37
38
39
    def has_training_docs(self):
        return True

    def has_validation_docs(self):
Anish Thite's avatar
Anish Thite committed
40
        return True
Jason Phang's avatar
Jason Phang committed
41
42
43
44

    def has_test_docs(self):
        return False

45
    def training_docs(self):
Jonathan Tow's avatar
Jonathan Tow committed
46
        return self.dataset["train"]
47
48

    def validation_docs(self):
Jonathan Tow's avatar
Jonathan Tow committed
49
        return self.dataset["validation"]
50
51

    def test_docs(self):
Leo Gao's avatar
Leo Gao committed
52
        pass
53

cjlovering's avatar
cjlovering committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    # @classmethod
    # def get_answers(cls, doc, turn_id):
    #     # Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
    #     answers = []
    #     answer_forturn = doc["answers"]["input_text"][turn_id - 1]
    #     answers.append(answer_forturn)
    #     additional_answers = doc.get("additional_answers")
    #     if additional_answers:
    #         for key in additional_answers:
    #             additional_answer_for_turn = additional_answers[key]["input_text"][
    #                 turn_id - 1
    #             ]
    #             if additional_answer_for_turn.lower() not in map(str.lower, answers):
    #                 answers.append(additional_answer_for_turn)
    #     return answers
Leo Gao's avatar
Leo Gao committed
69

jon-tow's avatar
jon-tow committed
70

71
72
    @staticmethod
    def compute_scores(gold_list, pred):
thefazzer's avatar
thefazzer committed
73
74
        # tests for exact match and on the normalised answer (compute_exact)
        # test for overlap (compute_f1)
75
76
77
78
        f1_sum = 0.0
        em_sum = 0.0
        if len(gold_list) > 1:
            for i in range(len(gold_list)):
cjlovering's avatar
cjlovering committed
79
                gold_answers = gold_list[0:i] + gold_list[i + 1 :]
thefazzer's avatar
thefazzer committed
80
                # predictions compared against (n) golds and take maximum
cjlovering's avatar
cjlovering committed
81
82
83
                em_sum += max(
                    squad_metrics.compute_exact(a, pred) for a in gold_answers
                )
84
85
86
87
88
                f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers)
        else:
            em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
            f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list)

cjlovering's avatar
cjlovering committed
89
90
91
92
        return {
            "em": em_sum / max(1, len(gold_list)),
            "f1": f1_sum / max(1, len(gold_list)),
        }
thefazzer's avatar
thefazzer committed
93

jon-tow's avatar
jon-tow committed
94
95
    def eos_token(self):
        return "\n"
Leo Gao's avatar
Leo Gao committed
96

jon-tow's avatar
jon-tow committed
97
98
99
100
101
102
103
104
105
106
107
108
    # def construct_requests(self, doc, ctx):
    #     """Uses RequestFactory to construct Requests and returns an iterable of
    #     Requests which will be sent to the LM.

    #     :param doc:
    #         The document as returned from training_docs, validation_docs, or test_docs.
    #     :param ctx: str
    #         The context string, generated by fewshot_context. This includes the natural
    #         language description, as well as the few shot examples, and the question
    #         part of the document for `doc`.
    #     """
    #     return cont_request
thefazzer's avatar
thefazzer committed
109

Leo Gao's avatar
Leo Gao committed
110
    def process_results(self, doc, results):
cjlovering's avatar
cjlovering committed
111
112
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
Leo Gao's avatar
Leo Gao committed
113
114
115
116
117
118
119
        the metric for that one document

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
        """
cjlovering's avatar
cjlovering committed
120
121
        target = self.doc_to_target(doc).strip()
        pred = results[0].strip().split("\n")[0]
jon-tow's avatar
jon-tow committed
122
123
124
125
126
127
128
        print("*" * 80)
        print(f"DOC: {doc}")
#        print(f"PS: {self.prompt.apply(doc)}")
        print(f"TEXT: {self.doc_to_text(doc)}")
        print(f"TARGET: {target} END TARGET")
        print(pred)
        print("*" * 80)
cjlovering's avatar
cjlovering committed
129
130
131

        # turn_id = len(doc["questions"]["input_text"])
        # gold_list = self.get_answers(doc, turn_id)
132

cjlovering's avatar
cjlovering committed
133
134
        # TODO: Add HF metrics mapped from promptsource metadata.
        scores = self.compute_scores([target], pred)
135

thefazzer's avatar
thefazzer committed
136
        return {
cjlovering's avatar
cjlovering committed
137
138
            "f1": scores["f1"],
            "em": scores["em"],
thefazzer's avatar
thefazzer committed
139
        }
140
141

    def higher_is_better(self):
142
        return {
143
144
            "f1": True,
            "em": True,
145
        }
Leo Gao's avatar
Leo Gao committed
146

147
    def aggregation(self):
148
        return {
149
150
            "f1": mean,
            "em": mean,
Leo Gao's avatar
Leo Gao committed
151
        }