Commit 764f6fb2 authored by Baber's avatar Baber
Browse files

better api

parent a61b3ee6
......@@ -60,7 +60,7 @@ class TaskConfig(dict):
# HF dataset options.
# which dataset to use,
# and what splits for what purpose
download_dataset: Optional[bool] = None
download_dataset: Optional[Callable] = None
dataset_path: Optional[str] = None
dataset_name: Optional[str] = None
dataset_kwargs: Optional[dict] = None
......@@ -822,13 +822,10 @@ class ConfigurableTask(Task):
self.download(self.config.dataset_kwargs)
else:
self.dataset = self.config.download_dataset(
metadata=self.config.metadata.get(
"tokenizer",
self.config.metadata.get("pretrained"),
**self.config.dataset_kwargs
if self.config.dataset_kwargs is not None
else {},
)
metadata=self.config.metadata,
**self.config.dataset_kwargs
if self.config.dataset_kwargs is not None
else {},
)
self._training_docs = None
self._fewshot_docs = None
......
......@@ -193,7 +193,7 @@ def generate_input_output(
query=query,
)
return input_text, answers
return input_text, answers, query
def generate_samples(
......@@ -213,7 +213,7 @@ def generate_samples(
remove_newline_tab: bool = False,
random_seed: int = 42,
TOKENIZER=None,
):
) -> list[dict]:
assert TOKENIZER is not None, "TOKENIZER is not defined."
num_needle_k = max(num_needle_k, num_needle_q)
write_jsons = []
......@@ -233,7 +233,7 @@ def generate_samples(
total_tokens = 0 # Track the total tokens generated for the first example
while total_tokens + tokens_to_generate < max_seq_length:
input_text, answer = generate_input_output(
input_text, answer, query = generate_input_output(
num_haystack,
haystack,
type_haystack=type_haystack,
......@@ -247,9 +247,6 @@ def generate_samples(
)
# Calculate the number of tokens in the example
total_tokens = len(TOKENIZER(input_text + " ".join(answer)).input_ids)
# print(
# f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Haystack: {num_haystack}"
# )
if total_tokens + tokens_to_generate > max_seq_length:
num_haystack -= incremental
break
......@@ -270,7 +267,7 @@ def generate_samples(
used_haystack = num_haystack
while True:
try:
input_text, answer = generate_input_output(
input_text, answer, query = generate_input_output(
used_haystack,
haystack,
type_haystack=type_haystack,
......@@ -301,6 +298,7 @@ def generate_samples(
"outputs": answer,
"length": length,
"max_length": max_seq_length,
"gen_prefix": f"The special magic {type_needle_v} for {query} mentioned in the provided text are",
}
if formatted_output["outputs"][0] not in formatted_output["input"]:
assert (
......
......@@ -2,7 +2,7 @@
import itertools
import re
from functools import cache
from typing import Literal
from typing import Literal, Generator, Union, TYPE_CHECKING
import datasets
from transformers import AutoTokenizer
......@@ -10,9 +10,16 @@ from transformers import AutoTokenizer
from lm_eval.tasks.ruler.essays import get_all_essays
from lm_eval.tasks.ruler.prepare import generate_samples
if TYPE_CHECKING:
import transformers
@cache
def get_tokenizer(pretrained):
def get_tokenizer(
**kwargs,
) -> Union["transformers.PreTrainedTokenizer", "transformers.PreTrainedTokenizerFast"]:
kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
assert pretrained, "No tokenizer or pretrained provided."
print("using tokenizer ", pretrained)
return AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
......@@ -36,7 +43,9 @@ RANDOM_SEED = 42
@cache
def get_haystack(type_haystack: Literal["essay", "repeat", "needle"]):
def get_haystack(
type_haystack: Literal["essay", "repeat", "needle"],
) -> Union[list[str], str]:
NEEDLE = "One of the special magic {type_needle_v} for {key} is: {value}."
if type_haystack == "essay":
essay = get_all_essays()["text"]
......@@ -51,7 +60,7 @@ def get_haystack(type_haystack: Literal["essay", "repeat", "needle"]):
return haystack
def flatten(df):
def flatten(df: Generator) -> dict[str, datasets.Dataset]:
return {
"test": datasets.Dataset.from_list(
list(itertools.chain.from_iterable(df)), split=datasets.Split.TEST
......@@ -60,7 +69,7 @@ def flatten(df):
# ruff: noqa
niah_single_1 = lambda x: flatten(
niah_single_1 = lambda **kwargs: flatten(
generate_samples(
get_haystack(type_haystack="repeat"),
max_seq_length=seq,
......@@ -68,7 +77,7 @@ niah_single_1 = lambda x: flatten(
type_haystack="repeat",
type_needle_k="words",
type_needle_v="numbers",
TOKENIZER=get_tokenizer(x),
TOKENIZER=get_tokenizer(**kwargs),
)
for seq in SEQ_LENGTHS
)
......@@ -86,7 +95,7 @@ niah_single_2 = lambda x: flatten(
for seq in SEQ_LENGTHS
)
# noqa
niah_single_3 = lambda x: flatten(
niah_single_3 = lambda **kwargs: flatten(
generate_samples(
get_haystack(type_haystack="essay"),
max_seq_length=seq,
......@@ -94,12 +103,12 @@ niah_single_3 = lambda x: flatten(
type_haystack="essay",
type_needle_k="words",
type_needle_v="uuids",
TOKENIZER=get_tokenizer(x),
TOKENIZER=get_tokenizer(**kwargs),
)
for seq in SEQ_LENGTHS
)
# noqa
niah_multikey_1 = lambda x: flatten(
niah_multikey_1 = lambda **kwargs: flatten(
generate_samples(
get_haystack(type_haystack="essay"),
max_seq_length=seq,
......@@ -108,12 +117,12 @@ niah_multikey_1 = lambda x: flatten(
type_needle_k="words",
type_needle_v="numbers",
num_needle_k=4,
TOKENIZER=get_tokenizer(x),
TOKENIZER=get_tokenizer(**kwargs),
)
for seq in SEQ_LENGTHS
)
# noqa
niah_multikey_2 = lambda x: flatten(
niah_multikey_2 = lambda **kwargs: flatten(
generate_samples(
get_haystack(type_haystack="needle"),
max_seq_length=seq,
......@@ -121,12 +130,12 @@ niah_multikey_2 = lambda x: flatten(
type_haystack="needle",
type_needle_k="words",
type_needle_v="numbers",
TOKENIZER=get_tokenizer(x),
TOKENIZER=get_tokenizer(**kwargs),
)
for seq in SEQ_LENGTHS
)
# noqa
niah_multikey_3 = lambda x: flatten(
niah_multikey_3 = lambda **kwargs: flatten(
generate_samples(
get_haystack(type_haystack="needle"),
max_seq_length=seq,
......@@ -134,12 +143,12 @@ niah_multikey_3 = lambda x: flatten(
type_haystack="needle",
type_needle_k="uuids",
type_needle_v="uuids",
TOKENIZER=get_tokenizer(x),
TOKENIZER=get_tokenizer(**kwargs),
)
for seq in SEQ_LENGTHS
)
# noqa
niah_multivalue = lambda x: flatten(
niah_multivalue = lambda **kwargs: flatten(
generate_samples(
get_haystack(type_haystack="essay"),
max_seq_length=seq,
......@@ -148,12 +157,12 @@ niah_multivalue = lambda x: flatten(
type_needle_k="words",
type_needle_v="numbers",
num_needle_v=4,
TOKENIZER=get_tokenizer(x),
TOKENIZER=get_tokenizer(**kwargs),
)
for seq in SEQ_LENGTHS
)
# noqa
niah_multiquery = lambda x: flatten(
niah_multiquery = lambda **kwargs: flatten(
generate_samples(
get_haystack(type_haystack="essay"),
max_seq_length=seq,
......@@ -162,7 +171,7 @@ niah_multiquery = lambda x: flatten(
type_needle_k="words",
type_needle_v="numbers",
num_needle_q=4,
TOKENIZER=get_tokenizer(x),
TOKENIZER=get_tokenizer(**kwargs),
)
for seq in SEQ_LENGTHS
)
......
......@@ -24,8 +24,7 @@ import numpy as np
from tqdm import tqdm
from transformers import AutoTokenizer
from lm_eval.tasks.ruler.prepare import SEQ_LENGTHS
from lm_eval.tasks.ruler.utils import SEQ_LENGTHS
TASKS = {
"variable_tracking": {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment