drop.py 9.25 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
"""
DROP: A Reading Comprehension Benchmark Requiring Discrete Reasoning Over Paragraphs
https://aclanthology.org/attachments/N19-1246.Supplementary.pdf

DROP is a QA dataset which tests comprehensive understanding of paragraphs. In 
this crowdsourced, adversarially-created, 96k question-answering benchmark, a 
system must resolve multiple references in a question, map them onto a paragraph,
and perform discrete operations over them (such as addition, counting, or sorting).

Homepage: https://allenai.org/data/drop

Acknowledgement: This implementation is based on the official evaluation for `DROP`:
https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc/eval/drop_eval.py
"""
Jonathan Tow's avatar
Jonathan Tow committed
15
import inspect
Jon Tow's avatar
Jon Tow committed
16
17
import numpy as np
import re
18
import string
Jonathan Tow's avatar
Jonathan Tow committed
19
import lm_eval.datasets.drop.drop
Jon Tow's avatar
Jon Tow committed
20
from scipy.optimize import linear_sum_assignment
21
from lm_eval.base import PromptSourceTask, rf
Jon Tow's avatar
Jon Tow committed
22
23
from lm_eval.metrics import mean

24

Jonathan Tow's avatar
Jonathan Tow committed
25
_CITATION = """
26
27
28
29
30
31
32
33
34
35
36
@misc{dua2019drop,
    title={DROP: A Reading Comprehension Benchmark Requiring Discrete Reasoning Over Paragraphs}, 
    author={Dheeru Dua and Yizhong Wang and Pradeep Dasigi and Gabriel Stanovsky and Sameer Singh and Matt Gardner},
    year={2019},
    eprint={1903.00161},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}
"""


silentv0x's avatar
silentv0x committed
37
_ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE)
Anish Thite's avatar
Anish Thite committed
38

39

40
class DROP(PromptSourceTask):
Leo Gao's avatar
Leo Gao committed
41
    VERSION = 1
cjlovering's avatar
cjlovering committed
42
    DATASET_PATH = "drop"  # inspect.getfile(lm_eval.datasets.drop.drop)
Jonathan Tow's avatar
Jonathan Tow committed
43
    DATASET_NAME = None
44

Anish Thite's avatar
Anish Thite committed
45
46
    def has_training_docs(self):
        return True
Jon Tow's avatar
Jon Tow committed
47

Anish Thite's avatar
Anish Thite committed
48
49
50
51
52
53
    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

Jonathan Tow's avatar
Jonathan Tow committed
54
    def training_docs(self):
jon-tow's avatar
jon-tow committed
55
56
57
58
        # if self._training_docs is None:
        #     self._training_docs = list()
        # return self._training_docs
        return self.dataset["train"]
Jonathan Tow's avatar
Jonathan Tow committed
59
60

    def validation_docs(self):
jon-tow's avatar
jon-tow committed
61
        return self.dataset["validation"]
silentv0x's avatar
silentv0x committed
62
63
64
65
66
67
68
69

    @classmethod
    def parse_answer(cls, answer):
        # NOTE: Everything is returned as a tuple for uniformity and hashability.
        if answer["number"] != "":
            return (str(answer["number"]),)
        if answer["spans"] != []:
            return tuple(answer["spans"])
cjlovering's avatar
cjlovering committed
70
71
72
73
74
        return (
            " ".join(
                [answer["date"]["day"], answer["date"]["month"], answer["date"]["year"]]
            ).strip(),
        )
Jon Tow's avatar
Jon Tow committed
75

cjlovering's avatar
cjlovering committed
76
77
    # def doc_to_text(self, doc):
    #     return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:"
Jon Tow's avatar
Jon Tow committed
78

cjlovering's avatar
cjlovering committed
79
80
    # def doc_to_target(self, doc):
    #     return " " + ", ".join(doc["answers"][0])
Anish Thite's avatar
Anish Thite committed
81

jon-tow's avatar
jon-tow committed
82
83
84
85
86
87
88
89
90
91
92
93
94
    # def construct_requests(self, doc, ctx):
    #     """Uses RequestFactory to construct Requests and returns an iterable of
    #     Requests which will be sent to the LM.

    #     :param doc:
    #         The document as returned from training_docs, validation_docs, or test_docs.
    #     :param ctx: str
    #         The context string, generated by fewshot_context. This includes the natural
    #         language description, as well as the few shot examples, and the question
    #         part of the document for `doc`.
    #     """
    #     conts = [rf.greedy_until(ctx, ["."])]
    #     return conts
cjlovering's avatar
cjlovering committed
95
    def stopping_criteria(self):
jon-tow's avatar
jon-tow committed
96
        return "."
Jon Tow's avatar
Jon Tow committed
97

Leo Gao's avatar
Leo Gao committed
98
    def process_results(self, doc, results):
Jon Tow's avatar
Jon Tow committed
99
100
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
Leo Gao's avatar
Leo Gao committed
101
102
        the metric for that one document

Jon Tow's avatar
Jon Tow committed
103
        :param doc:
Jon Tow's avatar
Jon Tow committed
104
            The document as returned from training_docs, validation_docs, or test_docs.
Leo Gao's avatar
Leo Gao committed
105
106
107
        :param results:
            The results of the requests created in construct_requests.
        """
cjlovering's avatar
cjlovering committed
108
109
110
111

        pred = results[0].strip()
        target = self.doc_to_target(doc).strip()

jon-tow's avatar
jon-tow committed
112
113
114
115
116
        print("*" * 80)
        print(f"DOC: {doc}")
        print(f"PS: {self.prompt.apply(doc)}")
        print(f"TEXT: {self.doc_to_text(doc)}")
        print(f"TARGET: {target} END TARGET")
cjlovering's avatar
cjlovering committed
117
        print(f"PRED: {pred} END PRED")
jon-tow's avatar
jon-tow committed
118
119
        print("*" * 80)

cjlovering's avatar
cjlovering committed
120
121
122
        preds = [pred]
        golds = [target]

silentv0x's avatar
silentv0x committed
123
124
125
126
127
128
129
        max_em = 0
        max_f1 = 0
        for gold_answer in golds:
            exact_match, f1_score = self.get_metrics(preds, gold_answer)
            if gold_answer[0].strip():
                max_em = max(max_em, exact_match)
                max_f1 = max(max_f1, f1_score)
cjlovering's avatar
cjlovering committed
130
        return {"em": max_em, "f1": max_f1}
Jon Tow's avatar
Jon Tow committed
131

silentv0x's avatar
silentv0x committed
132
133
134
135
136
137
138
    def get_metrics(self, predicted, gold):
        """
        Takes a predicted answer and a gold answer (that are both either a string or a list of
        strings), and returns exact match and the DROP F1 metric for the prediction.  If you are
        writing a script for evaluating objects in memory (say, the output of predictions during
        validation, or while training), this is the function you want to call, after using
        :func:`answer_json_to_strings` when reading the gold answer from the released data file.
139
        """
silentv0x's avatar
silentv0x committed
140
141
142
        predicted_bags = self._answer_to_bags(predicted)
        gold_bags = self._answer_to_bags(gold)

cjlovering's avatar
cjlovering committed
143
144
145
        if set(predicted_bags[0]) == set(gold_bags[0]) and len(
            predicted_bags[0]
        ) == len(gold_bags[0]):
silentv0x's avatar
silentv0x committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
            exact_match = 1.0
        else:
            exact_match = 0.0

        f1_per_bag = self._align_bags(predicted_bags[1], gold_bags[1])
        f1 = np.mean(f1_per_bag)
        f1 = round(f1, 2)
        return exact_match, f1

    def _answer_to_bags(self, answer):
        if isinstance(answer, (list, tuple)):
            raw_spans = answer
        else:
            raw_spans = [answer]
        normalized_spans = []
        token_bags = []
        for raw_span in raw_spans:
            normalized_span = self._normalize(raw_span)
            normalized_spans.append(normalized_span)
            token_bags.append(set(normalized_span.split()))
        return normalized_spans, token_bags

    def _align_bags(self, predicted, gold):
        """
        Takes gold and predicted answer sets and first finds the optimal 1-1 alignment
        between them and gets maximum metric values over all the answers.
        """
        scores = np.zeros([len(gold), len(predicted)])
        for gold_index, gold_item in enumerate(gold):
            for pred_index, pred_item in enumerate(predicted):
                if self._match_numbers_if_present(gold_item, pred_item):
cjlovering's avatar
cjlovering committed
177
178
179
                    scores[gold_index, pred_index] = self._compute_f1(
                        pred_item, gold_item
                    )
Jon Tow's avatar
Jon Tow committed
180
        row_ind, col_ind = linear_sum_assignment(-scores)
silentv0x's avatar
silentv0x committed
181
182

        max_scores = np.zeros([max(len(gold), len(predicted))])
Jon Tow's avatar
Jon Tow committed
183
184
185
186
        for row, column in zip(row_ind, col_ind):
            max_scores[row] = max(max_scores[row], scores[row, column])
        return max_scores

silentv0x's avatar
silentv0x committed
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    def _compute_f1(self, predicted_bag, gold_bag):
        intersection = len(gold_bag.intersection(predicted_bag))
        if not predicted_bag:
            precision = 1.0
        else:
            precision = intersection / float(len(predicted_bag))
        if not gold_bag:
            recall = 1.0
        else:
            recall = intersection / float(len(gold_bag))
        f1 = (
            (2 * precision * recall) / (precision + recall)
            if not (precision == 0.0 and recall == 0.0)
            else 0.0
        )
Jon Tow's avatar
Jon Tow committed
202
203
        return f1

silentv0x's avatar
silentv0x committed
204
205
206
207
208
209
210
211
212
213
    def _match_numbers_if_present(self, gold_bag, predicted_bag):
        gold_numbers = set()
        predicted_numbers = set()
        for word in gold_bag:
            if self._is_number(word):
                gold_numbers.add(word)
        for word in predicted_bag:
            if self._is_number(word):
                predicted_numbers.add(word)
        if (not gold_numbers) or gold_numbers.intersection(predicted_numbers):
214
215
216
217
218
219
220
221
222
            return True
        return False

    def _is_number(self, text):
        try:
            float(text)
            return True
        except ValueError:
            return False
Jon Tow's avatar
Jon Tow committed
223

silentv0x's avatar
silentv0x committed
224
225
    def _remove_articles(self, text):
        return _ARTICLES.sub(" ", text)
226

silentv0x's avatar
silentv0x committed
227
228
    def _white_space_fix(self, text):
        return " ".join(text.split())
229

silentv0x's avatar
silentv0x committed
230
231
232
233
234
235
    def _remove_punc(self, text):
        exclude = set(string.punctuation)
        if not self._is_number(text):
            return "".join(ch for ch in text if ch not in exclude)
        else:
            return text
236

silentv0x's avatar
silentv0x committed
237
238
    def _fix_number(self, text):
        return str(float(text)) if self._is_number(text) else text
239

silentv0x's avatar
Bug fix  
silentv0x committed
240
    def _tokenize(self, text):
silentv0x's avatar
silentv0x committed
241
        return re.split(" |-", text)
242

silentv0x's avatar
silentv0x committed
243
    def _normalize(self, answer):
244
        tokens = [
cjlovering's avatar
cjlovering committed
245
246
247
248
249
            self._white_space_fix(
                self._remove_articles(
                    self._fix_number(self._remove_punc(token.lower()))
                )
            )
silentv0x's avatar
silentv0x committed
250
            for token in self._tokenize(answer)
251
        ]
Jon Tow's avatar
Fixes  
Jon Tow committed
252
        tokens = [token for token in tokens if token.strip()]
Jon Tow's avatar
Jon Tow committed
253
254
        normalized = " ".join(tokens).strip()
        return normalized
Leo Gao's avatar
Leo Gao committed
255
256
257
258

    def aggregation(self):
        """
        :returns: {str: [float] -> float}
Jon Tow's avatar
Jon Tow committed
259
260
            A dictionary where keys are the names of submetrics and values are
            functions that aggregate a list of metrics
Leo Gao's avatar
Leo Gao committed
261
        """
cjlovering's avatar
cjlovering committed
262
        return {"em": mean, "f1": mean}
Leo Gao's avatar
Leo Gao committed
263
264
265
266

    def higher_is_better(self):
        """
        :returns: {str: bool}
Jon Tow's avatar
Jon Tow committed
267
268
            A dictionary where keys are the names of submetrics and values are
            whether a higher value of the submetric is better
Leo Gao's avatar
Leo Gao committed
269
        """
cjlovering's avatar
cjlovering committed
270
        return {"em": True, "f1": True}