coqa.py 4.68 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
"""
CoQA: A Conversational Question Answering Challenge
https://arxiv.org/pdf/1808.07042.pdf

CoQA is a large-scale dataset for building Conversational Question Answering 
systems. The goal of the CoQA challenge is to measure the ability of machines to 
understand a text passage and answer a series of interconnected questions that 
appear in a conversation.

Homepage: https://stanfordnlp.github.io/coqa/
11
"""
Jonathan Tow's avatar
Jonathan Tow committed
12
import inspect
13
import transformers.data.metrics.squad_metrics as squad_metrics
Jonathan Tow's avatar
Jonathan Tow committed
14
import lm_eval.datasets.coqa.coqa
jon-tow's avatar
jon-tow committed
15
from lm_eval.base import PromptSourceTask, Task, rf, mean
16
from itertools import zip_longest
Jonathan Tow's avatar
Jonathan Tow committed
17

18

19
_CITATION = """
20
21
22
23
24
25
26
27
28
@misc{reddy2018coqa,
    title={CoQA: A Conversational Question Answering Challenge},
    author={Siva Reddy and Danqi Chen and Christopher D. Manning},
    year={2018},
    eprint={1808.07042},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}
"""
29

30

31
class CoQA(PromptSourceTask):
Leo Gao's avatar
Leo Gao committed
32
    VERSION = 1
Jonathan Tow's avatar
Jonathan Tow committed
33
34
    DATASET_PATH = inspect.getfile(lm_eval.datasets.coqa.coqa)
    DATASET_NAME = None
sdtblck's avatar
sdtblck committed
35

36
37
38
39
    def has_training_docs(self):
        return True

    def has_validation_docs(self):
Anish Thite's avatar
Anish Thite committed
40
        return True
Jason Phang's avatar
Jason Phang committed
41
42
43
44

    def has_test_docs(self):
        return False

45
    def training_docs(self):
Jonathan Tow's avatar
Jonathan Tow committed
46
        return self.dataset["train"]
47
48

    def validation_docs(self):
Jonathan Tow's avatar
Jonathan Tow committed
49
        return self.dataset["validation"]
50
51

    def test_docs(self):
Leo Gao's avatar
Leo Gao committed
52
        pass
53

cjlovering's avatar
cjlovering committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    # @classmethod
    # def get_answers(cls, doc, turn_id):
    #     # Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
    #     answers = []
    #     answer_forturn = doc["answers"]["input_text"][turn_id - 1]
    #     answers.append(answer_forturn)

    #     additional_answers = doc.get("additional_answers")
    #     if additional_answers:
    #         for key in additional_answers:
    #             additional_answer_for_turn = additional_answers[key]["input_text"][
    #                 turn_id - 1
    #             ]
    #             if additional_answer_for_turn.lower() not in map(str.lower, answers):
    #                 answers.append(additional_answer_for_turn)
    #     return answers
Leo Gao's avatar
Leo Gao committed
70

71
72
    @staticmethod
    def compute_scores(gold_list, pred):
thefazzer's avatar
thefazzer committed
73
74
        # tests for exact match and on the normalised answer (compute_exact)
        # test for overlap (compute_f1)
75
76
77
78
        f1_sum = 0.0
        em_sum = 0.0
        if len(gold_list) > 1:
            for i in range(len(gold_list)):
cjlovering's avatar
cjlovering committed
79
                gold_answers = gold_list[0:i] + gold_list[i + 1 :]
thefazzer's avatar
thefazzer committed
80
                # predictions compared against (n) golds and take maximum
cjlovering's avatar
cjlovering committed
81
82
83
                em_sum += max(
                    squad_metrics.compute_exact(a, pred) for a in gold_answers
                )
84
85
86
87
88
                f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers)
        else:
            em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
            f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list)

cjlovering's avatar
cjlovering committed
89
90
91
92
        return {
            "em": em_sum / max(1, len(gold_list)),
            "f1": f1_sum / max(1, len(gold_list)),
        }
thefazzer's avatar
thefazzer committed
93

Leo Gao's avatar
Leo Gao committed
94
    def construct_requests(self, doc, ctx):
cjlovering's avatar
cjlovering committed
95
        """Uses RequestFactory to construct Requests and returns an iterable of
Leo Gao's avatar
Leo Gao committed
96
97
98
99
100
        Requests which will be sent to the LM.

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
cjlovering's avatar
cjlovering committed
101
            The context string, generated by fewshot_context. This includes the natural
Leo Gao's avatar
Leo Gao committed
102
            language description, as well as the few shot examples, and the question
cjlovering's avatar
cjlovering committed
103
            part of the document for `doc`.
Leo Gao's avatar
Leo Gao committed
104
        """
cjlovering's avatar
cjlovering committed
105
        cont_request = rf.greedy_until(ctx, ["\nQ:"])
106
        return cont_request
thefazzer's avatar
thefazzer committed
107

Leo Gao's avatar
Leo Gao committed
108
    def process_results(self, doc, results):
cjlovering's avatar
cjlovering committed
109
110
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
Leo Gao's avatar
Leo Gao committed
111
112
113
114
115
116
117
        the metric for that one document

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
        """
cjlovering's avatar
cjlovering committed
118
119
120
121
122
        target = self.doc_to_target(doc).strip()
        pred = results[0].strip().split("\n")[0]

        # turn_id = len(doc["questions"]["input_text"])
        # gold_list = self.get_answers(doc, turn_id)
123

cjlovering's avatar
cjlovering committed
124
125
        # TODO: Add HF metrics mapped from promptsource metadata.
        scores = self.compute_scores([target], pred)
126

thefazzer's avatar
thefazzer committed
127
        return {
cjlovering's avatar
cjlovering committed
128
129
            "f1": scores["f1"],
            "em": scores["em"],
thefazzer's avatar
thefazzer committed
130
        }
131
132

    def higher_is_better(self):
133
        return {
134
135
            "f1": True,
            "em": True,
136
        }
Leo Gao's avatar
Leo Gao committed
137

138
    def aggregation(self):
139
        return {
140
141
            "f1": mean,
            "em": mean,
Leo Gao's avatar
Leo Gao committed
142
        }