drop.py 9.28 KB
Newer Older
Anish Thite's avatar
Anish Thite committed
1
import json
Jon Tow's avatar
Jon Tow committed
2
3
import numpy as np
import re
4
import string
Jon Tow's avatar
Jon Tow committed
5
6
7
8
from best_download import download_file
from scipy.optimize import linear_sum_assignment
from lm_eval.base import Task, rf
from lm_eval.metrics import mean
Anish Thite's avatar
Anish Thite committed
9
from pathlib import Path
Jon Tow's avatar
Jon Tow committed
10
11
from zipfile import ZipFile

12
13
14
15
16
"""
Acknowledgement: This implementation is based on the official evaluation for `DROP`:
https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc/eval/drop_eval.py
"""

silentv0x's avatar
silentv0x committed
17
_ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE)
Anish Thite's avatar
Anish Thite committed
18

19
class DROP(Task):
Leo Gao's avatar
Leo Gao committed
20
    VERSION = 1
21
    DATASET_PATH = Path("data/drop")
Jon Tow's avatar
Jon Tow committed
22
23

    def download(self):
24
25
        if self.DATASET_PATH.exists():
            return
Jun Shern Chan's avatar
Jun Shern Chan committed
26
        Path.mkdir(self.DATASET_PATH, parents=True)
27
28
29
        url = "https://s3-us-west-2.amazonaws.com/allennlp/datasets/drop/drop_dataset.zip"
        checksum = "39d2278a29fd729de301b111a45f434c24834f40df8f4ff116d864589e3249d6"
        zip_path = self.DATASET_PATH / "drop_dataset.zip"
30
        download_file(url, local_file=str(zip_path), expected_checksum=checksum)
31
32
        with ZipFile(zip_path, "r") as zip:
            zip.extractall(self.DATASET_PATH)
33

Anish Thite's avatar
Anish Thite committed
34
35
    def has_training_docs(self):
        return True
Jon Tow's avatar
Jon Tow committed
36

Anish Thite's avatar
Anish Thite committed
37
38
39
40
41
42
    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

Jon Tow's avatar
Jon Tow committed
43
44
45
46
    def _load_docs(self, docs):
        for doc in docs:
            for qa in doc["qa_pairs"]:
                yield {
Jon Tow's avatar
Jon Tow committed
47
                    "id": qa["query_id"],
Jon Tow's avatar
Jon Tow committed
48
49
                    "passage": doc["passage"],
                    "question": qa["question"],
silentv0x's avatar
silentv0x committed
50
                    "answers": self.get_answers(qa),
Jon Tow's avatar
Jon Tow committed
51
                }
Anish Thite's avatar
Anish Thite committed
52

Jon Tow's avatar
Jon Tow committed
53
    @classmethod
silentv0x's avatar
silentv0x committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    def get_answers(cls, qa):
        answers = []
        answers_set = set()

        candidates = [qa["answer"]] + qa.get("validated_answers", [])
        for candidate in candidates:
            answer = cls.parse_answer(candidate)
            if answer in answers_set:
                continue
            answers_set.add(answer)
            answers.append(answer)

        return answers

    @classmethod
    def parse_answer(cls, answer):
        # NOTE: Everything is returned as a tuple for uniformity and hashability.
        if answer["number"] != "":
            return (str(answer["number"]),)
        if answer["spans"] != []:
            return tuple(answer["spans"])
        return (" ".join([answer["date"]["day"],
                          answer["date"]["month"],
                          answer["date"]["year"]]).strip(),)
Jon Tow's avatar
Jon Tow committed
78
79

    def training_docs(self):
80
        docs = json.load(open(self.DATASET_PATH / "drop_dataset" / "drop_dataset_train.json"))
Jon Tow's avatar
Jon Tow committed
81
        return self._load_docs([docs[k] for k in docs.keys()])
Anish Thite's avatar
Anish Thite committed
82
83

    def validation_docs(self):
84
        docs = json.load(open(self.DATASET_PATH / "drop_dataset" / "drop_dataset_dev.json"))
Jon Tow's avatar
Jon Tow committed
85
86
87
88
89
90
        return self._load_docs([docs[k] for k in docs.keys()])

    def doc_to_text(self, doc):
        return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:"

    def doc_to_target(self, doc):
silentv0x's avatar
silentv0x committed
91
        return " " + ", ".join(doc["answers"][0])
Anish Thite's avatar
Anish Thite committed
92

Leo Gao's avatar
Leo Gao committed
93
    def construct_requests(self, doc, ctx):
Jon Tow's avatar
Jon Tow committed
94
        """Uses RequestFactory to construct Requests and returns an iterable of
Leo Gao's avatar
Leo Gao committed
95
        Requests which will be sent to the LM.
96

Jon Tow's avatar
Jon Tow committed
97
        :param doc:
Leo Gao's avatar
Leo Gao committed
98
            The document as returned from training_docs, validation_docs, or test_docs.
Jon Tow's avatar
Jon Tow committed
99
        :param ctx: str
Jon Tow's avatar
Jon Tow committed
100
            The context string, generated by fewshot_context. This includes the natural
Leo Gao's avatar
Leo Gao committed
101
            language description, as well as the few shot examples, and the question
Jon Tow's avatar
Jon Tow committed
102
            part of the document for `doc`.
Leo Gao's avatar
Leo Gao committed
103
        """
silentv0x's avatar
silentv0x committed
104
        conts = [rf.greedy_until(ctx, ["."])]
Jon Tow's avatar
Jon Tow committed
105
106
        return conts

Leo Gao's avatar
Leo Gao committed
107
    def process_results(self, doc, results):
Jon Tow's avatar
Jon Tow committed
108
109
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
Leo Gao's avatar
Leo Gao committed
110
111
        the metric for that one document

Jon Tow's avatar
Jon Tow committed
112
        :param doc:
Jon Tow's avatar
Jon Tow committed
113
            The document as returned from training_docs, validation_docs, or test_docs.
Leo Gao's avatar
Leo Gao committed
114
115
116
        :param results:
            The results of the requests created in construct_requests.
        """
117
        preds, golds = results, doc["answers"]
silentv0x's avatar
silentv0x committed
118
119
120
121
122
123
124
        max_em = 0
        max_f1 = 0
        for gold_answer in golds:
            exact_match, f1_score = self.get_metrics(preds, gold_answer)
            if gold_answer[0].strip():
                max_em = max(max_em, exact_match)
                max_f1 = max(max_f1, f1_score)
Jon Tow's avatar
Jon Tow committed
125
        return {
silentv0x's avatar
silentv0x committed
126
127
            "em": max_em,
            "f1": max_f1
Jon Tow's avatar
Jon Tow committed
128
        }
Jon Tow's avatar
Jon Tow committed
129

silentv0x's avatar
silentv0x committed
130
131
132
133
134
135
136
    def get_metrics(self, predicted, gold):
        """
        Takes a predicted answer and a gold answer (that are both either a string or a list of
        strings), and returns exact match and the DROP F1 metric for the prediction.  If you are
        writing a script for evaluating objects in memory (say, the output of predictions during
        validation, or while training), this is the function you want to call, after using
        :func:`answer_json_to_strings` when reading the gold answer from the released data file.
137
        """
silentv0x's avatar
silentv0x committed
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
        predicted_bags = self._answer_to_bags(predicted)
        gold_bags = self._answer_to_bags(gold)

        if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(gold_bags[0]):
            exact_match = 1.0
        else:
            exact_match = 0.0

        f1_per_bag = self._align_bags(predicted_bags[1], gold_bags[1])
        f1 = np.mean(f1_per_bag)
        f1 = round(f1, 2)
        return exact_match, f1

    def _answer_to_bags(self, answer):
        if isinstance(answer, (list, tuple)):
            raw_spans = answer
        else:
            raw_spans = [answer]
        normalized_spans = []
        token_bags = []
        for raw_span in raw_spans:
            normalized_span = self._normalize(raw_span)
            normalized_spans.append(normalized_span)
            token_bags.append(set(normalized_span.split()))
        return normalized_spans, token_bags

    def _align_bags(self, predicted, gold):
        """
        Takes gold and predicted answer sets and first finds the optimal 1-1 alignment
        between them and gets maximum metric values over all the answers.
        """
        scores = np.zeros([len(gold), len(predicted)])
        for gold_index, gold_item in enumerate(gold):
            for pred_index, pred_item in enumerate(predicted):
                if self._match_numbers_if_present(gold_item, pred_item):
                    scores[gold_index, pred_index] = self._compute_f1(pred_item, gold_item)
Jon Tow's avatar
Jon Tow committed
174
        row_ind, col_ind = linear_sum_assignment(-scores)
silentv0x's avatar
silentv0x committed
175
176

        max_scores = np.zeros([max(len(gold), len(predicted))])
Jon Tow's avatar
Jon Tow committed
177
178
179
180
        for row, column in zip(row_ind, col_ind):
            max_scores[row] = max(max_scores[row], scores[row, column])
        return max_scores

silentv0x's avatar
silentv0x committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
    def _compute_f1(self, predicted_bag, gold_bag):
        intersection = len(gold_bag.intersection(predicted_bag))
        if not predicted_bag:
            precision = 1.0
        else:
            precision = intersection / float(len(predicted_bag))
        if not gold_bag:
            recall = 1.0
        else:
            recall = intersection / float(len(gold_bag))
        f1 = (
            (2 * precision * recall) / (precision + recall)
            if not (precision == 0.0 and recall == 0.0)
            else 0.0
        )
Jon Tow's avatar
Jon Tow committed
196
197
        return f1

silentv0x's avatar
silentv0x committed
198
199
200
201
202
203
204
205
206
207
    def _match_numbers_if_present(self, gold_bag, predicted_bag):
        gold_numbers = set()
        predicted_numbers = set()
        for word in gold_bag:
            if self._is_number(word):
                gold_numbers.add(word)
        for word in predicted_bag:
            if self._is_number(word):
                predicted_numbers.add(word)
        if (not gold_numbers) or gold_numbers.intersection(predicted_numbers):
208
209
210
211
212
213
214
215
216
            return True
        return False

    def _is_number(self, text):
        try:
            float(text)
            return True
        except ValueError:
            return False
Jon Tow's avatar
Jon Tow committed
217

silentv0x's avatar
silentv0x committed
218
219
    def _remove_articles(self, text):
        return _ARTICLES.sub(" ", text)
220

silentv0x's avatar
silentv0x committed
221
222
    def _white_space_fix(self, text):
        return " ".join(text.split())
223

silentv0x's avatar
silentv0x committed
224
225
226
227
228
229
    def _remove_punc(self, text):
        exclude = set(string.punctuation)
        if not self._is_number(text):
            return "".join(ch for ch in text if ch not in exclude)
        else:
            return text
230

silentv0x's avatar
silentv0x committed
231
232
    def _fix_number(self, text):
        return str(float(text)) if self._is_number(text) else text
233

silentv0x's avatar
Bug fix  
silentv0x committed
234
    def _tokenize(self, text):
silentv0x's avatar
silentv0x committed
235
        return re.split(" |-", text)
236

silentv0x's avatar
silentv0x committed
237
    def _normalize(self, answer):
238
        tokens = [
silentv0x's avatar
silentv0x committed
239
240
            self._white_space_fix(self._remove_articles(self._fix_number(self._remove_punc(token.lower()))))
            for token in self._tokenize(answer)
241
        ]
Jon Tow's avatar
Fixes  
Jon Tow committed
242
        tokens = [token for token in tokens if token.strip()]
Jon Tow's avatar
Jon Tow committed
243
244
        normalized = " ".join(tokens).strip()
        return normalized
Leo Gao's avatar
Leo Gao committed
245
246
247
248

    def aggregation(self):
        """
        :returns: {str: [float] -> float}
Jon Tow's avatar
Jon Tow committed
249
250
            A dictionary where keys are the names of submetrics and values are
            functions that aggregate a list of metrics
Leo Gao's avatar
Leo Gao committed
251
        """
Jon Tow's avatar
Jon Tow committed
252
253
254
255
        return {
            "em": mean,
            "f1": mean
        }
Leo Gao's avatar
Leo Gao committed
256
257
258
259

    def higher_is_better(self):
        """
        :returns: {str: bool}
Jon Tow's avatar
Jon Tow committed
260
261
            A dictionary where keys are the names of submetrics and values are
            whether a higher value of the submetric is better
Leo Gao's avatar
Leo Gao committed
262
        """
Jon Tow's avatar
Jon Tow committed
263
264
265
266
        return {
            "em": True,
            "f1": True
        }