Commit c1e43393 authored by artemorloff's avatar artemorloff
Browse files

add separate eval_config class

parent b2e1bfc6
......@@ -8,17 +8,21 @@ from pathlib import Path
from typing import Union
from lm_eval import evaluator, utils
from lm_eval.evaluator import request_caching_arg_to_dict
# from lm_eval.evaluator import request_caching_arg_to_dict
from lm_eval.loggers import EvaluationTracker, WandbLogger
from lm_eval.tasks import TaskManager
from lm_eval.utils import (
TrackExplicitAction,
TrackExplicitStoreTrue,
handle_non_serializable,
load_yaml_config,
make_table,
non_default_update,
parse_namespace,
request_caching_arg_to_dict,
# non_default_update,
# parse_namespace,
)
from lm_eval.api.eval_config import (
TrackExplicitAction,
TrackExplicitStoreTrue,
EvaluationConfig,
)
......@@ -353,96 +357,93 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
parser = setup_parser()
args = parse_eval_args(parser)
# get namespace from console, including config arg
config, non_default_args = parse_namespace(args)
# if config is passed, load it
if config.get("config", False):
local_config = load_yaml_config(yaml_path=config["config"])
config = non_default_update(config, local_config, non_default_args)
config = EvaluationConfig.from_cli(args)
if args.wandb_args:
wandb_logger = WandbLogger(config["wandb_args"], config["wandb_config_args"])
wandb_logger = WandbLogger(config.wandb_args, config.wandb_config_args)
utils.setup_logging(config["verbosity"])
# utils.setup_logging(args.verbosity)
utils.setup_logging(config.verbosity)
eval_logger = logging.getLogger(__name__)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# update the evaluation tracker args with the output path and the HF token
if config["output_path"]:
config.setdefault("hf_hub_log_args", {})["output_path"] = config["output_path"]
if config.output_path:
config.hf_hub_log_args["output_path"] = config.output_path
if os.environ.get("HF_TOKEN", None):
config.setdefault("hf_hub_log_args", {})["token"] = os.environ.get("HF_TOKEN")
evaluation_tracker_args = config["hf_hub_log_args"]
config.hf_hub_log_args["token"] = os.environ.get("HF_TOKEN")
evaluation_tracker_args = config.hf_hub_log_args
evaluation_tracker = EvaluationTracker(**evaluation_tracker_args)
if config["predict_only"]:
config["log_samples"] = True
if config.predict_only:
config.log_samples = True
if (config["log_samples"] or config["predict_only"]) and not config["output_path"]:
if (config.log_samples or config.predict_only) and not config.output_path:
raise ValueError(
"Specify --output_path if providing --log_samples or --predict_only"
)
if config["fewshot_as_multiturn"] and config["apply_chat_template"] is False:
if config.fewshot_as_multiturn and config.apply_chat_template is False:
raise ValueError(
"When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set (either to `True` or to the chosen template name)."
)
if config["include_path"] is not None:
eval_logger.info(f"Including path: {config['include_path']}")
if config.include_path is not None:
eval_logger.info(f"Including path: {config.include_path}")
metadata = (config["model_args"]) | (config["metadata"])
metadata = (config.model_args) | (config.metadata)
config.metadata = metadata
task_manager = TaskManager(include_path=config["include_path"], metadata=metadata)
# task_manager = TaskManager(include_path=config["include_path"], metadata=metadata)
task_manager = TaskManager(include_path=config.include_path, metadata=metadata)
if "push_samples_to_hub" in evaluation_tracker_args and not config["log_samples"]:
if "push_samples_to_hub" in evaluation_tracker_args and not config.log_samples:
eval_logger.warning(
"Pushing samples to the Hub requires --log_samples to be set. Samples will not be pushed to the Hub."
)
if config["limit"]:
if config.limit:
eval_logger.warning(
" --limit SHOULD ONLY BE USED FOR TESTING."
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
if config["samples"]:
assert config["limit"] is None, (
if config.samples:
assert config.limit is None, (
"If --samples is not None, then --limit must be None."
)
if (samples := Path(config["samples"])).is_file():
config["samples"] = json.loads(samples.read_text())
if (samples := Path(config.samples)).is_file():
config.samples = json.loads(samples.read_text())
else:
config["samples"] = json.loads(config["samples"])
config.samples = json.loads(config.samples)
if config["tasks"] is None:
if config.tasks is None:
eval_logger.error("Need to specify task to evaluate.")
sys.exit()
elif config["tasks"] == "list":
elif config.tasks == "list":
print(task_manager.list_all_tasks())
sys.exit()
elif config["tasks"] == "list_groups":
elif config.tasks == "list_groups":
print(task_manager.list_all_tasks(list_subtasks=False, list_tags=False))
sys.exit()
elif config["tasks"] == "list_tags":
elif config.tasks == "list_tags":
print(task_manager.list_all_tasks(list_groups=False, list_subtasks=False))
sys.exit()
elif config["tasks"] == "list_subtasks":
elif config.tasks == "list_subtasks":
print(task_manager.list_all_tasks(list_groups=False, list_tags=False))
sys.exit()
else:
if os.path.isdir(config["tasks"]):
if os.path.isdir(config.tasks):
import glob
task_names = []
yaml_path = os.path.join(config["tasks"], "*.yaml")
yaml_path = os.path.join(config.tasks, "*.yaml")
for yaml_file in glob.glob(yaml_path):
config = utils.load_yaml_config(yaml_file)
task_names.append(config)
else:
task_list = config["tasks"].split(",")
task_list = config.tasks.split(",")
task_names = task_manager.match_tasks(task_list)
for task in [task for task in task_list if task not in task_names]:
if os.path.isfile(task):
......@@ -461,9 +462,10 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
raise ValueError(
f"Tasks not found: {missing}. Try `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above, or pass '--verbosity DEBUG' to troubleshoot task registration issues."
)
config.tasks = task_names
# Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args
if config["trust_remote_code"]:
if config.trust_remote_code:
eval_logger.info(
"Passed `--trust_remote_code`, setting environment variable `HF_DATASETS_TRUST_REMOTE_CODE=true`"
)
......@@ -474,7 +476,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
config.setdefault("model_args", {})["trust_remote_code"] = True
config.model_args["trust_remote_code"] = True
(
eval_logger.info(f"Selected Tasks: {task_names}")
if eval_logger.getEffectiveLevel() >= logging.INFO
......@@ -482,56 +484,35 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
)
request_caching_args = request_caching_arg_to_dict(
cache_requests=config["cache_requests"]
cache_requests=config.cache_requests
)
config.request_caching_args = request_caching_args
print(f"CONFIG_AFTER: {config}")
results = evaluator.simple_evaluate(
model=config["model"],
model_args=config["model_args"],
tasks=task_names,
num_fewshot=config["num_fewshot"],
batch_size=config["batch_size"],
max_batch_size=config["max_batch_size"],
device=config["device"],
use_cache=config["use_cache"],
limit=config["limit"],
samples=config["samples"],
check_integrity=config["check_integrity"],
write_out=config["write_out"],
log_samples=config["log_samples"],
config=config,
evaluation_tracker=evaluation_tracker,
system_instruction=config["system_instruction"],
apply_chat_template=config["apply_chat_template"],
fewshot_as_multiturn=config["fewshot_as_multiturn"],
gen_kwargs=config["gen_kwargs"],
task_manager=task_manager,
predict_only=config["predict_only"],
random_seed=config["seed"][0],
numpy_random_seed=config["seed"][1],
torch_random_seed=config["seed"][2],
fewshot_random_seed=config["seed"][3],
confirm_run_unsafe_code=config["confirm_run_unsafe_code"],
metadata=metadata,
**request_caching_args,
)
if results is not None:
if config["log_samples"]:
if config.log_samples:
samples = results.pop("samples")
dumped = json.dumps(
results, indent=2, default=handle_non_serializable, ensure_ascii=False
)
if config["show_config"]:
if config.show_config:
print(dumped)
batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))
# Add W&B logging
if config["wandb_args"]:
if config.wandb_args:
try:
wandb_logger.post_init(results)
wandb_logger.log_eval_result()
if config["log_samples"]:
if config.log_samples:
wandb_logger.log_eval_samples(samples)
except Exception as e:
eval_logger.info(f"Logging to Weights and Biases failed due to {e}")
......@@ -540,7 +521,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
results=results, samples=samples if args.log_samples else None
)
if config["log_samples"]:
if config.log_samples:
for task_name, _ in results["configs"].items():
evaluation_tracker.save_results_samples(
task_name=task_name, samples=samples[task_name]
......@@ -553,14 +534,14 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
evaluation_tracker.recreate_metadata_card()
print(
f"{config['model']} ({config['model_args']}), gen_kwargs: ({config['gen_kwargs']}), limit: {config['limit']}, num_fewshot: {config['num_fewshot']}, "
f"batch_size: {config['batch_size']}{f' ({batch_sizes})' if batch_sizes else ''}"
f"{config.model} ({config.model_args}), gen_kwargs: ({config.gen_kwargs}), limit: {config.limit}, num_fewshot: {config.num_fewshot}, "
f"batch_size: {config.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
)
print(make_table(results))
if "groups" in results:
print(make_table(results, "groups"))
if config["wandb_args"]:
if config.wandb_args:
# Tear down wandb run once all the logging is done.
wandb_logger.run.finish()
......
import os
import yaml
from argparse import Namespace
from typing import Any, Dict, Union, Optional
import argparse
from pydantic import BaseModel
from lm_eval.utils import simple_parse_args_string
DICT_KEYS = [
"wandb_args",
"wandb_config_args",
"hf_hub_log_args",
"metadata",
"model_args",
]
class EvaluationConfig(BaseModel):
"""
Simple config container for language-model evaluation.
No content validation here—just holds whatever comes from YAML or CLI.
"""
config: Optional[str]
model: Optional[str]
model_args: Optional[dict]
tasks: Optional[str]
num_fewshot: Optional[int]
batch_size: Optional[int]
max_batch_size: Optional[int]
device: Optional[str]
output_path: Optional[str]
limit: Optional[float]
samples: Optional[str]
use_cache: Optional[str]
cache_requests: Optional[str]
check_integrity: Optional[bool]
write_out: Optional[bool]
log_samples: Optional[bool]
predict_only: Optional[bool]
system_instruction: Optional[str]
apply_chat_template: Optional[Union[bool, str]]
fewshot_as_multiturn: Optional[bool]
show_config: Optional[bool]
include_path: Optional[str]
gen_kwargs: Optional[dict]
verbosity: Optional[str]
wandb_args: Optional[dict]
wandb_config_args: Optional[dict]
hf_hub_log_args: Optional[dict]
seed: Optional[list]
trust_remote_code: Optional[bool]
confirm_run_unsafe_code: Optional[bool]
metadata: Optional[dict]
request_caching_args: Optional[dict] = None
@staticmethod
def parse_namespace(namespace: argparse.Namespace) -> Dict[str, Any]:
"""
Convert an argparse Namespace object to a dictionary.
Handles all argument types including boolean flags, lists, and None values.
Only includes arguments that were explicitly set (ignores defaults unless used).
Args:
namespace: The argparse.Namespace object to convert
Returns:
Dictionary containing all the namespace arguments and their values
"""
config = {key: value for key, value in vars(namespace).items()}
for key in config:
if key == "_explicit_args":
continue
if key in DICT_KEYS:
if not isinstance(config[key], dict):
config[key] = simple_parse_args_string(config[key])
# if key == "cache_requests":
# config[key] = request_caching_arg_to_dict(config[key])
if "model_args" not in config:
config["model_args"] = {}
if "metadata" not in config:
config["metadata"] = {}
non_default_args = []
if hasattr(namespace, "_explicit_args"):
non_default_args = namespace._explicit_args
return config, non_default_args
@staticmethod
def non_default_update(console_dict, local_dict, non_default_args):
"""
Update local_dict by taking non-default values from console_dict.
Args:
console_dict (dict): The dictionary that potentially provides updates.
local_dict (dict): The dictionary to be updated.
non_default_args (list): The list of args passed by user in console.
Returns:
dict: The updated local_dict.
"""
result_config = {}
for key in set(console_dict.keys()).union(local_dict.keys()):
if key in non_default_args:
result_config[key] = console_dict[key]
else:
if key in local_dict:
result_config[key] = local_dict[key]
else:
result_config[key] = console_dict[key]
return result_config
@classmethod
def from_cli(cls, namespace: Namespace) -> "EvaluationConfig":
"""
Build an EvaluationConfig by merging:
1) YAML config (if --config was passed), then
2) CLI args (only those the user actually provided)
CLI defaults that weren’t overridden explicitly will be
overwritten by YAML values if present.
"""
# 1. Turn Namespace into a dict + know which were explicitly passed
args_dict, explicit_args = EvaluationConfig.parse_namespace(namespace)
# 2. Start from YAML if requested
config_data: Dict[str, Any] = {}
if "config" in explicit_args and args_dict.get("config"):
cfg_path = args_dict["config"]
if not os.path.isfile(cfg_path):
raise FileNotFoundError(f"Config file not found: {cfg_path}")
try:
with open(cfg_path, "r") as f:
yaml_data = yaml.safe_load(f)
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML in {cfg_path}: {e}")
if not isinstance(yaml_data, dict):
raise ValueError(f"YAML root must be a mapping, got {type(yaml_data).__name__}")
config_data.update(yaml_data)
# 3. Override with any CLI args the user explicitly passed
# for key, val in args_dict.items():
# if key == "config":
# continue
# if key in explicit_args:
# config_data[key] = val
print(f"YAML: {config_data}")
print(f"CLI: {args_dict}")
dict_config = EvaluationConfig.non_default_update(args_dict, config_data, explicit_args)
# 4. Instantiate the Pydantic model (no further validation here)
return cls(**dict_config)
class TrackExplicitAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
# Create a set on the namespace to track explicitly set arguments if it doesn't exist
if not hasattr(namespace, "_explicit_args"):
setattr(namespace, "_explicit_args", set())
# Record that this argument was explicitly provided
namespace._explicit_args.add(self.dest)
setattr(namespace, self.dest, values)
class TrackExplicitStoreTrue(argparse.Action):
def __init__(
self, option_strings, dest, nargs=0, const=True, default=False, **kwargs
):
# Ensure that nargs is 0, as required by store_true actions.
if nargs != 0:
raise ValueError("nargs for store_true actions must be 0")
super().__init__(
option_strings, dest, nargs=0, const=const, default=default, **kwargs
)
def __call__(self, parser, namespace, values, option_string=None):
# Initialize or update the set of explicitly provided arguments.
if not hasattr(namespace, "_explicit_args"):
setattr(namespace, "_explicit_args", set())
namespace._explicit_args.add(self.dest)
setattr(namespace, self.dest, self.const)
......@@ -34,6 +34,7 @@ from lm_eval.utils import (
setup_logging,
simple_parse_args_string,
)
from lm_eval.api.eval_config import EvaluationConfig
if TYPE_CHECKING:
......@@ -45,37 +46,11 @@ eval_logger = logging.getLogger(__name__)
@positional_deprecated
def simple_evaluate(
model,
model_args: Optional[Union[str, dict]] = None,
tasks: Optional[List[Union[str, dict, object]]] = None,
num_fewshot: Optional[int] = None,
batch_size: Optional[Union[int, str]] = None,
max_batch_size: Optional[int] = None,
device: Optional[str] = None,
use_cache: Optional[str] = None,
cache_requests: bool = False,
rewrite_requests_cache: bool = False,
delete_requests_cache: bool = False,
limit: Optional[Union[int, float]] = None,
samples: Optional[dict] = None,
config: "EvaluationConfig",
# TODO: bootstrap_iters is not passed from cli_evaluate
bootstrap_iters: int = 100000,
check_integrity: bool = False,
write_out: bool = False,
log_samples: bool = True,
evaluation_tracker: Optional[EvaluationTracker] = None,
system_instruction: Optional[str] = None,
apply_chat_template: Union[bool, str] = False,
fewshot_as_multiturn: bool = False,
gen_kwargs: Union[str, dict, None] = None,
task_manager: Optional[TaskManager] = None,
verbosity=None,
predict_only: bool = False,
random_seed: int = 0,
numpy_random_seed: int = 1234,
torch_random_seed: int = 1234,
fewshot_random_seed: int = 1234,
confirm_run_unsafe_code: bool = False,
metadata: Optional[dict] = None,
):
"""Instantiate and evaluate a model on a list of tasks.
......@@ -144,106 +119,106 @@ def simple_evaluate(
return
Dictionary of results
"""
if verbosity is not None:
setup_logging(verbosity=verbosity)
if config.verbosity is not None:
setup_logging(verbosity=config.verbosity)
start_date = time.time()
if limit is not None and samples is not None:
if config.limit is not None and config.samples is not None:
raise ValueError(
"Either 'limit' or 'samples' must be None, but both are not None."
)
if isinstance(model_args, str) and (
"instruct" in model_args and not apply_chat_template
if isinstance(config.model_args, str) and (
"instruct" in config.model_args and not config.apply_chat_template
):
eval_logger.warning(
"Instruct model detected, but chat template not applied. Recommend setting `apply_chat_template` (optionally `fewshot_as_multiturn`)."
)
if delete_requests_cache:
if config.request_caching_args.get("delete_requests_cache", False):
eval_logger.info("Deleting requests cache...")
delete_cache()
seed_message = []
if random_seed is not None:
if config.seed[0] is not None:
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1412
seed_message.append(f"Setting random seed to {random_seed}")
random.seed(random_seed)
seed_message.append(f"Setting random seed to {config.seed[0]}")
random.seed(config.seed[0])
if numpy_random_seed is not None:
seed_message.append(f"Setting numpy seed to {numpy_random_seed}")
np.random.seed(numpy_random_seed)
if config.seed[1] is not None:
seed_message.append(f"Setting numpy seed to {config.seed[1]}")
np.random.seed(config.seed[1])
if torch_random_seed is not None:
seed_message.append(f"Setting torch manual seed to {torch_random_seed}")
torch.manual_seed(torch_random_seed)
if config.seed[2] is not None:
seed_message.append(f"Setting torch manual seed to {config.seed[2]}")
torch.manual_seed(config.seed[2])
if fewshot_random_seed is not None:
seed_message.append(f"Setting fewshot manual seed to {fewshot_random_seed}")
if config.seed[3] is not None:
seed_message.append(f"Setting fewshot manual seed to {config.seed[3]}")
if seed_message:
eval_logger.info(" | ".join(seed_message))
if tasks is None:
tasks = []
if len(tasks) == 0:
if config.tasks is None:
config.tasks = []
if len(config.tasks) == 0:
raise ValueError(
"No tasks specified, or no tasks found. Please verify the task names."
)
if gen_kwargs is not None:
if isinstance(gen_kwargs, str):
gen_kwargs = simple_parse_args_string(gen_kwargs)
if config.gen_kwargs is not None:
if isinstance(config.gen_kwargs, str):
config.gen_kwargs = simple_parse_args_string(config.gen_kwargs)
eval_logger.warning(
f"generation_kwargs: {gen_kwargs} specified through cli, these settings will update set parameters in yaml tasks. "
f"generation_kwargs: {config.gen_kwargs} specified through cli, these settings will update set parameters in yaml tasks. "
"Ensure 'do_sample=True' for non-greedy decoding!"
)
if not gen_kwargs:
gen_kwargs = None
if not config.gen_kwargs:
config.gen_kwargs = None
if isinstance(model, str):
if model_args is None:
if isinstance(config.model, str):
if config.model_args is None:
eval_logger.warning("model_args not specified. Using defaults.")
model_args = ""
config.model_args = ""
if isinstance(model_args, dict):
if isinstance(config.model_args, dict):
eval_logger.info(
f"Initializing {model} model, with arguments: {model_args}"
f"Initializing {config.model} model, with arguments: {config.model_args}"
)
lm = lm_eval.api.registry.get_model(model).create_from_arg_obj(
model_args,
lm = lm_eval.api.registry.get_model(config.model).create_from_arg_obj(
config.model_args,
{
"batch_size": batch_size,
"max_batch_size": max_batch_size,
"device": device,
"batch_size": config.batch_size,
"max_batch_size": config.max_batch_size,
"device": config.device,
},
)
else:
eval_logger.info(
f"Initializing {model} model, with arguments: {simple_parse_args_string(model_args)}"
f"Initializing {config.model} model, with arguments: {simple_parse_args_string(config.model_args)}"
)
lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
model_args,
lm = lm_eval.api.registry.get_model(config.model).create_from_arg_string(
config.model_args,
{
"batch_size": batch_size,
"max_batch_size": max_batch_size,
"device": device,
"batch_size": config.batch_size,
"max_batch_size": config.max_batch_size,
"device": config.device,
},
)
else:
if not isinstance(model, lm_eval.api.model.LM):
if not isinstance(config.model, lm_eval.api.model.LM):
raise TypeError(
f"The value of `model` passed to simple_evaluate() was of type {type(model)}, but is required to be a subclass of lm_eval.api.model.LM . This may be because you are passing an initialized Hugging Face PreTrainedModel without having wrapped it in `lm_eval.models.huggingface.HFLM(pretrained=my_model)` first."
f"The value of `model` passed to simple_evaluate() was of type {type(config.model)}, but is required to be a subclass of lm_eval.api.model.LM . This may be because you are passing an initialized Hugging Face PreTrainedModel without having wrapped it in `lm_eval.models.huggingface.HFLM(pretrained=my_model)` first."
)
eval_logger.info("Using pre-initialized model")
lm = model
lm = config.model
if use_cache is not None:
eval_logger.info(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
if config.use_cache is not None:
eval_logger.info(f"Using cache at {config.use_cache + '_rank' + str(lm.rank) + '.db'}")
lm = lm_eval.api.model.CachingLM(
lm,
use_cache
config.use_cache
# each rank receives a different cache db.
# necessary to avoid multiple writes to cache at once
+ "_rank"
......@@ -252,17 +227,10 @@ def simple_evaluate(
)
if task_manager is None:
metadata = (
simple_parse_args_string(model_args)
if isinstance(model_args, str)
else model_args
if isinstance(model_args, dict)
else {}
) | (metadata or {})
task_manager = TaskManager(metadata=metadata)
task_manager = TaskManager(metadata=config.metadata)
task_dict = get_task_dict(
tasks,
config.tasks,
task_manager,
)
......@@ -279,15 +247,15 @@ def simple_evaluate(
else:
if task_obj.get_config("output_type") == "generate_until":
if gen_kwargs is not None:
if config.gen_kwargs is not None:
task_obj.set_config(
key="generation_kwargs", value=gen_kwargs, update=True
key="generation_kwargs", value=config.gen_kwargs, update=True
)
eval_logger.info(
f"{task_obj.config.task}: Using gen_kwargs: {task_obj.config.generation_kwargs}"
)
if predict_only:
if config.predict_only:
eval_logger.info(
f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
)
......@@ -296,7 +264,7 @@ def simple_evaluate(
# override tasks' fewshot values to the provided num_fewshot arg value
# except if tasks have it set to 0 manually in their configs--then we should never overwrite that
if num_fewshot is not None:
if config.num_fewshot is not None:
if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0:
eval_logger.info(
f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
......@@ -305,7 +273,7 @@ def simple_evaluate(
eval_logger.warning(
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
)
task_obj.set_config(key="num_fewshot", value=num_fewshot)
task_obj.set_config(key="num_fewshot", value=config.num_fewshot)
else:
# if num_fewshot not provided, and the task does not define a default one, default to 0
if (
......@@ -313,7 +281,7 @@ def simple_evaluate(
) is None:
task_obj.set_config(key="num_fewshot", value=0)
# fewshot_random_seed set for tasks, even with a default num_fewshot (e.g. in the YAML file)
task_obj.set_fewshot_seed(seed=fewshot_random_seed)
task_obj.set_fewshot_seed(seed=config.seed[3])
adjusted_task_dict[task_name] = task_obj
......@@ -321,51 +289,51 @@ def simple_evaluate(
task_dict = _adjust_config(task_dict)
if check_integrity:
run_task_tests(task_list=tasks)
if config.check_integrity:
run_task_tests(task_list=config.tasks)
if evaluation_tracker is not None:
evaluation_tracker.general_config_tracker.log_experiment_args(
model_source=model,
model_args=model_args,
system_instruction=system_instruction,
chat_template=lm.chat_template(apply_chat_template)
if apply_chat_template
model_source=config.model,
model_args=config.model_args,
system_instruction=config.system_instruction,
chat_template=lm.chat_template(config.apply_chat_template)
if config.apply_chat_template
else None,
fewshot_as_multiturn=fewshot_as_multiturn,
fewshot_as_multiturn=config.fewshot_as_multiturn,
)
results = evaluate(
lm=lm,
task_dict=task_dict,
limit=limit,
samples=samples,
cache_requests=cache_requests,
rewrite_requests_cache=rewrite_requests_cache,
limit=config.limit,
samples=config.samples,
cache_requests=config.cache_requests,
rewrite_requests_cache=config.request_caching_args.get("rewrite_requests_cache", False),
bootstrap_iters=bootstrap_iters,
write_out=write_out,
log_samples=True if predict_only else log_samples,
system_instruction=system_instruction,
apply_chat_template=apply_chat_template,
fewshot_as_multiturn=fewshot_as_multiturn,
verbosity=verbosity,
confirm_run_unsafe_code=confirm_run_unsafe_code,
write_out=config.write_out,
log_samples=True if config.predict_only else config.log_samples,
system_instruction=config.system_instruction,
apply_chat_template=config.apply_chat_template,
fewshot_as_multiturn=config.fewshot_as_multiturn,
verbosity=config.verbosity,
confirm_run_unsafe_code=config.confirm_run_unsafe_code,
)
if verbosity is not None:
setup_logging(verbosity=verbosity)
if config.verbosity is not None:
setup_logging(verbosity=config.verbosity)
if lm.rank == 0:
if isinstance(model, str):
model_name = model
elif hasattr(model, "config") and hasattr(model.config, "_name_or_path"):
model_name = model.config._name_or_path
if isinstance(config.model, str):
model_name = config.model
elif hasattr(config.model, "config") and hasattr(config.model.config, "_name_or_path"):
model_name = config.model.config._name_or_path
else:
model_name = type(model).__name__
model_name = type(config.model).__name__
# add info about the model and few shot config
results["config"] = {
"model": model_name,
"model_args": model_args,
"model_args": config.model_args,
}
# add more detailed model info if available
if isinstance(lm, lm_eval.models.huggingface.HFLM):
......@@ -373,19 +341,19 @@ def simple_evaluate(
# add info about execution
results["config"].update(
{
"batch_size": batch_size,
"batch_size": config.batch_size,
"batch_sizes": (
list(lm.batch_sizes.values()) if hasattr(lm, "batch_sizes") else []
),
"device": device,
"use_cache": use_cache,
"limit": limit,
"device": config.device,
"use_cache": config.use_cache,
"limit": config.limit,
"bootstrap_iters": bootstrap_iters,
"gen_kwargs": gen_kwargs,
"random_seed": random_seed,
"numpy_seed": numpy_random_seed,
"torch_seed": torch_random_seed,
"fewshot_seed": fewshot_random_seed,
"gen_kwargs": config.gen_kwargs,
"random_seed": config.seed[0],
"numpy_seed": config.seed[1],
"torch_seed": config.seed[2],
"fewshot_seed": config.seed[3],
}
)
results["git_hash"] = get_git_commit_hash()
......@@ -755,11 +723,11 @@ def evaluate(
return None
def request_caching_arg_to_dict(cache_requests: str) -> dict:
request_caching_args = {
"cache_requests": cache_requests in {"true", "refresh"},
"rewrite_requests_cache": cache_requests == "refresh",
"delete_requests_cache": cache_requests == "delete",
}
# def request_caching_arg_to_dict(cache_requests: str) -> dict:
# request_caching_args = {
# "cache_requests": cache_requests in {"true", "refresh"},
# "rewrite_requests_cache": cache_requests == "refresh",
# "delete_requests_cache": cache_requests == "delete",
# }
return request_caching_args
# return request_caching_args
......@@ -148,6 +148,16 @@ def simple_parse_args_string(args_string: Optional[str]) -> dict:
return args_dict
def request_caching_arg_to_dict(cache_requests: str) -> dict:
request_caching_args = {
"cache_requests": cache_requests in {"true", "refresh"},
"rewrite_requests_cache": cache_requests == "refresh",
"delete_requests_cache": cache_requests == "delete",
}
return request_caching_args
def join_iters(iters):
for iter in iters:
yield from iter
......@@ -543,6 +553,7 @@ def create_iterator(raw_iterator, *, rank=0, world_size=1, limit=None):
return islice(raw_iterator, rank, limit, world_size)
# TODO: why func for metric calc is here in eval utils?
def weighted_f1_score(items):
from sklearn.metrics import f1_score
......@@ -551,96 +562,3 @@ def weighted_f1_score(items):
preds = unzipped_list[1]
fscore = f1_score(golds, preds, average="weighted")
return fscore
def parse_namespace(namespace: argparse.Namespace) -> Dict[str, Any]:
"""
Convert an argparse Namespace object to a dictionary.
Handles all argument types including boolean flags, lists, and None values.
Only includes arguments that were explicitly set (ignores defaults unless used).
Args:
namespace: The argparse.Namespace object to convert
Returns:
Dictionary containing all the namespace arguments and their values
"""
config = {key: value for key, value in vars(namespace).items()}
for key in config:
# TODO: pass this list as param
if key in [
"wandb_args",
"wandb_config_args",
"hf_hub_log_args",
"metadata",
"model_args",
]:
if not isinstance(config[key], dict):
config[key] = simple_parse_args_string(config[key])
if "model_args" not in config:
config["model_args"] = {}
if "metadata" not in config:
config["metadata"] = {}
non_default_args = []
if hasattr(namespace, "_explicit_args"):
non_default_args = namespace._explicit_args
return config, non_default_args
class TrackExplicitAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
# Create a set on the namespace to track explicitly set arguments if it doesn't exist
if not hasattr(namespace, "_explicit_args"):
setattr(namespace, "_explicit_args", set())
# Record that this argument was explicitly provided
namespace._explicit_args.add(self.dest)
setattr(namespace, self.dest, values)
class TrackExplicitStoreTrue(argparse.Action):
def __init__(
self, option_strings, dest, nargs=0, const=True, default=False, **kwargs
):
# Ensure that nargs is 0, as required by store_true actions.
if nargs != 0:
raise ValueError("nargs for store_true actions must be 0")
super().__init__(
option_strings, dest, nargs=0, const=const, default=default, **kwargs
)
def __call__(self, parser, namespace, values, option_string=None):
# Initialize or update the set of explicitly provided arguments.
if not hasattr(namespace, "_explicit_args"):
setattr(namespace, "_explicit_args", set())
namespace._explicit_args.add(self.dest)
setattr(namespace, self.dest, self.const)
def non_default_update(console_dict, local_dict, non_default_args):
"""
Update local_dict by taking non-default values from console_dict.
Args:
console_dict (dict): The dictionary that potentially provides updates.
local_dict (dict): The dictionary to be updated.
non_default_args (list): The list of args passed by user in console.
Returns:
dict: The updated local_dict.
"""
result_config = {}
for key in set(console_dict.keys()).union(local_dict.keys()):
if key in non_default_args:
result_config[key] = console_dict[key]
else:
if key in local_dict:
result_config[key] = local_dict[key]
else:
result_config[key] = console_dict[key]
return result_config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment