drop.py 10.3 KB
Newer Older
1
2
3
4
"""
DROP: A Reading Comprehension Benchmark Requiring Discrete Reasoning Over Paragraphs
https://aclanthology.org/attachments/N19-1246.Supplementary.pdf

bzantium's avatar
bzantium committed
5
6
DROP is a QA dataset which tests comprehensive understanding of paragraphs. In
this crowdsourced, adversarially-created, 96k question-answering benchmark, a
7
8
9
10
11
12
13
14
system must resolve multiple references in a question, map them onto a paragraph,
and perform discrete operations over them (such as addition, counting, or sorting).

Homepage: https://allenai.org/data/drop

Acknowledgement: This implementation is based on the official evaluation for `DROP`:
https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc/eval/drop_eval.py
"""
Jonathan Tow's avatar
Jonathan Tow committed
15
import inspect
Jon Tow's avatar
Jon Tow committed
16
17
import numpy as np
import re
18
import string
Jonathan Tow's avatar
Jonathan Tow committed
19
import lm_eval.datasets.drop.drop
Jon Tow's avatar
Jon Tow committed
20
21
22
23
from scipy.optimize import linear_sum_assignment
from lm_eval.base import Task, rf
from lm_eval.metrics import mean

24

Jonathan Tow's avatar
Jonathan Tow committed
25
_CITATION = """
26
@misc{dua2019drop,
bzantium's avatar
bzantium committed
27
    title={DROP: A Reading Comprehension Benchmark Requiring Discrete Reasoning Over Paragraphs},
28
29
30
31
32
33
34
35
36
    author={Dheeru Dua and Yizhong Wang and Pradeep Dasigi and Gabriel Stanovsky and Sameer Singh and Matt Gardner},
    year={2019},
    eprint={1903.00161},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}
"""


silentv0x's avatar
silentv0x committed
37
_ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE)
Anish Thite's avatar
Anish Thite committed
38

39

40
class DROP(Task):
Leo Gao's avatar
Leo Gao committed
41
    VERSION = 1
Jonathan Tow's avatar
Jonathan Tow committed
42
43
    DATASET_PATH = inspect.getfile(lm_eval.datasets.drop.drop)
    DATASET_NAME = None
44

Anish Thite's avatar
Anish Thite committed
45
46
    def has_training_docs(self):
        return True
Jon Tow's avatar
Jon Tow committed
47

Anish Thite's avatar
Anish Thite committed
48
49
50
51
52
53
    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

Jonathan Tow's avatar
Jonathan Tow committed
54
    def training_docs(self):
Jon Tow's avatar
Jon Tow committed
55
        if self._training_docs is None:
Jon Tow's avatar
Jon Tow committed
56
            self._training_docs = list(map(self._process_doc, self.dataset["train"]))
Jon Tow's avatar
Jon Tow committed
57
        return self._training_docs
Jonathan Tow's avatar
Jonathan Tow committed
58
59

    def validation_docs(self):
Jon Tow's avatar
Jon Tow committed
60
        return map(self._process_doc, self.dataset["validation"])
Jonathan Tow's avatar
Jonathan Tow committed
61

Jon Tow's avatar
Jon Tow committed
62
    def _process_doc(self, doc):
Jonathan Tow's avatar
Jonathan Tow committed
63
64
65
66
67
68
        return {
            "id": doc["query_id"],
            "passage": doc["passage"],
            "question": doc["question"],
            "answers": self.get_answers(doc),
        }
Anish Thite's avatar
Anish Thite committed
69

Jon Tow's avatar
Jon Tow committed
70
    @classmethod
silentv0x's avatar
silentv0x committed
71
    def get_answers(cls, qa):
Jonathan Tow's avatar
Jonathan Tow committed
72
        def _flatten_validated_answers(validated_answers):
bzantium's avatar
bzantium committed
73
            """Flattens a dict of lists of validated answers.
Jonathan Tow's avatar
Jonathan Tow committed
74
75
76
            {"number": ['1', '8'], ...}
            -> [{"number": ['1'], ...}, {"number": ['8'], ...}]
            """
bzantium's avatar
bzantium committed
77
            valid_answers = []
Jonathan Tow's avatar
Jonathan Tow committed
78
            for i in range(len(validated_answers["number"])):
bzantium's avatar
bzantium committed
79
80
81
82
83
84
85
86
87
                valid_answers.append(
                    {
                        "number": validated_answers["number"][i],
                        "date": validated_answers["date"][i],
                        "spans": validated_answers["spans"][i],
                    }
                )
            return valid_answers

silentv0x's avatar
silentv0x committed
88
89
        answers = []
        answers_set = set()
bzantium's avatar
bzantium committed
90
91
92
        candidates = [qa["answer"]] + _flatten_validated_answers(
            qa["validated_answers"]
        )
silentv0x's avatar
silentv0x committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        for candidate in candidates:
            answer = cls.parse_answer(candidate)
            if answer in answers_set:
                continue
            answers_set.add(answer)
            answers.append(answer)
        return answers

    @classmethod
    def parse_answer(cls, answer):
        # NOTE: Everything is returned as a tuple for uniformity and hashability.
        if answer["number"] != "":
            return (str(answer["number"]),)
        if answer["spans"] != []:
            return tuple(answer["spans"])
bzantium's avatar
bzantium committed
108
109
110
111
112
        return (
            " ".join(
                [answer["date"]["day"], answer["date"]["month"], answer["date"]["year"]]
            ).strip(),
        )
Jon Tow's avatar
Jon Tow committed
113
114
115
116

    def doc_to_text(self, doc):
        return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:"

bzantium's avatar
bzantium committed
117
118
119
120
121
122
    def should_decontaminate(self):
        return True

    def doc_to_decontamination_query(self, doc):
        return doc["passage"] + " " + doc["question"]

Jon Tow's avatar
Jon Tow committed
123
    def doc_to_target(self, doc):
silentv0x's avatar
silentv0x committed
124
        return " " + ", ".join(doc["answers"][0])
Anish Thite's avatar
Anish Thite committed
125

Leo Gao's avatar
Leo Gao committed
126
    def construct_requests(self, doc, ctx):
Jon Tow's avatar
Jon Tow committed
127
        """Uses RequestFactory to construct Requests and returns an iterable of
Leo Gao's avatar
Leo Gao committed
128
        Requests which will be sent to the LM.
129

Jon Tow's avatar
Jon Tow committed
130
        :param doc:
Leo Gao's avatar
Leo Gao committed
131
            The document as returned from training_docs, validation_docs, or test_docs.
Jon Tow's avatar
Jon Tow committed
132
        :param ctx: str
Jon Tow's avatar
Jon Tow committed
133
            The context string, generated by fewshot_context. This includes the natural
Leo Gao's avatar
Leo Gao committed
134
            language description, as well as the few shot examples, and the question
Jon Tow's avatar
Jon Tow committed
135
            part of the document for `doc`.
Leo Gao's avatar
Leo Gao committed
136
        """
bzantium's avatar
bzantium committed
137
        conts = [rf.greedy_until(ctx, {"until": ["."]})]
Jon Tow's avatar
Jon Tow committed
138
139
        return conts

Leo Gao's avatar
Leo Gao committed
140
    def process_results(self, doc, results):
Jon Tow's avatar
Jon Tow committed
141
142
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
Leo Gao's avatar
Leo Gao committed
143
144
        the metric for that one document

Jon Tow's avatar
Jon Tow committed
145
        :param doc:
Jon Tow's avatar
Jon Tow committed
146
            The document as returned from training_docs, validation_docs, or test_docs.
Leo Gao's avatar
Leo Gao committed
147
148
149
        :param results:
            The results of the requests created in construct_requests.
        """
150
        preds, golds = results, doc["answers"]
silentv0x's avatar
silentv0x committed
151
152
153
154
155
156
157
        max_em = 0
        max_f1 = 0
        for gold_answer in golds:
            exact_match, f1_score = self.get_metrics(preds, gold_answer)
            if gold_answer[0].strip():
                max_em = max(max_em, exact_match)
                max_f1 = max(max_f1, f1_score)
bzantium's avatar
bzantium committed
158
        return {"em": max_em, "f1": max_f1}
Jon Tow's avatar
Jon Tow committed
159

silentv0x's avatar
silentv0x committed
160
161
162
163
164
165
166
    def get_metrics(self, predicted, gold):
        """
        Takes a predicted answer and a gold answer (that are both either a string or a list of
        strings), and returns exact match and the DROP F1 metric for the prediction.  If you are
        writing a script for evaluating objects in memory (say, the output of predictions during
        validation, or while training), this is the function you want to call, after using
        :func:`answer_json_to_strings` when reading the gold answer from the released data file.
167
        """
silentv0x's avatar
silentv0x committed
168
169
170
        predicted_bags = self._answer_to_bags(predicted)
        gold_bags = self._answer_to_bags(gold)

bzantium's avatar
bzantium committed
171
172
173
        if set(predicted_bags[0]) == set(gold_bags[0]) and len(
            predicted_bags[0]
        ) == len(gold_bags[0]):
silentv0x's avatar
silentv0x committed
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
            exact_match = 1.0
        else:
            exact_match = 0.0

        f1_per_bag = self._align_bags(predicted_bags[1], gold_bags[1])
        f1 = np.mean(f1_per_bag)
        f1 = round(f1, 2)
        return exact_match, f1

    def _answer_to_bags(self, answer):
        if isinstance(answer, (list, tuple)):
            raw_spans = answer
        else:
            raw_spans = [answer]
        normalized_spans = []
        token_bags = []
        for raw_span in raw_spans:
            normalized_span = self._normalize(raw_span)
            normalized_spans.append(normalized_span)
            token_bags.append(set(normalized_span.split()))
        return normalized_spans, token_bags

    def _align_bags(self, predicted, gold):
        """
        Takes gold and predicted answer sets and first finds the optimal 1-1 alignment
        between them and gets maximum metric values over all the answers.
        """
        scores = np.zeros([len(gold), len(predicted)])
        for gold_index, gold_item in enumerate(gold):
            for pred_index, pred_item in enumerate(predicted):
                if self._match_numbers_if_present(gold_item, pred_item):
bzantium's avatar
bzantium committed
205
206
207
                    scores[gold_index, pred_index] = self._compute_f1(
                        pred_item, gold_item
                    )
Jon Tow's avatar
Jon Tow committed
208
        row_ind, col_ind = linear_sum_assignment(-scores)
silentv0x's avatar
silentv0x committed
209
210

        max_scores = np.zeros([max(len(gold), len(predicted))])
Jon Tow's avatar
Jon Tow committed
211
212
213
214
        for row, column in zip(row_ind, col_ind):
            max_scores[row] = max(max_scores[row], scores[row, column])
        return max_scores

silentv0x's avatar
silentv0x committed
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
    def _compute_f1(self, predicted_bag, gold_bag):
        intersection = len(gold_bag.intersection(predicted_bag))
        if not predicted_bag:
            precision = 1.0
        else:
            precision = intersection / float(len(predicted_bag))
        if not gold_bag:
            recall = 1.0
        else:
            recall = intersection / float(len(gold_bag))
        f1 = (
            (2 * precision * recall) / (precision + recall)
            if not (precision == 0.0 and recall == 0.0)
            else 0.0
        )
Jon Tow's avatar
Jon Tow committed
230
231
        return f1

silentv0x's avatar
silentv0x committed
232
233
234
235
236
237
238
239
240
241
    def _match_numbers_if_present(self, gold_bag, predicted_bag):
        gold_numbers = set()
        predicted_numbers = set()
        for word in gold_bag:
            if self._is_number(word):
                gold_numbers.add(word)
        for word in predicted_bag:
            if self._is_number(word):
                predicted_numbers.add(word)
        if (not gold_numbers) or gold_numbers.intersection(predicted_numbers):
242
243
244
245
246
247
248
249
250
            return True
        return False

    def _is_number(self, text):
        try:
            float(text)
            return True
        except ValueError:
            return False
Jon Tow's avatar
Jon Tow committed
251

silentv0x's avatar
silentv0x committed
252
253
    def _remove_articles(self, text):
        return _ARTICLES.sub(" ", text)
254

silentv0x's avatar
silentv0x committed
255
256
    def _white_space_fix(self, text):
        return " ".join(text.split())
257

silentv0x's avatar
silentv0x committed
258
259
260
261
262
263
    def _remove_punc(self, text):
        exclude = set(string.punctuation)
        if not self._is_number(text):
            return "".join(ch for ch in text if ch not in exclude)
        else:
            return text
264

silentv0x's avatar
silentv0x committed
265
266
    def _fix_number(self, text):
        return str(float(text)) if self._is_number(text) else text
267

silentv0x's avatar
Bug fix  
silentv0x committed
268
    def _tokenize(self, text):
silentv0x's avatar
silentv0x committed
269
        return re.split(" |-", text)
270

silentv0x's avatar
silentv0x committed
271
    def _normalize(self, answer):
272
        tokens = [
bzantium's avatar
bzantium committed
273
274
275
276
277
            self._white_space_fix(
                self._remove_articles(
                    self._fix_number(self._remove_punc(token.lower()))
                )
            )
silentv0x's avatar
silentv0x committed
278
            for token in self._tokenize(answer)
279
        ]
Jon Tow's avatar
Fixes  
Jon Tow committed
280
        tokens = [token for token in tokens if token.strip()]
Jon Tow's avatar
Jon Tow committed
281
282
        normalized = " ".join(tokens).strip()
        return normalized
Leo Gao's avatar
Leo Gao committed
283
284
285
286

    def aggregation(self):
        """
        :returns: {str: [float] -> float}
Jon Tow's avatar
Jon Tow committed
287
288
            A dictionary where keys are the names of submetrics and values are
            functions that aggregate a list of metrics
Leo Gao's avatar
Leo Gao committed
289
        """
bzantium's avatar
bzantium committed
290
        return {"em": mean, "f1": mean}
Leo Gao's avatar
Leo Gao committed
291
292
293
294

    def higher_is_better(self):
        """
        :returns: {str: bool}
Jon Tow's avatar
Jon Tow committed
295
296
            A dictionary where keys are the names of submetrics and values are
            whether a higher value of the submetric is better
Leo Gao's avatar
Leo Gao committed
297
        """
bzantium's avatar
bzantium committed
298
        return {"em": True, "f1": True}