coqa.py 4.65 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
"""
CoQA: A Conversational Question Answering Challenge
https://arxiv.org/pdf/1808.07042.pdf

CoQA is a large-scale dataset for building Conversational Question Answering 
systems. The goal of the CoQA challenge is to measure the ability of machines to 
understand a text passage and answer a series of interconnected questions that 
appear in a conversation.

Homepage: https://stanfordnlp.github.io/coqa/
11
"""
Jonathan Tow's avatar
Jonathan Tow committed
12
import inspect
13
import transformers.data.metrics.squad_metrics as squad_metrics
Jonathan Tow's avatar
Jonathan Tow committed
14
import lm_eval.datasets.coqa.coqa
jon-tow's avatar
jon-tow committed
15
from lm_eval.base import PromptSourceTask, Task, rf, mean
16
from itertools import zip_longest
Jonathan Tow's avatar
Jonathan Tow committed
17

18

19
_CITATION = """
20
21
22
23
24
25
26
27
28
@misc{reddy2018coqa,
    title={CoQA: A Conversational Question Answering Challenge},
    author={Siva Reddy and Danqi Chen and Christopher D. Manning},
    year={2018},
    eprint={1808.07042},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}
"""
29

30

31
class CoQA(PromptSourceTask):
Leo Gao's avatar
Leo Gao committed
32
    VERSION = 1
33
    DATASET_PATH = "coqa"
Jonathan Tow's avatar
Jonathan Tow committed
34
    DATASET_NAME = None
sdtblck's avatar
sdtblck committed
35

36
37
38
39
    def has_training_docs(self):
        return True

    def has_validation_docs(self):
Anish Thite's avatar
Anish Thite committed
40
        return True
Jason Phang's avatar
Jason Phang committed
41
42
43
44

    def has_test_docs(self):
        return False

45
    def training_docs(self):
Jonathan Tow's avatar
Jonathan Tow committed
46
        return self.dataset["train"]
47
48

    def validation_docs(self):
Jonathan Tow's avatar
Jonathan Tow committed
49
        return self.dataset["validation"]
50
51

    def test_docs(self):
Leo Gao's avatar
Leo Gao committed
52
        pass
53

cjlovering's avatar
cjlovering committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    # @classmethod
    # def get_answers(cls, doc, turn_id):
    #     # Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
    #     answers = []
    #     answer_forturn = doc["answers"]["input_text"][turn_id - 1]
    #     answers.append(answer_forturn)
    #     additional_answers = doc.get("additional_answers")
    #     if additional_answers:
    #         for key in additional_answers:
    #             additional_answer_for_turn = additional_answers[key]["input_text"][
    #                 turn_id - 1
    #             ]
    #             if additional_answer_for_turn.lower() not in map(str.lower, answers):
    #                 answers.append(additional_answer_for_turn)
    #     return answers
Leo Gao's avatar
Leo Gao committed
69

70
71
    @staticmethod
    def compute_scores(gold_list, pred):
thefazzer's avatar
thefazzer committed
72
73
        # tests for exact match and on the normalised answer (compute_exact)
        # test for overlap (compute_f1)
74
75
76
77
        f1_sum = 0.0
        em_sum = 0.0
        if len(gold_list) > 1:
            for i in range(len(gold_list)):
cjlovering's avatar
cjlovering committed
78
                gold_answers = gold_list[0:i] + gold_list[i + 1 :]
thefazzer's avatar
thefazzer committed
79
                # predictions compared against (n) golds and take maximum
cjlovering's avatar
cjlovering committed
80
81
82
                em_sum += max(
                    squad_metrics.compute_exact(a, pred) for a in gold_answers
                )
83
84
85
86
87
                f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers)
        else:
            em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
            f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list)

cjlovering's avatar
cjlovering committed
88
89
90
91
        return {
            "em": em_sum / max(1, len(gold_list)),
            "f1": f1_sum / max(1, len(gold_list)),
        }
thefazzer's avatar
thefazzer committed
92

Leo Gao's avatar
Leo Gao committed
93
    def construct_requests(self, doc, ctx):
cjlovering's avatar
cjlovering committed
94
        """Uses RequestFactory to construct Requests and returns an iterable of
Leo Gao's avatar
Leo Gao committed
95
96
97
98
99
        Requests which will be sent to the LM.

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
cjlovering's avatar
cjlovering committed
100
            The context string, generated by fewshot_context. This includes the natural
Leo Gao's avatar
Leo Gao committed
101
            language description, as well as the few shot examples, and the question
cjlovering's avatar
cjlovering committed
102
            part of the document for `doc`.
Leo Gao's avatar
Leo Gao committed
103
        """
cjlovering's avatar
cjlovering committed
104
        cont_request = rf.greedy_until(ctx, ["\nQ:"])
105
        return cont_request
thefazzer's avatar
thefazzer committed
106

Leo Gao's avatar
Leo Gao committed
107
    def process_results(self, doc, results):
cjlovering's avatar
cjlovering committed
108
109
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
Leo Gao's avatar
Leo Gao committed
110
111
112
113
114
115
116
        the metric for that one document

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
        """
cjlovering's avatar
cjlovering committed
117
118
119
120
121
        target = self.doc_to_target(doc).strip()
        pred = results[0].strip().split("\n")[0]

        # turn_id = len(doc["questions"]["input_text"])
        # gold_list = self.get_answers(doc, turn_id)
122

cjlovering's avatar
cjlovering committed
123
124
        # TODO: Add HF metrics mapped from promptsource metadata.
        scores = self.compute_scores([target], pred)
125

thefazzer's avatar
thefazzer committed
126
        return {
cjlovering's avatar
cjlovering committed
127
128
            "f1": scores["f1"],
            "em": scores["em"],
thefazzer's avatar
thefazzer committed
129
        }
130
131

    def higher_is_better(self):
132
        return {
133
134
            "f1": True,
            "em": True,
135
        }
Leo Gao's avatar
Leo Gao committed
136

137
    def aggregation(self):
138
        return {
139
140
            "f1": mean,
            "em": mean,
Leo Gao's avatar
Leo Gao committed
141
        }