Commit 2009ec4b authored by Baber's avatar Baber
Browse files

update `scrolls`

parent 4ad6cd9f
......@@ -2,9 +2,9 @@ import re
from abc import abstractmethod
from functools import reduce
import datasets
import numpy as np
import transformers.data.metrics.squad_metrics as squad_metrics
from datasets import Dataset
from evaluate import load
from transformers import AutoTokenizer
......@@ -135,26 +135,10 @@ class _SCROLLSTask(ConfigurableTask):
return False
def training_docs(self):
processed_docs = list(map(self._process_doc, self.dataset["train"]))
# Flatten the list of lists since _process_doc returns a list of one element.
processed_docs = [item for sublist in processed_docs for item in sublist]
processed_dict = {
key: [d[key] for d in processed_docs] for key in processed_docs[0]
}
return Dataset.from_dict(processed_dict)
return self.dataset["train"].map(self._process_doc)
def validation_docs(self):
processed_docs = list(map(self._process_doc, self.dataset["validation"]))
# Flatten the list of lists since _process_doc returns a list of one element.
processed_docs = [item for sublist in processed_docs for item in sublist]
processed_dict = {
key: [d[key] for d in processed_docs] for key in processed_docs[0]
}
return Dataset.from_dict(processed_dict)
return self.dataset["validation"].map(self._process_doc)
def should_decontaminate(self):
return True
......@@ -163,8 +147,9 @@ class _SCROLLSTask(ConfigurableTask):
return doc["input"]
def download(self, *args, **kwargs):
super().download(*args, **kwargs)
del self.dataset["test"]
self.dataset: datasets.DatasetDict = datasets.load_dataset(
self.DATASET_PATH, self.DATASET_NAME, splits=["train", "validation"]
)
for split in self.dataset:
self.dataset[split] = _drop_duplicates_in_input(self.dataset[split])
if self.PRUNE_TOKENIZERS is not None:
......@@ -173,23 +158,26 @@ class _SCROLLSTask(ConfigurableTask):
def _get_prune_text(self, sample):
return self.doc_to_text(self._process_doc(sample)[0])
def prune(self):
def prune(self, **kwargs):
"""Create a pruned version of a SCROLLS task dataset containing only inputs
that are less than `max_tokens` when tokenized by each tokenizer
"""
tokenizers = [
AutoTokenizer.from_pretrained(tokenizer)
for tokenizer in self.PRUNE_TOKENIZERS
]
toks = [kwargs.get("tokenizer", kwargs.get("pretrained"))]
if self.PRUNE_TOKENIZERS is not None:
toks.extend(self.PRUNE_TOKENIZERS)
max_length = self.PRUNE_MAX_TOKENS or kwargs.get("max_length")
tokenizers = [AutoTokenizer.from_pretrained(tokenizer) for tokenizer in toks]
cache = {}
def _filter(sample):
text = self._get_prune_text(sample)
cached = cache.get(text, None)
cached = cache.get(text)
if cached is None:
for tokenizer in tokenizers:
if len(tokenizer(text).input_ids) > self.PRUNE_MAX_TOKENS:
if (
max_length is not None
and len(tokenizer(text).input_ids) > max_length
):
cache[text] = False
return False
cache[text] = True
......@@ -206,7 +194,7 @@ class _SCROLLSTask(ConfigurableTask):
return f"{doc['text']}\n\nQuestion: {doc['question']}\nAnswer:"
def higher_is_better(self):
return {x: True for x in self._scrolls_metrics().keys()}
return {x: True for x in self._scrolls_metrics()}
@abstractmethod
def _scrolls_metrics(self):
......@@ -263,9 +251,9 @@ class _SCROLLSMultipleChoiceTask(_SCROLLSTask):
Instance(
request_type="loglikelihood",
doc=doc,
arguments=(ctx, " {}".format(choice))
arguments=(ctx, f" {choice}")
if not apply_chat_template
else (ctx, "{}".format(choice)),
else (ctx, f"{choice}"),
idx=i,
**kwargs,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment