Commit 2bc9aa6d authored by Baber's avatar Baber
Browse files

cleanup; pass metadata

parent b9614a3e
...@@ -10,7 +10,12 @@ from lm_eval import evaluator, utils ...@@ -10,7 +10,12 @@ from lm_eval import evaluator, utils
from lm_eval.evaluator import request_caching_arg_to_dict from lm_eval.evaluator import request_caching_arg_to_dict
from lm_eval.loggers import EvaluationTracker, WandbLogger from lm_eval.loggers import EvaluationTracker, WandbLogger
from lm_eval.tasks import TaskManager from lm_eval.tasks import TaskManager
from lm_eval.utils import handle_non_serializable, make_table, simple_parse_args_string from lm_eval.utils import (
handle_non_serializable,
make_table,
parse_keyed_list_string,
simple_parse_args_string,
)
def _int_or_none_list_arg_type( def _int_or_none_list_arg_type(
...@@ -266,7 +271,7 @@ def setup_parser() -> argparse.ArgumentParser: ...@@ -266,7 +271,7 @@ def setup_parser() -> argparse.ArgumentParser:
"--metadata", "--metadata",
type=str, type=str,
default=None, default=None,
help="Comma separated string argument metadata to pass to task configs, for example max_context_len=4096,8192 etc.", help="Comma separated string argument metadata to pass to task configs, for example max_context_len=4096,8192. Will be parsed as a dictionary with all values as lists.",
) )
return parser return parser
...@@ -416,7 +421,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -416,7 +421,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
torch_random_seed=args.seed[2], torch_random_seed=args.seed[2],
fewshot_random_seed=args.seed[3], fewshot_random_seed=args.seed[3],
confirm_run_unsafe_code=args.confirm_run_unsafe_code, confirm_run_unsafe_code=args.confirm_run_unsafe_code,
metadata=args.metadata, metadata=parse_keyed_list_string(args.metadata),
**request_caching_args, **request_caching_args,
) )
......
...@@ -821,15 +821,7 @@ class ConfigurableTask(Task): ...@@ -821,15 +821,7 @@ class ConfigurableTask(Task):
) )
self._higher_is_better[metric_name] = is_higher_better(metric_name) self._higher_is_better[metric_name] = is_higher_better(metric_name)
if self.config.download_dataset is None: self.download(self.config.dataset_kwargs)
self.download(self.config.dataset_kwargs)
else:
self.dataset = self.config.download_dataset(
metadata=self.config.metadata,
**self.config.dataset_kwargs
if self.config.dataset_kwargs is not None
else {},
)
self._training_docs = None self._training_docs = None
self._fewshot_docs = None self._fewshot_docs = None
...@@ -938,11 +930,19 @@ class ConfigurableTask(Task): ...@@ -938,11 +930,19 @@ class ConfigurableTask(Task):
def download( def download(
self, dataset_kwargs: Optional[Dict[str, Any]] = None, **kwargs self, dataset_kwargs: Optional[Dict[str, Any]] = None, **kwargs
) -> None: ) -> None:
self.dataset = datasets.load_dataset( if isinstance(self.config.download_dataset, Callable):
path=self.DATASET_PATH, self.dataset = self.config.download_dataset(
name=self.DATASET_NAME, **self.config.metadata,
**dataset_kwargs if dataset_kwargs is not None else {}, **self.config.dataset_kwargs
) if self.config.dataset_kwargs is not None
else {},
)
else:
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
**dataset_kwargs if dataset_kwargs is not None else {},
)
def has_training_docs(self) -> bool: def has_training_docs(self) -> bool:
if self.config.training_split is not None: if self.config.training_split is not None:
......
...@@ -234,12 +234,15 @@ def simple_evaluate( ...@@ -234,12 +234,15 @@ def simple_evaluate(
if task_manager is None: if task_manager is None:
task_manager = TaskManager(verbosity) task_manager = TaskManager(verbosity)
# TODO fix this. hack to get around the fact that we can't pass model to task config
task_dict = get_task_dict( task_dict = get_task_dict(
tasks, tasks,
task_manager, task_manager,
metadata=simple_parse_args_string(model_args) metadata=(
| simple_parse_args_string(metadata), simple_parse_args_string(model_args)
if isinstance(model_args, str)
else model_args
)
| (metadata or {}),
) )
# helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups. # helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
......
...@@ -170,7 +170,6 @@ def get_dataset(pretrained, seq=None, **kwargs): ...@@ -170,7 +170,6 @@ def get_dataset(pretrained, seq=None, **kwargs):
def get_cw_dataset(**kwargs): def get_cw_dataset(**kwargs):
kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {})) pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
df = ( df = (
get_dataset(pretrained, seq=seq) get_dataset(pretrained, seq=seq)
......
...@@ -157,7 +157,6 @@ def get_dataset(pretrained, max_seq_length=None, **kwargs): ...@@ -157,7 +157,6 @@ def get_dataset(pretrained, max_seq_length=None, **kwargs):
def fwe_download(**kwargs): def fwe_download(**kwargs):
kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {})) pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
df = ( df = (
get_dataset(pretrained, max_seq_length=seq) get_dataset(pretrained, max_seq_length=seq)
......
...@@ -28,7 +28,7 @@ def niah_single_1(**kwargs): ...@@ -28,7 +28,7 @@ def niah_single_1(**kwargs):
type_haystack="repeat", type_haystack="repeat",
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
TOKENIZER=get_tokenizer(**kwargs.get("metadata")), TOKENIZER=get_tokenizer(**kwargs),
) )
for seq in seq_lengths for seq in seq_lengths
) )
...@@ -44,7 +44,7 @@ def niah_single_2(**kwargs): ...@@ -44,7 +44,7 @@ def niah_single_2(**kwargs):
type_haystack="essay", type_haystack="essay",
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
TOKENIZER=get_tokenizer(**kwargs.get("metadata")), TOKENIZER=get_tokenizer(**kwargs),
) )
for seq in seq_lengths for seq in seq_lengths
) )
...@@ -60,7 +60,7 @@ def niah_single_3(**kwargs): ...@@ -60,7 +60,7 @@ def niah_single_3(**kwargs):
type_haystack="essay", type_haystack="essay",
type_needle_k="words", type_needle_k="words",
type_needle_v="uuids", type_needle_v="uuids",
TOKENIZER=get_tokenizer(**kwargs.get("metadata")), TOKENIZER=get_tokenizer(**kwargs),
) )
for seq in seq_lengths for seq in seq_lengths
) )
...@@ -77,7 +77,7 @@ def niah_multikey_1(**kwargs): ...@@ -77,7 +77,7 @@ def niah_multikey_1(**kwargs):
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
num_needle_k=4, num_needle_k=4,
TOKENIZER=get_tokenizer(**kwargs.get("metadata")), TOKENIZER=get_tokenizer(**kwargs),
) )
for seq in seq_lengths for seq in seq_lengths
) )
...@@ -93,7 +93,7 @@ def niah_multikey_2(**kwargs): ...@@ -93,7 +93,7 @@ def niah_multikey_2(**kwargs):
type_haystack="needle", type_haystack="needle",
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
TOKENIZER=get_tokenizer(**kwargs.get("metadata")), TOKENIZER=get_tokenizer(**kwargs),
) )
for seq in seq_lengths for seq in seq_lengths
) )
...@@ -109,7 +109,7 @@ def niah_multikey_3(**kwargs): ...@@ -109,7 +109,7 @@ def niah_multikey_3(**kwargs):
type_haystack="needle", type_haystack="needle",
type_needle_k="uuids", type_needle_k="uuids",
type_needle_v="uuids", type_needle_v="uuids",
TOKENIZER=get_tokenizer(**kwargs.get("metadata")), TOKENIZER=get_tokenizer(**kwargs),
) )
for seq in seq_lengths for seq in seq_lengths
) )
...@@ -126,7 +126,7 @@ def niah_multivalue(**kwargs): ...@@ -126,7 +126,7 @@ def niah_multivalue(**kwargs):
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
num_needle_v=4, num_needle_v=4,
TOKENIZER=get_tokenizer(**kwargs.get("metadata")), TOKENIZER=get_tokenizer(**kwargs),
) )
for seq in seq_lengths for seq in seq_lengths
) )
...@@ -143,7 +143,7 @@ def niah_multiquery(**kwargs): ...@@ -143,7 +143,7 @@ def niah_multiquery(**kwargs):
type_needle_k="words", type_needle_k="words",
type_needle_v="numbers", type_needle_v="numbers",
num_needle_q=4, num_needle_q=4,
TOKENIZER=get_tokenizer(**kwargs.get("metadata")), TOKENIZER=get_tokenizer(**kwargs),
) )
for seq in seq_lengths for seq in seq_lengths
) )
...@@ -215,7 +215,6 @@ def get_dataset(pretrained, docs, qas, max_seq_length=None, **kwargs) -> list[di ...@@ -215,7 +215,6 @@ def get_dataset(pretrained, docs, qas, max_seq_length=None, **kwargs) -> list[di
def get_qa_dataset(ds, **kwargs) -> dict[str, datasets.Dataset]: def get_qa_dataset(ds, **kwargs) -> dict[str, datasets.Dataset]:
kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {})) pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
if ds == "squad": if ds == "squad":
qas, docs = read_squad() qas, docs = read_squad()
......
...@@ -237,7 +237,6 @@ def get_dataset(pretrained, seq=None, **kwargs) -> list[dict]: ...@@ -237,7 +237,6 @@ def get_dataset(pretrained, seq=None, **kwargs) -> list[dict]:
def get_vt_dataset(**kwargs) -> dict[str, datasets.Dataset]: def get_vt_dataset(**kwargs) -> dict[str, datasets.Dataset]:
kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {})) pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
df = ( df = (
get_dataset(pretrained, seq=seq) get_dataset(pretrained, seq=seq)
......
...@@ -112,6 +112,33 @@ def simple_parse_args_string(args_string: Optional[str]) -> dict: ...@@ -112,6 +112,33 @@ def simple_parse_args_string(args_string: Optional[str]) -> dict:
return args_dict return args_dict
def parse_keyed_list_string(s: str) -> dict[str, list]:
"""Parse a string of key-value pairs into a dictionary where all values are lists."""
result = {}
current_key = None
values = []
parts = s.split(",")
for part in parts:
if "=" in part:
# Save previous key's values
if current_key is not None:
result[current_key] = values
# Start new key-value pair
current_key, value = part.split("=")
values = [handle_arg_string(value)]
else:
values.append(handle_arg_string(part))
# Add the last key-value pair
if current_key is not None:
result[current_key] = values
return result
def join_iters(iters): def join_iters(iters):
for iter in iters: for iter in iters:
yield from iter yield from iter
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment