task.py 14.1 KB
Newer Older
lintangsutawika's avatar
lintangsutawika committed
1
import re
2
3
4
from abc import abstractmethod
from functools import reduce

Baber's avatar
Baber committed
5
import datasets
lintangsutawika's avatar
lintangsutawika committed
6
7
import numpy as np
import transformers.data.metrics.squad_metrics as squad_metrics
Baber Abbasi's avatar
Baber Abbasi committed
8
from evaluate import load
lintangsutawika's avatar
lintangsutawika committed
9
10
11
from transformers import AutoTokenizer

from lm_eval.api.instance import Instance
12
from lm_eval.api.metrics import mean
13
from lm_eval.api.task import ConfigurableTask
14

lintangsutawika's avatar
lintangsutawika committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

_CITATION = """
@inproceedings{shaham-etal-2022-scrolls,
    title = "{SCROLLS}: Standardized {C}ompa{R}ison Over Long Language Sequences",
    author = "Shaham, Uri  and
      Segal, Elad  and
      Ivgi, Maor  and
      Efrat, Avia  and
      Yoran, Ori  and
      Haviv, Adi  and
      Gupta, Ankit  and
      Xiong, Wenhan  and
      Geva, Mor  and
      Berant, Jonathan  and
      Levy, Omer",
    booktitle = "Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing",
    month = dec,
    year = "2022",
    address = "Abu Dhabi, United Arab Emirates",
    publisher = "Association for Computational Linguistics",
    url = "https://aclanthology.org/2022.emnlp-main.823",
    pages = "12007--12021"
}
"""

# SCROLLS is formualted as a sequence-to-sequence task.
# To allow for evaluation of causal models, we'll
# reformualte these with appropriate prompts


def _download_metric():
    import os
    import shutil
48

lintangsutawika's avatar
lintangsutawika committed
49
50
51
    from huggingface_hub import hf_hub_download

    scrolls_metric_path = hf_hub_download(
Baber Abbasi's avatar
Baber Abbasi committed
52
53
54
55
        repo_id="tau/scrolls",
        repo_type="dataset",
        filename="metrics/scrolls.py",
        revision="refs/pr/5",
lintangsutawika's avatar
lintangsutawika committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    )
    updated_scrolls_metric_path = (
        os.path.dirname(scrolls_metric_path)
        + os.path.basename(scrolls_metric_path).replace(".", "_")
        + ".py"
    )
    shutil.copy(scrolls_metric_path, updated_scrolls_metric_path)
    return updated_scrolls_metric_path


def _process_doc_prepended_question(doc):
    # "When a query is given in addition to the raw text (as
    # in QMSum, Qasper, NarrativeQA, QuALITY, and ContractNLI),
    # we prepend it to the text, using two newlines as a natural separator"
    input = doc["input"]
    split = input.find("\n\n")
    return {
        "id": doc["id"],
        "pid": doc["pid"],
        "input": input,
        "outputs": doc["outputs"],
        "question": input[0:split],
        "text": input[split + 2 :],
    }


def _drop_duplicates_in_input(untokenized_dataset):
    # from scrolls/evaluator/dataset_evaluator.py

    indices_to_keep = []
    id_to_idx = {}
    outputs = []
    for i, (id_, output) in enumerate(
        zip(untokenized_dataset["id"], untokenized_dataset["output"])
    ):
        if id_ in id_to_idx:
            outputs[id_to_idx[id_]].append(output)
            continue
        indices_to_keep.append(i)
        id_to_idx[id_] = len(outputs)
        outputs.append([output])
    untokenized_dataset = untokenized_dataset.select(indices_to_keep).flatten_indices()
    untokenized_dataset = untokenized_dataset.remove_columns("output")
    untokenized_dataset = untokenized_dataset.add_column("outputs", outputs)
    return untokenized_dataset


def _num_cpu_cores():
    # https://stackoverflow.com/questions/1006289/how-to-find-out-the-number-of-cpus-using-python/55423170#55423170
    try:
        import psutil

        return psutil.cpu_count(logical=False)
    except ImportError:
        import os

        return len(os.sched_getaffinity(0))


115
class _SCROLLSTask(ConfigurableTask):
116
    VERSION = 2
lintangsutawika's avatar
lintangsutawika committed
117
118
119
120
121
122
    DATASET_PATH = "tau/scrolls"
    DATASET_NAME = None
    PRUNE_TOKENIZERS = None
    PRUNE_MAX_TOKENS = None
    PRUNE_NUM_PROC = None

123
    def __init__(self, config=None):
124
        super().__init__(config={"metadata": {"version": self.VERSION}})
125
        if self.DATASET_NAME is not None:
Baber Abbasi's avatar
Baber Abbasi committed
126
            self.metric = load(_download_metric(), config_name=self.DATASET_NAME)
lintangsutawika's avatar
lintangsutawika committed
127
128
129
130
131
132
133
134
135
136
137

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return False

    def training_docs(self):
Baber's avatar
Baber committed
138
        return self.dataset["train"].map(self._process_doc)
lintangsutawika's avatar
lintangsutawika committed
139
140

    def validation_docs(self):
Baber's avatar
Baber committed
141
        return self.dataset["validation"].map(self._process_doc)
lintangsutawika's avatar
lintangsutawika committed
142
143
144
145
146
147
148
149

    def should_decontaminate(self):
        return True

    def doc_to_decontamination_query(self, doc):
        return doc["input"]

    def download(self, *args, **kwargs):
Baber's avatar
Baber committed
150
151
152
        self.dataset: datasets.DatasetDict = datasets.load_dataset(
            self.DATASET_PATH, self.DATASET_NAME, splits=["train", "validation"]
        )
lintangsutawika's avatar
lintangsutawika committed
153
154
        for split in self.dataset:
            self.dataset[split] = _drop_duplicates_in_input(self.dataset[split])
155
        if self.PRUNE_TOKENIZERS is not None:
lintangsutawika's avatar
lintangsutawika committed
156
157
158
159
160
            self.prune()

    def _get_prune_text(self, sample):
        return self.doc_to_text(self._process_doc(sample)[0])

Baber's avatar
Baber committed
161
    def prune(self, **kwargs):
lintangsutawika's avatar
lintangsutawika committed
162
163
164
        """Create a pruned version of a SCROLLS task dataset containing only inputs
        that are less than `max_tokens` when tokenized by each tokenizer
        """
Baber's avatar
Baber committed
165
166
167
168
169
        toks = [kwargs.get("tokenizer", kwargs.get("pretrained"))]
        if self.PRUNE_TOKENIZERS is not None:
            toks.extend(self.PRUNE_TOKENIZERS)
        max_length = self.PRUNE_MAX_TOKENS or kwargs.get("max_length")
        tokenizers = [AutoTokenizer.from_pretrained(tokenizer) for tokenizer in toks]
lintangsutawika's avatar
lintangsutawika committed
170
171
172
173
        cache = {}

        def _filter(sample):
            text = self._get_prune_text(sample)
Baber's avatar
Baber committed
174
            cached = cache.get(text)
lintangsutawika's avatar
lintangsutawika committed
175
176
            if cached is None:
                for tokenizer in tokenizers:
Baber's avatar
Baber committed
177
178
179
180
                    if (
                        max_length is not None
                        and len(tokenizer(text).input_ids) > max_length
                    ):
lintangsutawika's avatar
lintangsutawika committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
                        cache[text] = False
                        return False
                cache[text] = True
                return True
            else:
                return cached

        self.dataset = self.dataset.filter(_filter, num_proc=self.PRUNE_NUM_PROC)

    def doc_to_target(self, doc):
        return " " + ", ".join(doc["outputs"])

    def doc_to_text(self, doc):
        return f"{doc['text']}\n\nQuestion: {doc['question']}\nAnswer:"

    def higher_is_better(self):
Baber's avatar
Baber committed
197
        return {x: True for x in self._scrolls_metrics()}
lintangsutawika's avatar
lintangsutawika committed
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220

    @abstractmethod
    def _scrolls_metrics(self):
        pass

    def _make_compute_metrics(self, value):
        def compute_metrics(samples):
            predictions, references = zip(*samples)  # unzip, if you will
            computed = self.metric.compute(
                predictions=predictions, references=references
            )
            return computed[value]

        return compute_metrics

    def aggregation(self):
        return {
            key: self._make_compute_metrics(value)
            for key, value in self._scrolls_metrics().items()
        }


class _SCROLLSMultipleChoiceTask(_SCROLLSTask):
lintangsutawika's avatar
lintangsutawika committed
221
222
    def __post_init__(self):
        self.metric = None
lintangsutawika's avatar
lintangsutawika committed
223
224
225
226
227
228
229
230
231
232
233
234
235

    def _scrolls_metrics(self):
        return None

    def aggregation(self):
        return {"em": mean, "acc": mean, "acc_norm": mean}

    def higher_is_better(self):
        return {"em": True, "acc": True, "acc_norm": True}

    def process_results(self, doc, results):
        gold = doc["gold"]

236
237
        lls, _ = zip(*results)
        acc = 1.0 if np.argmax(lls) == gold else 0.0
lintangsutawika's avatar
lintangsutawika committed
238
        completion_len = np.array([float(len(i)) for i in doc["choices"]])
239
        acc_norm = 1.0 if np.argmax(lls / completion_len) == gold else 0.0
lintangsutawika's avatar
lintangsutawika committed
240
241
242
243
244
245
246

        return {
            "acc": acc,
            "acc_norm": acc_norm,
            "em": acc_norm * 100.0,
        }

lintangsutawika's avatar
lintangsutawika committed
247
    def construct_requests(self, doc, ctx, **kwargs):
Baber Abbasi's avatar
Baber Abbasi committed
248
        apply_chat_template = kwargs.pop("apply_chat_template", False)
lintangsutawika's avatar
lintangsutawika committed
249
250
251
252
        request_list = [
            Instance(
                request_type="loglikelihood",
                doc=doc,
Baber's avatar
Baber committed
253
                arguments=(ctx, f" {choice}")
Baber Abbasi's avatar
Baber Abbasi committed
254
                if not apply_chat_template
Baber's avatar
Baber committed
255
                else (ctx, f"{choice}"),
lintangsutawika's avatar
lintangsutawika committed
256
                idx=i,
lintangsutawika's avatar
lintangsutawika committed
257
                **kwargs,
lintangsutawika's avatar
lintangsutawika committed
258
            )
lintangsutawika's avatar
lintangsutawika committed
259
            for i, choice in enumerate(doc["choices"])
lintangsutawika's avatar
lintangsutawika committed
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
        ]
        return request_list


class _SCROLLSSummaryTask(_SCROLLSTask):
    def _process_doc(self, doc):
        return [doc]

    def _scrolls_metrics(self):
        return {
            "rouge1": "rouge/rouge1",
            "rouge2": "rouge/rouge2",
            "rougeL": "rouge/rougeL",
        }

    def process_results(self, doc, results):
        return {
            "rouge1": (results[0], doc["outputs"]),
            "rouge2": (results[0], doc["outputs"]),
            "rougeL": (results[0], doc["outputs"]),
        }

lintangsutawika's avatar
lintangsutawika committed
282
    def construct_requests(self, doc, ctx, **kwargs):
Baber Abbasi's avatar
Baber Abbasi committed
283
        kwargs.pop("apply_chat_template", False)
lintangsutawika's avatar
lintangsutawika committed
284
285
286
287
288
        return Instance(
            request_type="generate_until",
            doc=doc,
            arguments=(ctx, {"until": ["\n"]}),
            idx=0,
lintangsutawika's avatar
lintangsutawika committed
289
            **kwargs,
lintangsutawika's avatar
lintangsutawika committed
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
        )

    def doc_to_text(self, doc):
        return f"{doc['input']}\n\nQuestion: What is a summary of the preceding text?\nAnswer:"


class Qasper(_SCROLLSTask):
    """A Dataset of Information-Seeking Questions and Answers Anchored in Research Papers
    https://arxiv.org/abs/2105.03011
    """

    DATASET_NAME = "qasper"

    def _process_doc(self, doc):
        doc = _process_doc_prepended_question(doc)
        doc["is_yes_no"] = reduce(
            lambda prev, cur: prev
            and squad_metrics.normalize_answer(cur) in ["yes", "no"],
            doc["outputs"],
            True,
        )
        return [doc]

    def _scrolls_metrics(self):
        return {"f1": "f1"}

    def process_results(self, doc, results):
        if doc["is_yes_no"]:
            prediction = " yes" if results[0] > results[1] else " no"
        elif len(results[0].strip()) == 0:
            prediction = "Unanswerable"
        else:
            prediction = results[0]
        return {"f1": (prediction, doc["outputs"])}

lintangsutawika's avatar
lintangsutawika committed
325
    def construct_requests(self, doc, ctx, **kwargs):
Baber Abbasi's avatar
Baber Abbasi committed
326
        apply_chat_template = kwargs.pop("apply_chat_template", False)
lintangsutawika's avatar
lintangsutawika committed
327
328
329
330
331
        if doc["is_yes_no"]:
            return [
                Instance(
                    request_type="loglikelihood",
                    doc=doc,
Baber Abbasi's avatar
Baber Abbasi committed
332
333
334
                    arguments=(ctx, " yes")
                    if not apply_chat_template
                    else (ctx, "yes"),
lintangsutawika's avatar
lintangsutawika committed
335
                    idx=0,
lintangsutawika's avatar
lintangsutawika committed
336
                    **kwargs,
lintangsutawika's avatar
lintangsutawika committed
337
338
339
340
                ),
                Instance(
                    request_type="loglikelihood",
                    doc=doc,
Baber Abbasi's avatar
Baber Abbasi committed
341
                    arguments=(ctx, " no") if not apply_chat_template else (ctx, "no"),
lintangsutawika's avatar
lintangsutawika committed
342
                    idx=1,
lintangsutawika's avatar
lintangsutawika committed
343
                    **kwargs,
lintangsutawika's avatar
lintangsutawika committed
344
345
346
347
348
349
350
351
                ),
            ]
        else:
            return Instance(
                request_type="generate_until",
                doc=doc,
                arguments=(ctx, {"until": ["\n"]}),
                idx=0,
lintangsutawika's avatar
lintangsutawika committed
352
                **kwargs,
lintangsutawika's avatar
lintangsutawika committed
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
            )


class QuALITY(_SCROLLSMultipleChoiceTask):
    """QuALITY: Question Answering with Long Input Texts, Yes!
    https://arxiv.org/abs/2112.08608
    """

    DATASET_NAME = "quality"
    _multiple_choice_pattern = re.compile(r" *\([A-D]\) *")

    @staticmethod
    def _normalize_answer(text):
        return " ".join(text.split()).strip()

    def _process_doc(self, doc):
        doc = _process_doc_prepended_question(doc)

        split = doc["text"].find("\n\n", doc["text"].find("(D)"))
        choices_text = doc["text"][:split]

        doc["text"] = doc["text"][split:].strip()
        doc["choices"] = [
            QuALITY._normalize_answer(choice)
            for choice in re.split(QuALITY._multiple_choice_pattern, choices_text)[1:]
        ]
        doc["gold"] = doc["choices"].index(QuALITY._normalize_answer(doc["outputs"][0]))

        return [doc]


class NarrativeQA(_SCROLLSTask):
    """The NarrativeQA Reading Comprehension Challenge
    https://arxiv.org/abs/1712.07040
    """

    DATASET_NAME = "narrative_qa"

    def _process_doc(self, doc):
        return [_process_doc_prepended_question(doc)]

    def _scrolls_metrics(self):
        return {"f1": "f1"}

    def _get_prune_text(self, doc):
        # pruning narrativeqa takes forever -- let's cheat a bit
        # and just cache on the text, not the question, since
        # the dataset is different questions about the same large
        # documents
        return self._process_doc(doc)[0]["text"]

    def process_results(self, doc, results):
        return {"f1": (results[0], doc["outputs"])}

lintangsutawika's avatar
lintangsutawika committed
407
    def construct_requests(self, doc, ctx, **kwargs):
Baber Abbasi's avatar
Baber Abbasi committed
408
        kwargs.pop("apply_chat_template", False)
lintangsutawika's avatar
lintangsutawika committed
409
410
411
412
413
        return Instance(
            request_type="generate_until",
            doc=doc,
            arguments=(ctx, {"until": ["\n"]}),
            idx=0,
lintangsutawika's avatar
lintangsutawika committed
414
            **kwargs,
lintangsutawika's avatar
lintangsutawika committed
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
        )


class ContractNLI(_SCROLLSMultipleChoiceTask):
    """ContractNLI: A Dataset for Document-level Natural Language Inference for Contracts
    https://arxiv.org/abs/1712.07040
    """

    DATASET_NAME = "contract_nli"
    CHOICES = ["Not mentioned", "Entailment", "Contradiction"]

    def _process_doc(self, doc):
        doc = _process_doc_prepended_question(doc)
        doc["choices"] = ContractNLI.CHOICES
        doc["gold"] = ContractNLI.CHOICES.index(doc["outputs"][0])
        return [doc]

    def doc_to_text(self, doc):
        return f"{doc['text']}\n\nHypothesis: {doc['question']}\nConclusion:"


class GovReport(_SCROLLSSummaryTask):
    """Efficient Attentions for Long Document Summarization
    https://arxiv.org/abs/2104.02112

    Note: The average length of the reference summaries is ~3,000
    characters, or ~600 tokens as tokenized by GPT-NeoX. For causal models,
LSinev's avatar
LSinev committed
442
    it is recommended to set `max_gen_toks` sufficiently large (e.g. 1024)
lintangsutawika's avatar
lintangsutawika committed
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
    to allow a full summary to be generated.
    """

    DATASET_NAME = "gov_report"


class SummScreenFD(_SCROLLSSummaryTask):
    """SummScreen: A Dataset for Abstractive Screenplay Summarization
    https://arxiv.org/abs/2104.07091
    """

    DATASET_NAME = "summ_screen_fd"


class QMSum(_SCROLLSSummaryTask):
    """QMSum: A New Benchmark for Query-based Multi-domain
    Meeting Summarization

    https://arxiv.org/abs/2104.05938
    """

    DATASET_NAME = "qmsum"

    def _process_doc(self, doc):
        return [_process_doc_prepended_question(doc)]

    def doc_to_text(self, doc):
        return f"{doc['text']}\n\nQuestion: {doc['question']}\nAnswer:"