drop.py 8.25 KB
Newer Older
Anish Thite's avatar
Anish Thite committed
1
import json
Jon Tow's avatar
Jon Tow committed
2
3
4
5
6
7
8
import numpy as np
import re
import transformers.data.metrics.squad_metrics as squad_metrics
from best_download import download_file
from scipy.optimize import linear_sum_assignment
from lm_eval.base import Task, rf
from lm_eval.metrics import mean
Anish Thite's avatar
Anish Thite committed
9
from pathlib import Path
Jon Tow's avatar
Jon Tow committed
10
11
from zipfile import ZipFile

Anish Thite's avatar
Anish Thite committed
12

13
class DROP(Task):
Jon Tow's avatar
Jon Tow committed
14
15
16
17
18
19
20
21
22
23
    DATAFOLDER = Path("data/drop")
    URL = "https://s3-us-west-2.amazonaws.com/allennlp/datasets/drop/drop_dataset.zip"

    def download(self):
        if self.DATAFOLDER.exists():
            return
        Path.mkdir(self.DATAFOLDER)
        download_file(self.URL, to=str(self.DATAFOLDER / "drop_dataset.zip"))
        with ZipFile(self.DATAFOLDER / "drop_dataset.zip", "r") as zip:
            zip.extractall(self.DATAFOLDER)
24

Anish Thite's avatar
Anish Thite committed
25
26
    def has_training_docs(self):
        return True
Jon Tow's avatar
Jon Tow committed
27

Anish Thite's avatar
Anish Thite committed
28
29
30
31
32
33
    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

Jon Tow's avatar
Jon Tow committed
34
35
36
37
38
39
40
41
42
43
44
45
    def fewshot_description(self):
        # TODO: figure out description
        return ""

    def _load_docs(self, docs):
        for doc in docs:
            for qa in doc["qa_pairs"]:
                yield {
                    "passage": doc["passage"],
                    "question": qa["question"],
                    "answers": self.get_answers(qa["answer"]),
                }
Anish Thite's avatar
Anish Thite committed
46

Jon Tow's avatar
Jon Tow committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    @classmethod
    def get_answers(cls, answers):
        # NOTE: We wrap every non-`list` answer into a list for uniformity.
        if answers["number"] != "":
            return [answers["number"]]
        if answers["spans"] != []:
            return answers["spans"]
        return [" ".join([answers["date"]["day"],
                          answers["date"]["month"],
                          answers["date"]["year"]]).strip()]

    def training_docs(self):
        docs = json.load(open(self.DATAFOLDER / "drop_dataset" / "drop_dataset_train.json"))
        return self._load_docs([docs[k] for k in docs.keys()])
Anish Thite's avatar
Anish Thite committed
61
62

    def validation_docs(self):
Jon Tow's avatar
Jon Tow committed
63
64
65
        docs = json.load(open(self.DATAFOLDER / "drop_dataset" / "drop_dataset_dev.json"))
        return self._load_docs([docs[k] for k in docs.keys()])

Anish Thite's avatar
Anish Thite committed
66
67
68
    def test_docs(self):
        pass

Jon Tow's avatar
Jon Tow committed
69
70
71
72
73
    def doc_to_text(self, doc):
        return f"Passage: {doc['passage']}\nQuestion: {doc['question']}\nAnswer:"

    def doc_to_target(self, doc):
        return " " + ", ".join(doc["answers"])
Anish Thite's avatar
Anish Thite committed
74

Leo Gao's avatar
Leo Gao committed
75
    def construct_requests(self, doc, ctx):
Jon Tow's avatar
Jon Tow committed
76
        """Uses RequestFactory to construct Requests and returns an iterable of
Leo Gao's avatar
Leo Gao committed
77
        Requests which will be sent to the LM.
78

Jon Tow's avatar
Jon Tow committed
79
         :param doc:
Leo Gao's avatar
Leo Gao committed
80
            The document as returned from training_docs, validation_docs, or test_docs.
Jon Tow's avatar
Jon Tow committed
81
82
         :param ctx: str
            The context string, generated by fewshot_context. This includes the natural
Leo Gao's avatar
Leo Gao committed
83
            language description, as well as the few shot examples, and the question
Jon Tow's avatar
Jon Tow committed
84
            part of the document for `doc`.
Leo Gao's avatar
Leo Gao committed
85
        """
Jon Tow's avatar
Jon Tow committed
86
87
88
89
90
        conts = []
        for _ in doc["answers"]:
            conts.append(rf.greedy_until(ctx, ["\n", "."]))
        return conts

Leo Gao's avatar
Leo Gao committed
91
    def process_results(self, doc, results):
Jon Tow's avatar
Jon Tow committed
92
93
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
Leo Gao's avatar
Leo Gao committed
94
95
        the metric for that one document

Jon Tow's avatar
Jon Tow committed
96
97
        :param
                The document as returned from training_docs, validation_docs, or test_docs.
Leo Gao's avatar
Leo Gao committed
98
99
100
        :param results:
            The results of the requests created in construct_requests.
        """
Jon Tow's avatar
Jon Tow committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
        gold, pred = doc["answers"], results
        print(gold)
        print(pred)
        exact_match = self._exact_match(gold, pred)
        f1_score = self._f1_score(gold, pred)
        return {"em": exact_match, "f1": f1_score}

    def _exact_match(self, golds, preds):
        """ Returns the exact match of normalized gold answers and predictions. """
        normalized_golds = set([self._normalize(gold) for gold in golds])
        normalized_preds = set([self._normalize(pred) for pred in preds])
        return int(normalized_golds == normalized_preds)

    def _f1_score(self, golds, preds):
        """Returns the average F1-score over normalized `gold` and `pred`
        answer lists.
        """
        gold_bags = self._answer_to_bags(golds)
        print("GOLD BAGS: " + str(gold_bags))
        pred_bags = self._answer_to_bags(preds)
        print("PRED BAGS: " + str(pred_bags))
        f1_per_bag = self._align_bags(gold_bags, pred_bags)
        return np.mean(f1_per_bag)

    def _answer_to_bags(self, answers):
        return [set(self._normalize(answer).split()) for answer in answers]

    def _align_bags(self, gold_bags, pred_bags):
        """ Returns the max metric value over all the answers. """
        scores = np.zeros([len(gold_bags), len(pred_bags)])
        for gold_index, gold_bag in enumerate(gold_bags):
            for pred_index, pred_bag in enumerate(pred_bags):
                print(self._is_number_match(gold_bag, pred_bag))
                if self._is_number_match(gold_bag, pred_bag):
                    scores[gold_index, pred_index] = self._bag_f1(pred_bag, gold_bag)
        print(scores)
        row_ind, col_ind = linear_sum_assignment(-scores)
        max_scores = np.zeros([max(len(gold_bags), len(pred_bags))])
        for row, column in zip(row_ind, col_ind):
            max_scores[row] = max(max_scores[row], scores[row, column])
        return max_scores

    def _bag_f1(self, gold_bag, pred_bag):
        intersection = len(gold_bag.intersection(pred_bag))
        if intersection == 0:
            return 0.0
        precision = intersection / float(len(pred_bag)) if pred_bag else 1.0
        recall = intersection / float(len(gold_bag)) if gold_bag else 1.0
        f1 = (2 * precision * recall) / (precision + recall)
        return f1

    def _is_number_match(self, gold_bag, pred_bag):
        gold_numbers = set(filter(lambda s: s.isnumeric(), list(gold_bag)))
        pred_numbers = set(filter(lambda s: s.isnumeric(), list(pred_bag)))
        return (not gold_numbers) or gold_numbers.intersection(pred_numbers)

    def _normalize(self, answer):
        def tokenize(text):
            return re.split(" |-", text)
        tokens = [squad_metrics.normalize_answer(token) for token in tokenize(answer)]
        normalized = " ".join(tokens).strip()
        return normalized
Leo Gao's avatar
Leo Gao committed
163
164
165
166

    def aggregation(self):
        """
        :returns: {str: [float] -> float}
Jon Tow's avatar
Jon Tow committed
167
168
169
           A dictionary where keys are the names of submetrics and values are
           A dictionary where keys are the names of submetrics and values are
           functions that aggregate a list of metrics
Leo Gao's avatar
Leo Gao committed
170
        """
Jon Tow's avatar
Jon Tow committed
171
        return {"em": mean, "f1": mean}
Leo Gao's avatar
Leo Gao committed
172
173
174
175

    def higher_is_better(self):
        """
        :returns: {str: bool}
Jon Tow's avatar
Jon Tow committed
176
177
178
           A dictionary where keys are the names of submetrics and values are
           A dictionary where keys are the names of submetrics and values are
           whether a higher value of the submetric is better
Leo Gao's avatar
Leo Gao committed
179
        """
Jon Tow's avatar
Jon Tow committed
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
        return {"em": True, "f1": True}


# Temporary sanity-checks


def main():
    drop = DROP()

    def test_bags():
        multiple_answers = ["Pacific Ocean", "Pacific"]
        ma_bags = drop._answer_to_bags(multiple_answers)
        print(f"Multiple Choice Answer Bags: {multiple_answers} => {ma_bags}")
        assert len(ma_bags) == 2
        number_answer = ["1974"]
        number_bags = drop._answer_to_bags(number_answer)
        print(f"Number Bags: {number_answer} => {number_bags}")
        print()
    test_bags()

    def test_is_number_match():
        gold = ["10 29 1999"]
        pred = ["4 29 1990"]
        gb = drop._answer_to_bags(gold)
        pb = drop._answer_to_bags(pred)
        print(gb)
        print(pb)
        for g in gb:
            for p in pb:
                match = drop._is_number_match(g, p)
                print(match)
        print()
    #test_is_number_match()

    def test_exact_match():
        gold = ["Bob Ross"]
        pred = ["Bob Ross"]
        em = drop._exact_match(gold, pred)
        print(em)
    #test_exact_match()

    def test_f1_score():
        gold = ["25 to 44"]
        pred = ["25 to 44 or 45 to 64"]
        f1 = drop._f1_score(gold, pred)
        print(gold)
        print(pred)
        print(f1)
        gold = ["300", "1992"]
        pred = ["300", "1992"]
        f1 = drop._f1_score(gold, pred)
        print(f1)
    #test_f1_score()


if __name__ == "__main__":
    main()