drop.py 8.25 KB
Newer Older
Anish Thite's avatar
Anish Thite committed
1
import json
Jon Tow's avatar
Jon Tow committed
2
3
import numpy as np
import re
4
import string
Jon Tow's avatar
Jon Tow committed
5
6
7
8
from best_download import download_file
from scipy.optimize import linear_sum_assignment
from lm_eval.base import Task, rf
from lm_eval.metrics import mean
Anish Thite's avatar
Anish Thite committed
9
from pathlib import Path
Jon Tow's avatar
Jon Tow committed
10
11
from zipfile import ZipFile

12
13
14
15
16
"""
Acknowledgement: This implementation is based on the official evaluation for `DROP`:
https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc/eval/drop_eval.py
"""

Anish Thite's avatar
Anish Thite committed
17

18
class DROP(Task):
Leo Gao's avatar
Leo Gao committed
19
    VERSION = 0
20
    DATASET_PATH = Path("data/drop")
Jon Tow's avatar
Jon Tow committed
21
22

    def download(self):
23
24
25
26
27
28
29
30
31
        if self.DATASET_PATH.exists():
            return
        Path.mkdir(self.DATASET_PATH)
        url = "https://s3-us-west-2.amazonaws.com/allennlp/datasets/drop/drop_dataset.zip"
        checksum = "39d2278a29fd729de301b111a45f434c24834f40df8f4ff116d864589e3249d6"
        zip_path = self.DATASET_PATH / "drop_dataset.zip"
        download_file(url, str(zip_path), checksum)
        with ZipFile(zip_path, "r") as zip:
            zip.extractall(self.DATASET_PATH)
32

Anish Thite's avatar
Anish Thite committed
33
34
    def has_training_docs(self):
        return True
Jon Tow's avatar
Jon Tow committed
35

Anish Thite's avatar
Anish Thite committed
36
37
38
39
40
41
    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

Jon Tow's avatar
Jon Tow committed
42
43
44
45
46
47
48
49
    def fewshot_description(self):
        # TODO: figure out description
        return ""

    def _load_docs(self, docs):
        for doc in docs:
            for qa in doc["qa_pairs"]:
                yield {
Jon Tow's avatar
Jon Tow committed
50
                    "id": qa["query_id"],
Jon Tow's avatar
Jon Tow committed
51
52
53
54
                    "passage": doc["passage"],
                    "question": qa["question"],
                    "answers": self.get_answers(qa["answer"]),
                }
Anish Thite's avatar
Anish Thite committed
55

Jon Tow's avatar
Jon Tow committed
56
57
58
59
    @classmethod
    def get_answers(cls, answers):
        # NOTE: We wrap every non-`list` answer into a list for uniformity.
        if answers["number"] != "":
Jon Tow's avatar
Jon Tow committed
60
            return [str(answers["number"])]
Jon Tow's avatar
Jon Tow committed
61
62
63
64
65
66
67
        if answers["spans"] != []:
            return answers["spans"]
        return [" ".join([answers["date"]["day"],
                          answers["date"]["month"],
                          answers["date"]["year"]]).strip()]

    def training_docs(self):
68
        docs = json.load(open(self.DATASET_PATH / "drop_dataset" / "drop_dataset_train.json"))
Jon Tow's avatar
Jon Tow committed
69
        return self._load_docs([docs[k] for k in docs.keys()])
Anish Thite's avatar
Anish Thite committed
70
71

    def validation_docs(self):
72
        docs = json.load(open(self.DATASET_PATH / "drop_dataset" / "drop_dataset_dev.json"))
Jon Tow's avatar
Jon Tow committed
73
74
75
76
77
78
79
        return self._load_docs([docs[k] for k in docs.keys()])

    def doc_to_text(self, doc):
        return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:"

    def doc_to_target(self, doc):
        return " " + ", ".join(doc["answers"])
Anish Thite's avatar
Anish Thite committed
80

Leo Gao's avatar
Leo Gao committed
81
    def construct_requests(self, doc, ctx):
Jon Tow's avatar
Jon Tow committed
82
        """Uses RequestFactory to construct Requests and returns an iterable of
Leo Gao's avatar
Leo Gao committed
83
        Requests which will be sent to the LM.
84

Jon Tow's avatar
Jon Tow committed
85
        :param doc:
Leo Gao's avatar
Leo Gao committed
86
            The document as returned from training_docs, validation_docs, or test_docs.
Jon Tow's avatar
Jon Tow committed
87
        :param ctx: str
Jon Tow's avatar
Jon Tow committed
88
            The context string, generated by fewshot_context. This includes the natural
Leo Gao's avatar
Leo Gao committed
89
            language description, as well as the few shot examples, and the question
Jon Tow's avatar
Jon Tow committed
90
            part of the document for `doc`.
Leo Gao's avatar
Leo Gao committed
91
        """
Jon Tow's avatar
Jon Tow committed
92
93
        conts = []
        for _ in doc["answers"]:
Jon Tow's avatar
Jon Tow committed
94
            conts.append(rf.greedy_until(ctx, ["."]))
Jon Tow's avatar
Jon Tow committed
95
96
        return conts

Leo Gao's avatar
Leo Gao committed
97
    def process_results(self, doc, results):
Jon Tow's avatar
Jon Tow committed
98
99
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
Leo Gao's avatar
Leo Gao committed
100
101
        the metric for that one document

Jon Tow's avatar
Jon Tow committed
102
        :param doc:
Jon Tow's avatar
Jon Tow committed
103
            The document as returned from training_docs, validation_docs, or test_docs.
Leo Gao's avatar
Leo Gao committed
104
105
106
        :param results:
            The results of the requests created in construct_requests.
        """
107
108
        preds, golds = results, doc["answers"]
        exact_match, f1_score = self.get_metrics(preds, golds)
Jon Tow's avatar
Jon Tow committed
109
110
111
112
        return {
            "em": exact_match,
            "f1": f1_score
        }
Jon Tow's avatar
Jon Tow committed
113

114
115
116
117
    def get_metrics(self, preds, golds):
        exact_match = self._exact_match(preds, golds)
        f1_score = self._f1_score(preds, golds)
        return exact_match, f1_score
Jon Tow's avatar
Jon Tow committed
118

119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    def _exact_match(self, preds, golds):
        """ Returns the exact match of normalized gold answers and predictions. """
        normalized_preds = [self._normalize(pred) for pred in preds]
        normalized_golds = [self._normalize(gold) for gold in golds]
        is_equal_sets = set(normalized_preds) == set(normalized_golds)
        is_equal_length = len(normalized_preds) == len(normalized_golds)
        return int(is_equal_sets and is_equal_length)

    def _f1_score(self, preds, golds):
        """Returns the average F1-score over normalized gold answers and predictions.
        From Section 5 of Dua et al. "DROP:...":
        "When an answer has multiple spans, we first perform a one-to-one
        alignment greedily based on bag-of-word overlap on the set of spans
        and then compute average F1 over each span."
        """
Jon Tow's avatar
Jon Tow committed
134
        pred_bags = self._answer_to_bags(preds)
135
136
        gold_bags = self._answer_to_bags(golds)
        f1_per_bag = self._align_bags(pred_bags, gold_bags)
Jon Tow's avatar
Jon Tow committed
137
138
139
140
141
        return np.mean(f1_per_bag)

    def _answer_to_bags(self, answers):
        return [set(self._normalize(answer).split()) for answer in answers]

142
    def _align_bags(self, pred_bags, gold_bags):
Jon Tow's avatar
Jon Tow committed
143
144
145
146
        """ Returns the max metric value over all the answers. """
        scores = np.zeros([len(gold_bags), len(pred_bags)])
        for gold_index, gold_bag in enumerate(gold_bags):
            for pred_index, pred_bag in enumerate(pred_bags):
147
148
                if self._is_number_match(pred_bag, gold_bag):
                    scores[gold_index, pred_index] = self._bag_f1(pred_bag, gold_bag)
Jon Tow's avatar
Jon Tow committed
149
150
151
152
153
154
        row_ind, col_ind = linear_sum_assignment(-scores)
        max_scores = np.zeros([max(len(gold_bags), len(pred_bags))])
        for row, column in zip(row_ind, col_ind):
            max_scores[row] = max(max_scores[row], scores[row, column])
        return max_scores

155
    def _bag_f1(self, pred_bag, gold_bag):
Jon Tow's avatar
Jon Tow committed
156
157
158
159
160
161
162
163
        intersection = len(gold_bag.intersection(pred_bag))
        if intersection == 0:
            return 0.0
        precision = intersection / float(len(pred_bag)) if pred_bag else 1.0
        recall = intersection / float(len(gold_bag)) if gold_bag else 1.0
        f1 = (2 * precision * recall) / (precision + recall)
        return f1

164
165
166
167
168
169
170
171
172
173
174
175
176
    def _is_number_match(self, pred_bag, gold_bag):
        pred_numbers = set([word for word in pred_bag if self._is_number(word)])
        gold_numbers = set([word for word in gold_bag if self._is_number(word)])
        if (not gold_numbers) or gold_numbers.intersection(pred_numbers):
            return True
        return False

    def _is_number(self, text):
        try:
            float(text)
            return True
        except ValueError:
            return False
Jon Tow's avatar
Jon Tow committed
177
178

    def _normalize(self, answer):
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        def remove_articles(text):
            regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
            return re.sub(regex, " ", text)

        def white_space_fix(text):
            return " ".join(text.split())

        def remove_punc(text):
            exclude = set(string.punctuation)
            if not self._is_number(text):
                return "".join(ch for ch in text if ch not in exclude)
            else:
                return text

        def fix_number(text):
            return str(float(text)) if self._is_number(text) else text

Jon Tow's avatar
Jon Tow committed
196
197
        def tokenize(text):
            return re.split(" |-", text)
198
199
200
201
202

        tokens = [
            white_space_fix(remove_articles(fix_number(remove_punc(token.lower()))))
            for token in tokenize(answer)
        ]
Jon Tow's avatar
Fixes  
Jon Tow committed
203
        tokens = [token for token in tokens if token.strip()]
Jon Tow's avatar
Jon Tow committed
204
205
        normalized = " ".join(tokens).strip()
        return normalized
Leo Gao's avatar
Leo Gao committed
206
207
208
209

    def aggregation(self):
        """
        :returns: {str: [float] -> float}
Jon Tow's avatar
Jon Tow committed
210
211
            A dictionary where keys are the names of submetrics and values are
            functions that aggregate a list of metrics
Leo Gao's avatar
Leo Gao committed
212
        """
Jon Tow's avatar
Jon Tow committed
213
214
215
216
        return {
            "em": mean,
            "f1": mean
        }
Leo Gao's avatar
Leo Gao committed
217
218
219
220

    def higher_is_better(self):
        """
        :returns: {str: bool}
Jon Tow's avatar
Jon Tow committed
221
222
            A dictionary where keys are the names of submetrics and values are
            whether a higher value of the submetric is better
Leo Gao's avatar
Leo Gao committed
223
        """
Jon Tow's avatar
Jon Tow committed
224
225
226
227
        return {
            "em": True,
            "f1": True
        }