"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "253e52b1c2e1a9169b64a6d4f01063db2f8163fc"
Commit 2bc9aa6d authored by Baber's avatar Baber
Browse files

cleanup; pass metadata

parent b9614a3e
......@@ -10,7 +10,12 @@ from lm_eval import evaluator, utils
from lm_eval.evaluator import request_caching_arg_to_dict
from lm_eval.loggers import EvaluationTracker, WandbLogger
from lm_eval.tasks import TaskManager
from lm_eval.utils import handle_non_serializable, make_table, simple_parse_args_string
from lm_eval.utils import (
handle_non_serializable,
make_table,
parse_keyed_list_string,
simple_parse_args_string,
)
def _int_or_none_list_arg_type(
......@@ -266,7 +271,7 @@ def setup_parser() -> argparse.ArgumentParser:
"--metadata",
type=str,
default=None,
help="Comma separated string argument metadata to pass to task configs, for example max_context_len=4096,8192 etc.",
help="Comma separated string argument metadata to pass to task configs, for example max_context_len=4096,8192. Will be parsed as a dictionary with all values as lists.",
)
return parser
......@@ -416,7 +421,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
torch_random_seed=args.seed[2],
fewshot_random_seed=args.seed[3],
confirm_run_unsafe_code=args.confirm_run_unsafe_code,
metadata=args.metadata,
metadata=parse_keyed_list_string(args.metadata),
**request_caching_args,
)
......
......@@ -821,15 +821,7 @@ class ConfigurableTask(Task):
)
self._higher_is_better[metric_name] = is_higher_better(metric_name)
if self.config.download_dataset is None:
self.download(self.config.dataset_kwargs)
else:
self.dataset = self.config.download_dataset(
metadata=self.config.metadata,
**self.config.dataset_kwargs
if self.config.dataset_kwargs is not None
else {},
)
self.download(self.config.dataset_kwargs)
self._training_docs = None
self._fewshot_docs = None
......@@ -938,11 +930,19 @@ class ConfigurableTask(Task):
def download(
self, dataset_kwargs: Optional[Dict[str, Any]] = None, **kwargs
) -> None:
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
**dataset_kwargs if dataset_kwargs is not None else {},
)
if isinstance(self.config.download_dataset, Callable):
self.dataset = self.config.download_dataset(
**self.config.metadata,
**self.config.dataset_kwargs
if self.config.dataset_kwargs is not None
else {},
)
else:
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
**dataset_kwargs if dataset_kwargs is not None else {},
)
def has_training_docs(self) -> bool:
if self.config.training_split is not None:
......
......@@ -234,12 +234,15 @@ def simple_evaluate(
if task_manager is None:
task_manager = TaskManager(verbosity)
# TODO fix this. hack to get around the fact that we can't pass model to task config
task_dict = get_task_dict(
tasks,
task_manager,
metadata=simple_parse_args_string(model_args)
| simple_parse_args_string(metadata),
metadata=(
simple_parse_args_string(model_args)
if isinstance(model_args, str)
else model_args
)
| (metadata or {}),
)
# helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
......
......@@ -170,7 +170,6 @@ def get_dataset(pretrained, seq=None, **kwargs):
def get_cw_dataset(**kwargs):
kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
df = (
get_dataset(pretrained, seq=seq)
......
......@@ -157,7 +157,6 @@ def get_dataset(pretrained, max_seq_length=None, **kwargs):
def fwe_download(**kwargs):
kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
df = (
get_dataset(pretrained, max_seq_length=seq)
......
......@@ -28,7 +28,7 @@ def niah_single_1(**kwargs):
type_haystack="repeat",
type_needle_k="words",
type_needle_v="numbers",
TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
TOKENIZER=get_tokenizer(**kwargs),
)
for seq in seq_lengths
)
......@@ -44,7 +44,7 @@ def niah_single_2(**kwargs):
type_haystack="essay",
type_needle_k="words",
type_needle_v="numbers",
TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
TOKENIZER=get_tokenizer(**kwargs),
)
for seq in seq_lengths
)
......@@ -60,7 +60,7 @@ def niah_single_3(**kwargs):
type_haystack="essay",
type_needle_k="words",
type_needle_v="uuids",
TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
TOKENIZER=get_tokenizer(**kwargs),
)
for seq in seq_lengths
)
......@@ -77,7 +77,7 @@ def niah_multikey_1(**kwargs):
type_needle_k="words",
type_needle_v="numbers",
num_needle_k=4,
TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
TOKENIZER=get_tokenizer(**kwargs),
)
for seq in seq_lengths
)
......@@ -93,7 +93,7 @@ def niah_multikey_2(**kwargs):
type_haystack="needle",
type_needle_k="words",
type_needle_v="numbers",
TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
TOKENIZER=get_tokenizer(**kwargs),
)
for seq in seq_lengths
)
......@@ -109,7 +109,7 @@ def niah_multikey_3(**kwargs):
type_haystack="needle",
type_needle_k="uuids",
type_needle_v="uuids",
TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
TOKENIZER=get_tokenizer(**kwargs),
)
for seq in seq_lengths
)
......@@ -126,7 +126,7 @@ def niah_multivalue(**kwargs):
type_needle_k="words",
type_needle_v="numbers",
num_needle_v=4,
TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
TOKENIZER=get_tokenizer(**kwargs),
)
for seq in seq_lengths
)
......@@ -143,7 +143,7 @@ def niah_multiquery(**kwargs):
type_needle_k="words",
type_needle_v="numbers",
num_needle_q=4,
TOKENIZER=get_tokenizer(**kwargs.get("metadata")),
TOKENIZER=get_tokenizer(**kwargs),
)
for seq in seq_lengths
)
......@@ -215,7 +215,6 @@ def get_dataset(pretrained, docs, qas, max_seq_length=None, **kwargs) -> list[di
def get_qa_dataset(ds, **kwargs) -> dict[str, datasets.Dataset]:
kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
if ds == "squad":
qas, docs = read_squad()
......
......@@ -237,7 +237,6 @@ def get_dataset(pretrained, seq=None, **kwargs) -> list[dict]:
def get_vt_dataset(**kwargs) -> dict[str, datasets.Dataset]:
kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
df = (
get_dataset(pretrained, seq=seq)
......
......@@ -112,6 +112,33 @@ def simple_parse_args_string(args_string: Optional[str]) -> dict:
return args_dict
def parse_keyed_list_string(s: str) -> dict[str, list]:
"""Parse a string of key-value pairs into a dictionary where all values are lists."""
result = {}
current_key = None
values = []
parts = s.split(",")
for part in parts:
if "=" in part:
# Save previous key's values
if current_key is not None:
result[current_key] = values
# Start new key-value pair
current_key, value = part.split("=")
values = [handle_arg_string(value)]
else:
values.append(handle_arg_string(part))
# Add the last key-value pair
if current_key is not None:
result[current_key] = values
return result
def join_iters(iters):
for iter in iters:
yield from iter
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment