Commit fabd0d90 authored by Baber's avatar Baber
Browse files

nit

parent 40f9dc14
...@@ -9,14 +9,7 @@ if TYPE_CHECKING: ...@@ -9,14 +9,7 @@ if TYPE_CHECKING:
import transformers import transformers
DEFAULT_SEQ_LENGTHS = ( DEFAULT_SEQ_LENGTHS = (4096,)
# 131072,
# 65536,
# 32768,
# 16384,
# 8192,
4096,
)
@cache @cache
......
import itertools import itertools
import logging
from typing import Generator from typing import Generator
import datasets import datasets
...@@ -8,6 +9,7 @@ from lm_eval.tasks.ruler.prepare_niah import generate_samples, get_haystack ...@@ -8,6 +9,7 @@ from lm_eval.tasks.ruler.prepare_niah import generate_samples, get_haystack
TEMPLATE = """Some special magic {type_needle_v} are hidden within the following text. Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n{context}\nWhat are all the special magic {type_needle_v} for {query} mentioned in the provided text?""" TEMPLATE = """Some special magic {type_needle_v} are hidden within the following text. Make sure to memorize it. I will quiz you about the {type_needle_v} afterwards.\n{context}\nWhat are all the special magic {type_needle_v} for {query} mentioned in the provided text?"""
eval_logger = logging.getLogger(__name__)
def download_dataset(df: Generator) -> dict[str, datasets.Dataset]: def download_dataset(df: Generator) -> dict[str, datasets.Dataset]:
...@@ -28,6 +30,7 @@ def niah_single_1(**kwargs): ...@@ -28,6 +30,7 @@ def niah_single_1(**kwargs):
type_haystack="repeat", type_haystack="repeat",
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
num_samples=500,
TOKENIZER=get_tokenizer(**kwargs), TOKENIZER=get_tokenizer(**kwargs),
) )
for seq in seq_lengths for seq in seq_lengths
...@@ -44,6 +47,7 @@ def niah_single_2(**kwargs): ...@@ -44,6 +47,7 @@ def niah_single_2(**kwargs):
type_haystack="essay", type_haystack="essay",
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
num_samples=500,
TOKENIZER=get_tokenizer(**kwargs), TOKENIZER=get_tokenizer(**kwargs),
) )
for seq in seq_lengths for seq in seq_lengths
...@@ -60,6 +64,7 @@ def niah_single_3(**kwargs): ...@@ -60,6 +64,7 @@ def niah_single_3(**kwargs):
type_haystack="essay", type_haystack="essay",
type_needle_k="words", type_needle_k="words",
type_needle_v="uuids", type_needle_v="uuids",
num_samples=500,
TOKENIZER=get_tokenizer(**kwargs), TOKENIZER=get_tokenizer(**kwargs),
) )
for seq in seq_lengths for seq in seq_lengths
...@@ -77,6 +82,7 @@ def niah_multikey_1(**kwargs): ...@@ -77,6 +82,7 @@ def niah_multikey_1(**kwargs):
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
num_needle_k=4, num_needle_k=4,
num_samples=500,
TOKENIZER=get_tokenizer(**kwargs), TOKENIZER=get_tokenizer(**kwargs),
) )
for seq in seq_lengths for seq in seq_lengths
...@@ -93,6 +99,7 @@ def niah_multikey_2(**kwargs): ...@@ -93,6 +99,7 @@ def niah_multikey_2(**kwargs):
type_haystack="needle", type_haystack="needle",
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
num_samples=500,
TOKENIZER=get_tokenizer(**kwargs), TOKENIZER=get_tokenizer(**kwargs),
) )
for seq in seq_lengths for seq in seq_lengths
...@@ -109,6 +116,7 @@ def niah_multikey_3(**kwargs): ...@@ -109,6 +116,7 @@ def niah_multikey_3(**kwargs):
type_haystack="needle", type_haystack="needle",
type_needle_k="uuids", type_needle_k="uuids",
type_needle_v="uuids", type_needle_v="uuids",
num_samples=500,
TOKENIZER=get_tokenizer(**kwargs), TOKENIZER=get_tokenizer(**kwargs),
) )
for seq in seq_lengths for seq in seq_lengths
...@@ -126,6 +134,7 @@ def niah_multivalue(**kwargs): ...@@ -126,6 +134,7 @@ def niah_multivalue(**kwargs):
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
num_needle_v=4, num_needle_v=4,
num_samples=500,
TOKENIZER=get_tokenizer(**kwargs), TOKENIZER=get_tokenizer(**kwargs),
) )
for seq in seq_lengths for seq in seq_lengths
...@@ -143,6 +152,7 @@ def niah_multiquery(**kwargs): ...@@ -143,6 +152,7 @@ def niah_multiquery(**kwargs):
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
num_needle_q=4, num_needle_q=4,
num_samples=500,
TOKENIZER=get_tokenizer(**kwargs), TOKENIZER=get_tokenizer(**kwargs),
) )
for seq in seq_lengths for seq in seq_lengths
......
...@@ -112,8 +112,8 @@ def simple_parse_args_string(args_string: Optional[str]) -> dict: ...@@ -112,8 +112,8 @@ def simple_parse_args_string(args_string: Optional[str]) -> dict:
return args_dict return args_dict
def parse_keyed_list_string(s: str) -> dict[str, list]: def parse_keyed_list_string(s: str) -> dict[str, tuple]:
"""Parse a string of key-value pairs into a dictionary where all values are lists.""" """Parse a string of key-value pairs into a dictionary where all values are tuples."""
if s is None: if s is None:
return {} return {}
result = {} result = {}
...@@ -126,7 +126,7 @@ def parse_keyed_list_string(s: str) -> dict[str, list]: ...@@ -126,7 +126,7 @@ def parse_keyed_list_string(s: str) -> dict[str, list]:
if "=" in part: if "=" in part:
# Save previous key's values # Save previous key's values
if current_key is not None: if current_key is not None:
result[current_key] = values result[current_key] = tuple(values)
# Start new key-value pair # Start new key-value pair
current_key, value = part.split("=") current_key, value = part.split("=")
...@@ -136,7 +136,7 @@ def parse_keyed_list_string(s: str) -> dict[str, list]: ...@@ -136,7 +136,7 @@ def parse_keyed_list_string(s: str) -> dict[str, list]:
# Add the last key-value pair # Add the last key-value pair
if current_key is not None: if current_key is not None:
result[current_key] = values result[current_key] = tuple(values)
return result return result
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment