Commit 2009ec4b authored by Baber's avatar Baber
Browse files

update `scrolls`

parent 4ad6cd9f
...@@ -2,9 +2,9 @@ import re ...@@ -2,9 +2,9 @@ import re
from abc import abstractmethod from abc import abstractmethod
from functools import reduce from functools import reduce
import datasets
import numpy as np import numpy as np
import transformers.data.metrics.squad_metrics as squad_metrics import transformers.data.metrics.squad_metrics as squad_metrics
from datasets import Dataset
from evaluate import load from evaluate import load
from transformers import AutoTokenizer from transformers import AutoTokenizer
...@@ -135,26 +135,10 @@ class _SCROLLSTask(ConfigurableTask): ...@@ -135,26 +135,10 @@ class _SCROLLSTask(ConfigurableTask):
return False return False
def training_docs(self): def training_docs(self):
processed_docs = list(map(self._process_doc, self.dataset["train"])) return self.dataset["train"].map(self._process_doc)
# Flatten the list of lists since _process_doc returns a list of one element.
processed_docs = [item for sublist in processed_docs for item in sublist]
processed_dict = {
key: [d[key] for d in processed_docs] for key in processed_docs[0]
}
return Dataset.from_dict(processed_dict)
def validation_docs(self): def validation_docs(self):
processed_docs = list(map(self._process_doc, self.dataset["validation"])) return self.dataset["validation"].map(self._process_doc)
# Flatten the list of lists since _process_doc returns a list of one element.
processed_docs = [item for sublist in processed_docs for item in sublist]
processed_dict = {
key: [d[key] for d in processed_docs] for key in processed_docs[0]
}
return Dataset.from_dict(processed_dict)
def should_decontaminate(self): def should_decontaminate(self):
return True return True
...@@ -163,8 +147,9 @@ class _SCROLLSTask(ConfigurableTask): ...@@ -163,8 +147,9 @@ class _SCROLLSTask(ConfigurableTask):
return doc["input"] return doc["input"]
def download(self, *args, **kwargs): def download(self, *args, **kwargs):
super().download(*args, **kwargs) self.dataset: datasets.DatasetDict = datasets.load_dataset(
del self.dataset["test"] self.DATASET_PATH, self.DATASET_NAME, splits=["train", "validation"]
)
for split in self.dataset: for split in self.dataset:
self.dataset[split] = _drop_duplicates_in_input(self.dataset[split]) self.dataset[split] = _drop_duplicates_in_input(self.dataset[split])
if self.PRUNE_TOKENIZERS is not None: if self.PRUNE_TOKENIZERS is not None:
...@@ -173,23 +158,26 @@ class _SCROLLSTask(ConfigurableTask): ...@@ -173,23 +158,26 @@ class _SCROLLSTask(ConfigurableTask):
def _get_prune_text(self, sample): def _get_prune_text(self, sample):
return self.doc_to_text(self._process_doc(sample)[0]) return self.doc_to_text(self._process_doc(sample)[0])
def prune(self): def prune(self, **kwargs):
"""Create a pruned version of a SCROLLS task dataset containing only inputs """Create a pruned version of a SCROLLS task dataset containing only inputs
that are less than `max_tokens` when tokenized by each tokenizer that are less than `max_tokens` when tokenized by each tokenizer
""" """
toks = [kwargs.get("tokenizer", kwargs.get("pretrained"))]
tokenizers = [ if self.PRUNE_TOKENIZERS is not None:
AutoTokenizer.from_pretrained(tokenizer) toks.extend(self.PRUNE_TOKENIZERS)
for tokenizer in self.PRUNE_TOKENIZERS max_length = self.PRUNE_MAX_TOKENS or kwargs.get("max_length")
] tokenizers = [AutoTokenizer.from_pretrained(tokenizer) for tokenizer in toks]
cache = {} cache = {}
def _filter(sample): def _filter(sample):
text = self._get_prune_text(sample) text = self._get_prune_text(sample)
cached = cache.get(text, None) cached = cache.get(text)
if cached is None: if cached is None:
for tokenizer in tokenizers: for tokenizer in tokenizers:
if len(tokenizer(text).input_ids) > self.PRUNE_MAX_TOKENS: if (
max_length is not None
and len(tokenizer(text).input_ids) > max_length
):
cache[text] = False cache[text] = False
return False return False
cache[text] = True cache[text] = True
...@@ -206,7 +194,7 @@ class _SCROLLSTask(ConfigurableTask): ...@@ -206,7 +194,7 @@ class _SCROLLSTask(ConfigurableTask):
return f"{doc['text']}\n\nQuestion: {doc['question']}\nAnswer:" return f"{doc['text']}\n\nQuestion: {doc['question']}\nAnswer:"
def higher_is_better(self): def higher_is_better(self):
return {x: True for x in self._scrolls_metrics().keys()} return {x: True for x in self._scrolls_metrics()}
@abstractmethod @abstractmethod
def _scrolls_metrics(self): def _scrolls_metrics(self):
...@@ -263,9 +251,9 @@ class _SCROLLSMultipleChoiceTask(_SCROLLSTask): ...@@ -263,9 +251,9 @@ class _SCROLLSMultipleChoiceTask(_SCROLLSTask):
Instance( Instance(
request_type="loglikelihood", request_type="loglikelihood",
doc=doc, doc=doc,
arguments=(ctx, " {}".format(choice)) arguments=(ctx, f" {choice}")
if not apply_chat_template if not apply_chat_template
else (ctx, "{}".format(choice)), else (ctx, f"{choice}"),
idx=i, idx=i,
**kwargs, **kwargs,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment