Commit fabd0d90 authored by Baber's avatar Baber
Browse files

nit

parent 40f9dc14
......@@ -9,14 +9,7 @@ if TYPE_CHECKING:
import transformers
DEFAULT_SEQ_LENGTHS = (
# 131072,
# 65536,
# 32768,
# 16384,
# 8192,
4096,
)
DEFAULT_SEQ_LENGTHS = (4096,)
@cache
......
import itertools
import logging
from typing import Generator
import datasets
......@@ -8,6 +9,7 @@ from lm_eval.tasks.ruler.prepare_niah import generate_samples, get_haystack
TEMPLATE = """Some special magic {type_needle_v} are hidden within the following text. Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n{context}\nWhat are all the special magic {type_needle_v} for {query} mentioned in the provided text?"""
eval_logger = logging.getLogger(__name__)
def download_dataset(df: Generator) -> dict[str, datasets.Dataset]:
......@@ -28,6 +30,7 @@ def niah_single_1(**kwargs):
type_haystack="repeat",
type_needle_k="words",
type_needle_v="numbers",
num_samples=500,
TOKENIZER=get_tokenizer(**kwargs),
)
for seq in seq_lengths
......@@ -44,6 +47,7 @@ def niah_single_2(**kwargs):
type_haystack="essay",
type_needle_k="words",
type_needle_v="numbers",
num_samples=500,
TOKENIZER=get_tokenizer(**kwargs),
)
for seq in seq_lengths
......@@ -60,6 +64,7 @@ def niah_single_3(**kwargs):
type_haystack="essay",
type_needle_k="words",
type_needle_v="uuids",
num_samples=500,
TOKENIZER=get_tokenizer(**kwargs),
)
for seq in seq_lengths
......@@ -77,6 +82,7 @@ def niah_multikey_1(**kwargs):
type_needle_k="words",
type_needle_v="numbers",
num_needle_k=4,
num_samples=500,
TOKENIZER=get_tokenizer(**kwargs),
)
for seq in seq_lengths
......@@ -93,6 +99,7 @@ def niah_multikey_2(**kwargs):
type_haystack="needle",
type_needle_k="words",
type_needle_v="numbers",
num_samples=500,
TOKENIZER=get_tokenizer(**kwargs),
)
for seq in seq_lengths
......@@ -109,6 +116,7 @@ def niah_multikey_3(**kwargs):
type_haystack="needle",
type_needle_k="uuids",
type_needle_v="uuids",
num_samples=500,
TOKENIZER=get_tokenizer(**kwargs),
)
for seq in seq_lengths
......@@ -126,6 +134,7 @@ def niah_multivalue(**kwargs):
type_needle_k="words",
type_needle_v="numbers",
num_needle_v=4,
num_samples=500,
TOKENIZER=get_tokenizer(**kwargs),
)
for seq in seq_lengths
......@@ -143,6 +152,7 @@ def niah_multiquery(**kwargs):
type_needle_k="words",
type_needle_v="numbers",
num_needle_q=4,
num_samples=500,
TOKENIZER=get_tokenizer(**kwargs),
)
for seq in seq_lengths
......
......@@ -112,8 +112,8 @@ def simple_parse_args_string(args_string: Optional[str]) -> dict:
return args_dict
def parse_keyed_list_string(s: str) -> dict[str, list]:
"""Parse a string of key-value pairs into a dictionary where all values are lists."""
def parse_keyed_list_string(s: str) -> dict[str, tuple]:
"""Parse a string of key-value pairs into a dictionary where all values are tuples."""
if s is None:
return {}
result = {}
......@@ -126,7 +126,7 @@ def parse_keyed_list_string(s: str) -> dict[str, list]:
if "=" in part:
# Save previous key's values
if current_key is not None:
result[current_key] = values
result[current_key] = tuple(values)
# Start new key-value pair
current_key, value = part.split("=")
......@@ -136,7 +136,7 @@ def parse_keyed_list_string(s: str) -> dict[str, list]:
# Add the last key-value pair
if current_key is not None:
result[current_key] = values
result[current_key] = tuple(values)
return result
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment