Commit 764f6fb2 authored by Baber's avatar Baber
Browse files

better api

parent a61b3ee6
...@@ -60,7 +60,7 @@ class TaskConfig(dict): ...@@ -60,7 +60,7 @@ class TaskConfig(dict):
# HF dataset options. # HF dataset options.
# which dataset to use, # which dataset to use,
# and what splits for what purpose # and what splits for what purpose
download_dataset: Optional[bool] = None download_dataset: Optional[Callable] = None
dataset_path: Optional[str] = None dataset_path: Optional[str] = None
dataset_name: Optional[str] = None dataset_name: Optional[str] = None
dataset_kwargs: Optional[dict] = None dataset_kwargs: Optional[dict] = None
...@@ -822,13 +822,10 @@ class ConfigurableTask(Task): ...@@ -822,13 +822,10 @@ class ConfigurableTask(Task):
self.download(self.config.dataset_kwargs) self.download(self.config.dataset_kwargs)
else: else:
self.dataset = self.config.download_dataset( self.dataset = self.config.download_dataset(
metadata=self.config.metadata.get( metadata=self.config.metadata,
"tokenizer", **self.config.dataset_kwargs
self.config.metadata.get("pretrained"), if self.config.dataset_kwargs is not None
**self.config.dataset_kwargs else {},
if self.config.dataset_kwargs is not None
else {},
)
) )
self._training_docs = None self._training_docs = None
self._fewshot_docs = None self._fewshot_docs = None
......
...@@ -193,7 +193,7 @@ def generate_input_output( ...@@ -193,7 +193,7 @@ def generate_input_output(
query=query, query=query,
) )
return input_text, answers return input_text, answers, query
def generate_samples( def generate_samples(
...@@ -213,7 +213,7 @@ def generate_samples( ...@@ -213,7 +213,7 @@ def generate_samples(
remove_newline_tab: bool = False, remove_newline_tab: bool = False,
random_seed: int = 42, random_seed: int = 42,
TOKENIZER=None, TOKENIZER=None,
): ) -> list[dict]:
assert TOKENIZER is not None, "TOKENIZER is not defined." assert TOKENIZER is not None, "TOKENIZER is not defined."
num_needle_k = max(num_needle_k, num_needle_q) num_needle_k = max(num_needle_k, num_needle_q)
write_jsons = [] write_jsons = []
...@@ -233,7 +233,7 @@ def generate_samples( ...@@ -233,7 +233,7 @@ def generate_samples(
total_tokens = 0 # Track the total tokens generated for the first example total_tokens = 0 # Track the total tokens generated for the first example
while total_tokens + tokens_to_generate < max_seq_length: while total_tokens + tokens_to_generate < max_seq_length:
input_text, answer = generate_input_output( input_text, answer, query = generate_input_output(
num_haystack, num_haystack,
haystack, haystack,
type_haystack=type_haystack, type_haystack=type_haystack,
...@@ -247,9 +247,6 @@ def generate_samples( ...@@ -247,9 +247,6 @@ def generate_samples(
) )
# Calculate the number of tokens in the example # Calculate the number of tokens in the example
total_tokens = len(TOKENIZER(input_text + " ".join(answer)).input_ids) total_tokens = len(TOKENIZER(input_text + " ".join(answer)).input_ids)
# print(
# f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Haystack: {num_haystack}"
# )
if total_tokens + tokens_to_generate > max_seq_length: if total_tokens + tokens_to_generate > max_seq_length:
num_haystack -= incremental num_haystack -= incremental
break break
...@@ -270,7 +267,7 @@ def generate_samples( ...@@ -270,7 +267,7 @@ def generate_samples(
used_haystack = num_haystack used_haystack = num_haystack
while True: while True:
try: try:
input_text, answer = generate_input_output( input_text, answer, query = generate_input_output(
used_haystack, used_haystack,
haystack, haystack,
type_haystack=type_haystack, type_haystack=type_haystack,
...@@ -301,6 +298,7 @@ def generate_samples( ...@@ -301,6 +298,7 @@ def generate_samples(
"outputs": answer, "outputs": answer,
"length": length, "length": length,
"max_length": max_seq_length, "max_length": max_seq_length,
"gen_prefix": f"The special magic {type_needle_v} for {query} mentioned in the provided text are",
} }
if formatted_output["outputs"][0] not in formatted_output["input"]: if formatted_output["outputs"][0] not in formatted_output["input"]:
assert ( assert (
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import itertools import itertools
import re import re
from functools import cache from functools import cache
from typing import Literal from typing import Literal, Generator, Union, TYPE_CHECKING
import datasets import datasets
from transformers import AutoTokenizer from transformers import AutoTokenizer
...@@ -10,9 +10,16 @@ from transformers import AutoTokenizer ...@@ -10,9 +10,16 @@ from transformers import AutoTokenizer
from lm_eval.tasks.ruler.essays import get_all_essays from lm_eval.tasks.ruler.essays import get_all_essays
from lm_eval.tasks.ruler.prepare import generate_samples from lm_eval.tasks.ruler.prepare import generate_samples
if TYPE_CHECKING:
import transformers
@cache
def get_tokenizer(pretrained): def get_tokenizer(
**kwargs,
) -> Union["transformers.PreTrainedTokenizer", "transformers.PreTrainedTokenizerFast"]:
kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
assert pretrained, "No tokenizer or pretrained provided."
print("using tokenizer ", pretrained) print("using tokenizer ", pretrained)
return AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True) return AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
...@@ -36,7 +43,9 @@ RANDOM_SEED = 42 ...@@ -36,7 +43,9 @@ RANDOM_SEED = 42
@cache @cache
def get_haystack(type_haystack: Literal["essay", "repeat", "needle"]): def get_haystack(
type_haystack: Literal["essay", "repeat", "needle"],
) -> Union[list[str], str]:
NEEDLE = "One of the special magic {type_needle_v} for {key} is: {value}." NEEDLE = "One of the special magic {type_needle_v} for {key} is: {value}."
if type_haystack == "essay": if type_haystack == "essay":
essay = get_all_essays()["text"] essay = get_all_essays()["text"]
...@@ -51,7 +60,7 @@ def get_haystack(type_haystack: Literal["essay", "repeat", "needle"]): ...@@ -51,7 +60,7 @@ def get_haystack(type_haystack: Literal["essay", "repeat", "needle"]):
return haystack return haystack
def flatten(df): def flatten(df: Generator) -> dict[str, datasets.Dataset]:
return { return {
"test": datasets.Dataset.from_list( "test": datasets.Dataset.from_list(
list(itertools.chain.from_iterable(df)), split=datasets.Split.TEST list(itertools.chain.from_iterable(df)), split=datasets.Split.TEST
...@@ -60,7 +69,7 @@ def flatten(df): ...@@ -60,7 +69,7 @@ def flatten(df):
# ruff: noqa # ruff: noqa
niah_single_1 = lambda x: flatten( niah_single_1 = lambda **kwargs: flatten(
generate_samples( generate_samples(
get_haystack(type_haystack="repeat"), get_haystack(type_haystack="repeat"),
max_seq_length=seq, max_seq_length=seq,
...@@ -68,7 +77,7 @@ niah_single_1 = lambda x: flatten( ...@@ -68,7 +77,7 @@ niah_single_1 = lambda x: flatten(
type_haystack="repeat", type_haystack="repeat",
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
TOKENIZER=get_tokenizer(x), TOKENIZER=get_tokenizer(**kwargs),
) )
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
...@@ -86,7 +95,7 @@ niah_single_2 = lambda x: flatten( ...@@ -86,7 +95,7 @@ niah_single_2 = lambda x: flatten(
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
# noqa # noqa
niah_single_3 = lambda x: flatten( niah_single_3 = lambda **kwargs: flatten(
generate_samples( generate_samples(
get_haystack(type_haystack="essay"), get_haystack(type_haystack="essay"),
max_seq_length=seq, max_seq_length=seq,
...@@ -94,12 +103,12 @@ niah_single_3 = lambda x: flatten( ...@@ -94,12 +103,12 @@ niah_single_3 = lambda x: flatten(
type_haystack="essay", type_haystack="essay",
type_needle_k="words", type_needle_k="words",
type_needle_v="uuids", type_needle_v="uuids",
TOKENIZER=get_tokenizer(x), TOKENIZER=get_tokenizer(**kwargs),
) )
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
# noqa # noqa
niah_multikey_1 = lambda x: flatten( niah_multikey_1 = lambda **kwargs: flatten(
generate_samples( generate_samples(
get_haystack(type_haystack="essay"), get_haystack(type_haystack="essay"),
max_seq_length=seq, max_seq_length=seq,
...@@ -108,12 +117,12 @@ niah_multikey_1 = lambda x: flatten( ...@@ -108,12 +117,12 @@ niah_multikey_1 = lambda x: flatten(
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
num_needle_k=4, num_needle_k=4,
TOKENIZER=get_tokenizer(x), TOKENIZER=get_tokenizer(**kwargs),
) )
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
# noqa # noqa
niah_multikey_2 = lambda x: flatten( niah_multikey_2 = lambda **kwargs: flatten(
generate_samples( generate_samples(
get_haystack(type_haystack="needle"), get_haystack(type_haystack="needle"),
max_seq_length=seq, max_seq_length=seq,
...@@ -121,12 +130,12 @@ niah_multikey_2 = lambda x: flatten( ...@@ -121,12 +130,12 @@ niah_multikey_2 = lambda x: flatten(
type_haystack="needle", type_haystack="needle",
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
TOKENIZER=get_tokenizer(x), TOKENIZER=get_tokenizer(**kwargs),
) )
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
# noqa # noqa
niah_multikey_3 = lambda x: flatten( niah_multikey_3 = lambda **kwargs: flatten(
generate_samples( generate_samples(
get_haystack(type_haystack="needle"), get_haystack(type_haystack="needle"),
max_seq_length=seq, max_seq_length=seq,
...@@ -134,12 +143,12 @@ niah_multikey_3 = lambda x: flatten( ...@@ -134,12 +143,12 @@ niah_multikey_3 = lambda x: flatten(
type_haystack="needle", type_haystack="needle",
type_needle_k="uuids", type_needle_k="uuids",
type_needle_v="uuids", type_needle_v="uuids",
TOKENIZER=get_tokenizer(x), TOKENIZER=get_tokenizer(**kwargs),
) )
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
# noqa # noqa
niah_multivalue = lambda x: flatten( niah_multivalue = lambda **kwargs: flatten(
generate_samples( generate_samples(
get_haystack(type_haystack="essay"), get_haystack(type_haystack="essay"),
max_seq_length=seq, max_seq_length=seq,
...@@ -148,12 +157,12 @@ niah_multivalue = lambda x: flatten( ...@@ -148,12 +157,12 @@ niah_multivalue = lambda x: flatten(
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
num_needle_v=4, num_needle_v=4,
TOKENIZER=get_tokenizer(x), TOKENIZER=get_tokenizer(**kwargs),
) )
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
# noqa # noqa
niah_multiquery = lambda x: flatten( niah_multiquery = lambda **kwargs: flatten(
generate_samples( generate_samples(
get_haystack(type_haystack="essay"), get_haystack(type_haystack="essay"),
max_seq_length=seq, max_seq_length=seq,
...@@ -162,7 +171,7 @@ niah_multiquery = lambda x: flatten( ...@@ -162,7 +171,7 @@ niah_multiquery = lambda x: flatten(
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
num_needle_q=4, num_needle_q=4,
TOKENIZER=get_tokenizer(x), TOKENIZER=get_tokenizer(**kwargs),
) )
for seq in SEQ_LENGTHS for seq in SEQ_LENGTHS
) )
......
...@@ -24,8 +24,7 @@ import numpy as np ...@@ -24,8 +24,7 @@ import numpy as np
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoTokenizer from transformers import AutoTokenizer
from lm_eval.tasks.ruler.prepare import SEQ_LENGTHS from lm_eval.tasks.ruler.utils import SEQ_LENGTHS
TASKS = { TASKS = {
"variable_tracking": { "variable_tracking": {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment