coqa.py 5.75 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
"""
CoQA: A Conversational Question Answering Challenge
https://arxiv.org/pdf/1808.07042.pdf

CoQA is a large-scale dataset for building Conversational Question Answering 
systems. The goal of the CoQA challenge is to measure the ability of machines to 
understand a text passage and answer a series of interconnected questions that 
appear in a conversation.

Homepage: https://stanfordnlp.github.io/coqa/
11
"""
Jonathan Tow's avatar
Jonathan Tow committed
12
import inspect
13
import transformers.data.metrics.squad_metrics as squad_metrics
Jonathan Tow's avatar
Jonathan Tow committed
14
import lm_eval.datasets.coqa.coqa
15
from lm_eval.base import Task, rf, mean
16
from itertools import zip_longest
Jonathan Tow's avatar
Jonathan Tow committed
17

18

19
_CITATION = """
20
21
22
23
24
25
26
27
28
@misc{reddy2018coqa,
    title={CoQA: A Conversational Question Answering Challenge},
    author={Siva Reddy and Danqi Chen and Christopher D. Manning},
    year={2018},
    eprint={1808.07042},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}
"""
29

30

31
class CoQA(Task):
Leo Gao's avatar
Leo Gao committed
32
    VERSION = 1
33
    DATASET_PATH = inspect.getfile(lm_eval.datasets.coqa.coqa)
Jonathan Tow's avatar
Jonathan Tow committed
34
    DATASET_NAME = None
sdtblck's avatar
sdtblck committed
35

36
37
38
39
    def has_training_docs(self):
        return True

    def has_validation_docs(self):
Anish Thite's avatar
Anish Thite committed
40
        return True
Jason Phang's avatar
Jason Phang committed
41
42
43
44

    def has_test_docs(self):
        return False

45
    def training_docs(self):
Jonathan Tow's avatar
Jonathan Tow committed
46
        return self.dataset["train"]
47
48

    def validation_docs(self):
Jonathan Tow's avatar
Jonathan Tow committed
49
        return self.dataset["validation"]
50
51

    def test_docs(self):
Leo Gao's avatar
Leo Gao committed
52
        pass
53

54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    def doc_to_text(self, doc):
        # Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1} 
        # and a question qi, the task is to predict the answer ai
        doc_text = doc["story"] + '\n\n'
        for (q, a) in zip_longest(doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]):   # omit target answer ai
            question = f"Q: {q}\n\n"
            answer = f"A: {a}\n\n" if a is not None else "A:"
            doc_text += question + answer
        return doc_text
        
    @classmethod
    def get_answers(cls, doc, turn_id):
        # Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
        answers = []
        answer_forturn = doc["answers"]["input_text"][turn_id - 1]
        answers.append(answer_forturn)
        
        additional_answers = doc.get("additional_answers")
        if additional_answers:
            for key in additional_answers:
                additional_answer_for_turn = additional_answers[key]["input_text"][turn_id - 1]
                if additional_answer_for_turn.lower() not in map(str.lower, answers):
                    answers.append(additional_answer_for_turn)
        return answers
    
    @classmethod
    def get_answer_choice(self, raw_text):
        # Function maps answers to CoQA answer categories
        # ~ 1/5 of the CoQA answers are Yes/No 
        # ~ 2/3 of the CoQA answers are span-based
        # (answers overlap with the passage ignoring punctuation and case mismatch)
        if raw_text == "unknown":
            return '0'
        if squad_metrics.normalize_answer(raw_text) == "yes":
            return '1'
        if squad_metrics.normalize_answer(raw_text) == "no":
            return '2'
        return '3' # Not a yes/no question
Leo Gao's avatar
Leo Gao committed
92

93
94
    @staticmethod
    def compute_scores(gold_list, pred):
thefazzer's avatar
thefazzer committed
95
96
        # tests for exact match and on the normalised answer (compute_exact)
        # test for overlap (compute_f1)
97
98
99
100
        f1_sum = 0.0
        em_sum = 0.0
        if len(gold_list) > 1:
            for i in range(len(gold_list)):
101
                gold_answers = gold_list[0:i] + gold_list[i + 1:]
thefazzer's avatar
thefazzer committed
102
                # predictions compared against (n) golds and take maximum
103
                em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_answers)
104
105
106
107
108
                f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers)
        else:
            em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
            f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list)

109
        return {'em': em_sum / max(1, len(gold_list)), 'f1': f1_sum / max(1, len(gold_list))}
thefazzer's avatar
thefazzer committed
110

111
112
113
114
115
116
    def doc_to_target(self, doc, turnid=None):
        # Default to prediction of last turn.
        if turnid is None:
            turnid = len(doc["questions"]["input_text"])
        raw_text = doc['answers']["input_text"][turnid - 1]
        return " " + raw_text
Leo Gao's avatar
Leo Gao committed
117

118
119
120
    def construct_requests(self, doc, ctx):
        """ Uses RequestFactory to construct Requests and returns an iterable of 
        Requests which will be sent to the LM.
jon-tow's avatar
jon-tow committed
121

122
123
124
125
126
127
128
129
130
        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
            The context string, generated by fewshot_context. This includes the natural 
            language description, as well as the few shot examples, and the question
            part of the document for `doc`. 
        """
        cont_request = rf.greedy_until(ctx, ['\nQ:'])
        return cont_request
thefazzer's avatar
thefazzer committed
131

Leo Gao's avatar
Leo Gao committed
132
    def process_results(self, doc, results):
133
134
        """Take a single document and the LM results and evaluates, returning a 
        dict where keys are the names of submetrics and values are the values of 
Leo Gao's avatar
Leo Gao committed
135
136
137
138
139
140
141
        the metric for that one document

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
        """
142
143
144
        turn_id = len(doc["questions"]["input_text"])
        gold_list = self.get_answers(doc, turn_id)
        pred = results[0].strip().split('\n')[0]
145

146
        scores = self.compute_scores(gold_list, pred)
147

148
149
150
151
        return {
            "f1": scores['f1'],
            "em": scores['em'],
        }
cjlovering's avatar
cjlovering committed
152

153
    def higher_is_better(self):
154
        return {
155
156
            "f1": True,
            "em": True,
157
        }
Leo Gao's avatar
Leo Gao committed
158

159
    def aggregation(self):
160
        return {
161
162
            "f1": mean,
            "em": mean,
Leo Gao's avatar
Leo Gao committed
163
        }