Commit b5d16d61 authored by artemorloff's avatar artemorloff
Browse files

enable evaluation from yaml config file

parent d693dcd2
model: vllm
model_args:
pretrained: Qwen/Qwen2.5-0.5B-Instruct
dtype: bfloat16
tensor_parallel_size: 1
tasks: hellaswag,gsm8k
batch_size: 1
trust_remote_code: true
log_samples: true
output_path: ./test
apply_chat_template: true
fewshot_as_multiturn: true
limit: 5
...@@ -12,9 +12,13 @@ from lm_eval.evaluator import request_caching_arg_to_dict ...@@ -12,9 +12,13 @@ from lm_eval.evaluator import request_caching_arg_to_dict
from lm_eval.loggers import EvaluationTracker, WandbLogger from lm_eval.loggers import EvaluationTracker, WandbLogger
from lm_eval.tasks import TaskManager from lm_eval.tasks import TaskManager
from lm_eval.utils import ( from lm_eval.utils import (
TrackExplicitAction,
TrackExplicitStoreTrue,
handle_non_serializable, handle_non_serializable,
load_yaml_config,
make_table, make_table,
simple_parse_args_string, non_default_update,
parse_namespace,
) )
...@@ -83,13 +87,28 @@ def check_argument_types(parser: argparse.ArgumentParser): ...@@ -83,13 +87,28 @@ def check_argument_types(parser: argparse.ArgumentParser):
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument( parser.add_argument(
"--model", "-m", type=str, default="hf", help="Name of model e.g. `hf`" "--config",
"-C",
default=None,
type=str,
metavar="DIR/file.yaml",
action=TrackExplicitAction,
help="Path to config with all arguments for `lm-eval`",
)
parser.add_argument(
"--model",
"-m",
type=str,
default="hf",
action=TrackExplicitAction,
help="Name of model e.g. `hf`",
) )
parser.add_argument( parser.add_argument(
"--tasks", "--tasks",
"-t", "-t",
default=None, default=None,
type=str, type=str,
action=TrackExplicitAction,
metavar="task1,task2", metavar="task1,task2",
help="Comma-separated list of task names or task groupings to evaluate on.\nTo get full list of tasks, use one of the commands `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above", help="Comma-separated list of task names or task groupings to evaluate on.\nTo get full list of tasks, use one of the commands `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above",
) )
...@@ -97,6 +116,7 @@ def setup_parser() -> argparse.ArgumentParser: ...@@ -97,6 +116,7 @@ def setup_parser() -> argparse.ArgumentParser:
"--model_args", "--model_args",
"-a", "-a",
default="", default="",
action=TrackExplicitAction,
type=try_parse_json, type=try_parse_json,
help="""Comma separated string or JSON formatted arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32` or '{"pretrained":"EleutherAI/pythia-160m","dtype":"float32"}'""", help="""Comma separated string or JSON formatted arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32` or '{"pretrained":"EleutherAI/pythia-160m","dtype":"float32"}'""",
) )
...@@ -105,6 +125,7 @@ def setup_parser() -> argparse.ArgumentParser: ...@@ -105,6 +125,7 @@ def setup_parser() -> argparse.ArgumentParser:
"-f", "-f",
type=int, type=int,
default=None, default=None,
action=TrackExplicitAction,
metavar="N", metavar="N",
help="Number of examples in few-shot context", help="Number of examples in few-shot context",
) )
...@@ -112,6 +133,7 @@ def setup_parser() -> argparse.ArgumentParser: ...@@ -112,6 +133,7 @@ def setup_parser() -> argparse.ArgumentParser:
"--batch_size", "--batch_size",
"-b", "-b",
type=str, type=str,
action=TrackExplicitAction,
default=1, default=1,
metavar="auto|auto:N|N", metavar="auto|auto:N|N",
help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.", help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.",
...@@ -120,6 +142,7 @@ def setup_parser() -> argparse.ArgumentParser: ...@@ -120,6 +142,7 @@ def setup_parser() -> argparse.ArgumentParser:
"--max_batch_size", "--max_batch_size",
type=int, type=int,
default=None, default=None,
action=TrackExplicitAction,
metavar="N", metavar="N",
help="Maximal batch size to try with --batch_size auto.", help="Maximal batch size to try with --batch_size auto.",
) )
...@@ -127,6 +150,7 @@ def setup_parser() -> argparse.ArgumentParser: ...@@ -127,6 +150,7 @@ def setup_parser() -> argparse.ArgumentParser:
"--device", "--device",
type=str, type=str,
default=None, default=None,
action=TrackExplicitAction,
help="Device to use (e.g. cuda, cuda:0, cpu).", help="Device to use (e.g. cuda, cuda:0, cpu).",
) )
parser.add_argument( parser.add_argument(
...@@ -134,6 +158,7 @@ def setup_parser() -> argparse.ArgumentParser: ...@@ -134,6 +158,7 @@ def setup_parser() -> argparse.ArgumentParser:
"-o", "-o",
default=None, default=None,
type=str, type=str,
action=TrackExplicitAction,
metavar="DIR|DIR/file.json", metavar="DIR|DIR/file.json",
help="The path to the output file where the result metrics will be saved. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.", help="The path to the output file where the result metrics will be saved. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
) )
...@@ -142,6 +167,7 @@ def setup_parser() -> argparse.ArgumentParser: ...@@ -142,6 +167,7 @@ def setup_parser() -> argparse.ArgumentParser:
"-L", "-L",
type=float, type=float,
default=None, default=None,
action=TrackExplicitAction,
metavar="N|0<N<1", metavar="N|0<N<1",
help="Limit the number of examples per task. " help="Limit the number of examples per task. "
"If <1, limit is a percentage of the total number of examples.", "If <1, limit is a percentage of the total number of examples.",
...@@ -151,6 +177,7 @@ def setup_parser() -> argparse.ArgumentParser: ...@@ -151,6 +177,7 @@ def setup_parser() -> argparse.ArgumentParser:
"-E", "-E",
default=None, default=None,
type=str, type=str,
action=TrackExplicitAction,
metavar="/path/to/json", metavar="/path/to/json",
help='JSON string or path to JSON file containing doc indices of selected examples to test. Format: {"task_name":[indices],...}', help='JSON string or path to JSON file containing doc indices of selected examples to test. Format: {"task_name":[indices],...}',
) )
...@@ -158,6 +185,7 @@ def setup_parser() -> argparse.ArgumentParser: ...@@ -158,6 +185,7 @@ def setup_parser() -> argparse.ArgumentParser:
"--use_cache", "--use_cache",
"-c", "-c",
type=str, type=str,
action=TrackExplicitAction,
default=None, default=None,
metavar="DIR", metavar="DIR",
help="A path to a sqlite db file for caching model responses. `None` if not caching.", help="A path to a sqlite db file for caching model responses. `None` if not caching.",
...@@ -166,25 +194,26 @@ def setup_parser() -> argparse.ArgumentParser: ...@@ -166,25 +194,26 @@ def setup_parser() -> argparse.ArgumentParser:
"--cache_requests", "--cache_requests",
type=str, type=str,
default=None, default=None,
action=TrackExplicitAction,
choices=["true", "refresh", "delete"], choices=["true", "refresh", "delete"],
help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.", help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.",
) )
parser.add_argument( parser.add_argument(
"--check_integrity", "--check_integrity",
action="store_true", action=TrackExplicitStoreTrue,
help="Whether to run the relevant part of the test suite for the tasks.", help="Whether to run the relevant part of the test suite for the tasks.",
) )
parser.add_argument( parser.add_argument(
"--write_out", "--write_out",
"-w", "-w",
action="store_true", action=TrackExplicitStoreTrue,
default=False, default=False,
help="Prints the prompt for the first few documents.", help="Prints the prompt for the first few documents.",
) )
parser.add_argument( parser.add_argument(
"--log_samples", "--log_samples",
"-s", "-s",
action="store_true", action=TrackExplicitStoreTrue,
default=False, default=False,
help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis. Use with --output_path.", help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis. Use with --output_path.",
) )
...@@ -192,12 +221,14 @@ def setup_parser() -> argparse.ArgumentParser: ...@@ -192,12 +221,14 @@ def setup_parser() -> argparse.ArgumentParser:
"--system_instruction", "--system_instruction",
type=str, type=str,
default=None, default=None,
action=TrackExplicitAction,
help="System instruction to be used in the prompt", help="System instruction to be used in the prompt",
) )
parser.add_argument( parser.add_argument(
"--apply_chat_template", "--apply_chat_template",
type=str, type=str,
nargs="?", nargs="?",
action=TrackExplicitAction,
const=True, const=True,
default=False, default=False,
help=( help=(
...@@ -209,13 +240,13 @@ def setup_parser() -> argparse.ArgumentParser: ...@@ -209,13 +240,13 @@ def setup_parser() -> argparse.ArgumentParser:
) )
parser.add_argument( parser.add_argument(
"--fewshot_as_multiturn", "--fewshot_as_multiturn",
action="store_true", action=TrackExplicitStoreTrue,
default=False, default=False,
help="If True, uses the fewshot as a multi-turn conversation", help="If True, uses the fewshot as a multi-turn conversation",
) )
parser.add_argument( parser.add_argument(
"--show_config", "--show_config",
action="store_true", action=TrackExplicitStoreTrue,
default=False, default=False,
help="If True, shows the the full config of all tasks at the end of the evaluation.", help="If True, shows the the full config of all tasks at the end of the evaluation.",
) )
...@@ -223,6 +254,7 @@ def setup_parser() -> argparse.ArgumentParser: ...@@ -223,6 +254,7 @@ def setup_parser() -> argparse.ArgumentParser:
"--include_path", "--include_path",
type=str, type=str,
default=None, default=None,
action=TrackExplicitAction,
metavar="DIR", metavar="DIR",
help="Additional path to include if there are external tasks to include.", help="Additional path to include if there are external tasks to include.",
) )
...@@ -230,6 +262,7 @@ def setup_parser() -> argparse.ArgumentParser: ...@@ -230,6 +262,7 @@ def setup_parser() -> argparse.ArgumentParser:
"--gen_kwargs", "--gen_kwargs",
type=try_parse_json, type=try_parse_json,
default=None, default=None,
action=TrackExplicitAction,
help=( help=(
"Either comma delimited string or JSON formatted arguments for model generation on greedy_until tasks," "Either comma delimited string or JSON formatted arguments for model generation on greedy_until tasks,"
""" e.g. '{"temperature":0.7,"until":["hello"]}' or temperature=0,top_p=0.1.""" """ e.g. '{"temperature":0.7,"until":["hello"]}' or temperature=0,top_p=0.1."""
...@@ -240,6 +273,7 @@ def setup_parser() -> argparse.ArgumentParser: ...@@ -240,6 +273,7 @@ def setup_parser() -> argparse.ArgumentParser:
"-v", "-v",
type=str.upper, type=str.upper,
default=None, default=None,
action=TrackExplicitAction,
metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG", metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG",
help="(Deprecated) Controls logging verbosity level. Use the `LOGLEVEL` environment variable instead. Set to DEBUG for detailed output when testing or adding new task configurations.", help="(Deprecated) Controls logging verbosity level. Use the `LOGLEVEL` environment variable instead. Set to DEBUG for detailed output when testing or adding new task configurations.",
) )
...@@ -247,24 +281,27 @@ def setup_parser() -> argparse.ArgumentParser: ...@@ -247,24 +281,27 @@ def setup_parser() -> argparse.ArgumentParser:
"--wandb_args", "--wandb_args",
type=str, type=str,
default="", default="",
action=TrackExplicitAction,
help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval", help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval",
) )
parser.add_argument( parser.add_argument(
"--wandb_config_args", "--wandb_config_args",
type=str, type=str,
default="", default="",
action=TrackExplicitAction,
help="Comma separated string arguments passed to wandb.config.update. Use this to trace parameters that aren't already traced by default. eg. `lr=0.01,repeats=3", help="Comma separated string arguments passed to wandb.config.update. Use this to trace parameters that aren't already traced by default. eg. `lr=0.01,repeats=3",
) )
parser.add_argument( parser.add_argument(
"--hf_hub_log_args", "--hf_hub_log_args",
type=str, type=str,
default="", default="",
action=TrackExplicitAction,
help="Comma separated string arguments passed to Hugging Face Hub's log function, e.g. `hub_results_org=EleutherAI,hub_repo_name=lm-eval-results`", help="Comma separated string arguments passed to Hugging Face Hub's log function, e.g. `hub_results_org=EleutherAI,hub_repo_name=lm-eval-results`",
) )
parser.add_argument( parser.add_argument(
"--predict_only", "--predict_only",
"-x", "-x",
action="store_true", action=TrackExplicitStoreTrue,
default=False, default=False,
help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.", help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
) )
...@@ -272,6 +309,7 @@ def setup_parser() -> argparse.ArgumentParser: ...@@ -272,6 +309,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument( parser.add_argument(
"--seed", "--seed",
type=partial(_int_or_none_list_arg_type, 3, 4, default_seed_string), type=partial(_int_or_none_list_arg_type, 3, 4, default_seed_string),
action=TrackExplicitAction,
default=default_seed_string, # for backward compatibility default=default_seed_string, # for backward compatibility
help=( help=(
"Set seed for python's random, numpy, torch, and fewshot sampling.\n" "Set seed for python's random, numpy, torch, and fewshot sampling.\n"
...@@ -286,18 +324,19 @@ def setup_parser() -> argparse.ArgumentParser: ...@@ -286,18 +324,19 @@ def setup_parser() -> argparse.ArgumentParser:
) )
parser.add_argument( parser.add_argument(
"--trust_remote_code", "--trust_remote_code",
action="store_true", action=TrackExplicitStoreTrue,
help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub", help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
) )
parser.add_argument( parser.add_argument(
"--confirm_run_unsafe_code", "--confirm_run_unsafe_code",
action="store_true", action=TrackExplicitStoreTrue,
help="Confirm that you understand the risks of running unsafe code for tasks that require it", help="Confirm that you understand the risks of running unsafe code for tasks that require it",
) )
parser.add_argument( parser.add_argument(
"--metadata", "--metadata",
type=json.loads, type=json.loads,
default=None, default=None,
action=TrackExplicitAction,
help="""JSON string metadata to pass to task configs, for example '{"max_seq_lengths":[4096,8192]}'. Will be merged with model_args. Can also be set in task config.""", help="""JSON string metadata to pass to task configs, for example '{"max_seq_lengths":[4096,8192]}'. Will be merged with model_args. Can also be set in task config.""",
) )
return parser return parser
...@@ -314,96 +353,96 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -314,96 +353,96 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
parser = setup_parser() parser = setup_parser()
args = parse_eval_args(parser) args = parse_eval_args(parser)
# get namespace from console, including config arg
config, non_default_args = parse_namespace(args)
# if config is passed, load it
if config.get("config", False):
local_config = load_yaml_config(yaml_path=config["config"])
config = non_default_update(config, local_config, non_default_args)
if args.wandb_args: if args.wandb_args:
wandb_args_dict = simple_parse_args_string(args.wandb_args) wandb_logger = WandbLogger(config["wandb_args"], config["wandb_config_args"])
wandb_config_args_dict = simple_parse_args_string(args.wandb_config_args)
wandb_logger = WandbLogger(wandb_args_dict, wandb_config_args_dict)
utils.setup_logging(args.verbosity) utils.setup_logging(config["verbosity"])
# utils.setup_logging(args.verbosity)
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
# update the evaluation tracker args with the output path and the HF token # update the evaluation tracker args with the output path and the HF token
if args.output_path: if config["output_path"]:
args.hf_hub_log_args += f",output_path={args.output_path}" config.setdefault("hf_hub_log_args", {})["output_path"] = config["output_path"]
if os.environ.get("HF_TOKEN", None): if os.environ.get("HF_TOKEN", None):
args.hf_hub_log_args += f",token={os.environ.get('HF_TOKEN')}" config.setdefault("hf_hub_log_args", {})["token"] = os.environ.get("HF_TOKEN")
evaluation_tracker_args = simple_parse_args_string(args.hf_hub_log_args) evaluation_tracker_args = config["hf_hub_log_args"]
evaluation_tracker = EvaluationTracker(**evaluation_tracker_args) evaluation_tracker = EvaluationTracker(**evaluation_tracker_args)
if args.predict_only: if config["predict_only"]:
args.log_samples = True config["log_samples"] = True
if (args.log_samples or args.predict_only) and not args.output_path:
if (config["log_samples"] or config["predict_only"]) and not config["output_path"]:
raise ValueError( raise ValueError(
"Specify --output_path if providing --log_samples or --predict_only" "Specify --output_path if providing --log_samples or --predict_only"
) )
if args.fewshot_as_multiturn and args.apply_chat_template is False: if config["fewshot_as_multiturn"] and config["apply_chat_template"] is False:
raise ValueError( raise ValueError(
"When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set (either to `True` or to the chosen template name)." "When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set (either to `True` or to the chosen template name)."
) )
if args.include_path is not None: if config["include_path"] is not None:
eval_logger.info(f"Including path: {args.include_path}") eval_logger.info(f"Including path: {config['include_path']}")
metadata = (
simple_parse_args_string(args.model_args) metadata = (config["model_args"]) | (config["metadata"])
if isinstance(args.model_args, str)
else args.model_args
if isinstance(args.model_args, dict)
else {}
) | (
args.metadata
if isinstance(args.metadata, dict)
else simple_parse_args_string(args.metadata)
)
task_manager = TaskManager(include_path=args.include_path, metadata=metadata) task_manager = TaskManager(include_path=config["include_path"], metadata=metadata)
if "push_samples_to_hub" in evaluation_tracker_args and not args.log_samples: if "push_samples_to_hub" in evaluation_tracker_args and not config["log_samples"]:
eval_logger.warning( eval_logger.warning(
"Pushing samples to the Hub requires --log_samples to be set. Samples will not be pushed to the Hub." "Pushing samples to the Hub requires --log_samples to be set. Samples will not be pushed to the Hub."
) )
if args.limit: if config["limit"]:
eval_logger.warning( eval_logger.warning(
" --limit SHOULD ONLY BE USED FOR TESTING." " --limit SHOULD ONLY BE USED FOR TESTING."
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT." "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
) )
if args.samples: if config["samples"]:
assert args.limit is None, ( assert config["limit"] is None, (
"If --samples is not None, then --limit must be None." "If --samples is not None, then --limit must be None."
) )
if (samples := Path(args.samples)).is_file(): if (samples := Path(config["samples"])).is_file():
args.samples = json.loads(samples.read_text()) config["samples"] = json.loads(samples.read_text())
else: else:
args.samples = json.loads(args.samples) config["samples"] = json.loads(config["samples"])
if args.tasks is None: if config["tasks"] is None:
eval_logger.error("Need to specify task to evaluate.") eval_logger.error("Need to specify task to evaluate.")
sys.exit() sys.exit()
elif args.tasks == "list": elif config["tasks"] == "list":
print(task_manager.list_all_tasks()) print(task_manager.list_all_tasks())
sys.exit() sys.exit()
elif args.tasks == "list_groups": elif config["tasks"] == "list_groups":
print(task_manager.list_all_tasks(list_subtasks=False, list_tags=False)) print(task_manager.list_all_tasks(list_subtasks=False, list_tags=False))
sys.exit() sys.exit()
elif args.tasks == "list_tags": elif config["tasks"] == "list_tags":
print(task_manager.list_all_tasks(list_groups=False, list_subtasks=False)) print(task_manager.list_all_tasks(list_groups=False, list_subtasks=False))
sys.exit() sys.exit()
elif args.tasks == "list_subtasks": elif config["tasks"] == "list_subtasks":
print(task_manager.list_all_tasks(list_groups=False, list_tags=False)) print(task_manager.list_all_tasks(list_groups=False, list_tags=False))
sys.exit() sys.exit()
else: else:
if os.path.isdir(args.tasks): if os.path.isdir(config["tasks"]):
import glob import glob
task_names = [] task_names = []
yaml_path = os.path.join(args.tasks, "*.yaml") yaml_path = os.path.join(config["tasks"], "*.yaml")
for yaml_file in glob.glob(yaml_path): for yaml_file in glob.glob(yaml_path):
config = utils.load_yaml_config(yaml_file) config = utils.load_yaml_config(yaml_file)
task_names.append(config) task_names.append(config)
else: else:
task_list = args.tasks.split(",") task_list = config["tasks"].split(",")
task_names = task_manager.match_tasks(task_list) task_names = task_manager.match_tasks(task_list)
for task in [task for task in task_list if task not in task_names]: for task in [task for task in task_list if task not in task_names]:
if os.path.isfile(task): if os.path.isfile(task):
...@@ -424,7 +463,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -424,7 +463,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
) )
# Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args # Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args
if args.trust_remote_code: if config["trust_remote_code"]:
eval_logger.info( eval_logger.info(
"Passed `--trust_remote_code`, setting environment variable `HF_DATASETS_TRUST_REMOTE_CODE=true`" "Passed `--trust_remote_code`, setting environment variable `HF_DATASETS_TRUST_REMOTE_CODE=true`"
) )
...@@ -435,7 +474,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -435,7 +474,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
args.model_args = args.model_args + ",trust_remote_code=True" config.setdefault("model_args", {})["trust_remote_code"] = True
( (
eval_logger.info(f"Selected Tasks: {task_names}") eval_logger.info(f"Selected Tasks: {task_names}")
if eval_logger.getEffectiveLevel() >= logging.INFO if eval_logger.getEffectiveLevel() >= logging.INFO
...@@ -443,56 +482,56 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -443,56 +482,56 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
) )
request_caching_args = request_caching_arg_to_dict( request_caching_args = request_caching_arg_to_dict(
cache_requests=args.cache_requests cache_requests=config["cache_requests"]
) )
results = evaluator.simple_evaluate( results = evaluator.simple_evaluate(
model=args.model, model=config["model"],
model_args=args.model_args, model_args=config["model_args"],
tasks=task_names, tasks=task_names,
num_fewshot=args.num_fewshot, num_fewshot=config["num_fewshot"],
batch_size=args.batch_size, batch_size=config["batch_size"],
max_batch_size=args.max_batch_size, max_batch_size=config["max_batch_size"],
device=args.device, device=config["device"],
use_cache=args.use_cache, use_cache=config["use_cache"],
limit=args.limit, limit=config["limit"],
samples=args.samples, samples=config["samples"],
check_integrity=args.check_integrity, check_integrity=config["check_integrity"],
write_out=args.write_out, write_out=config["write_out"],
log_samples=args.log_samples, log_samples=config["log_samples"],
evaluation_tracker=evaluation_tracker, evaluation_tracker=evaluation_tracker,
system_instruction=args.system_instruction, system_instruction=config["system_instruction"],
apply_chat_template=args.apply_chat_template, apply_chat_template=config["apply_chat_template"],
fewshot_as_multiturn=args.fewshot_as_multiturn, fewshot_as_multiturn=config["fewshot_as_multiturn"],
gen_kwargs=args.gen_kwargs, gen_kwargs=config["gen_kwargs"],
task_manager=task_manager, task_manager=task_manager,
predict_only=args.predict_only, predict_only=config["predict_only"],
random_seed=args.seed[0], random_seed=config["seed"][0],
numpy_random_seed=args.seed[1], numpy_random_seed=config["seed"][1],
torch_random_seed=args.seed[2], torch_random_seed=config["seed"][2],
fewshot_random_seed=args.seed[3], fewshot_random_seed=config["seed"][3],
confirm_run_unsafe_code=args.confirm_run_unsafe_code, confirm_run_unsafe_code=config["confirm_run_unsafe_code"],
metadata=metadata, metadata=metadata,
**request_caching_args, **request_caching_args,
) )
if results is not None: if results is not None:
if args.log_samples: if config["log_samples"]:
samples = results.pop("samples") samples = results.pop("samples")
dumped = json.dumps( dumped = json.dumps(
results, indent=2, default=handle_non_serializable, ensure_ascii=False results, indent=2, default=handle_non_serializable, ensure_ascii=False
) )
if args.show_config: if config["show_config"]:
print(dumped) print(dumped)
batch_sizes = ",".join(map(str, results["config"]["batch_sizes"])) batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))
# Add W&B logging # Add W&B logging
if args.wandb_args: if config["wandb_args"]:
try: try:
wandb_logger.post_init(results) wandb_logger.post_init(results)
wandb_logger.log_eval_result() wandb_logger.log_eval_result()
if args.log_samples: if config["log_samples"]:
wandb_logger.log_eval_samples(samples) wandb_logger.log_eval_samples(samples)
except Exception as e: except Exception as e:
eval_logger.info(f"Logging to Weights and Biases failed due to {e}") eval_logger.info(f"Logging to Weights and Biases failed due to {e}")
...@@ -501,8 +540,8 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -501,8 +540,8 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
results=results, samples=samples if args.log_samples else None results=results, samples=samples if args.log_samples else None
) )
if args.log_samples: if config["log_samples"]:
for task_name, config in results["configs"].items(): for task_name, _ in results["configs"].items():
evaluation_tracker.save_results_samples( evaluation_tracker.save_results_samples(
task_name=task_name, samples=samples[task_name] task_name=task_name, samples=samples[task_name]
) )
...@@ -514,14 +553,14 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -514,14 +553,14 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
evaluation_tracker.recreate_metadata_card() evaluation_tracker.recreate_metadata_card()
print( print(
f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, " f"{config['model']} ({config['model_args']}), gen_kwargs: ({config['gen_kwargs']}), limit: {config['limit']}, num_fewshot: {config['num_fewshot']}, "
f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}" f"batch_size: {config['batch_size']}{f' ({batch_sizes})' if batch_sizes else ''}"
) )
print(make_table(results)) print(make_table(results))
if "groups" in results: if "groups" in results:
print(make_table(results, "groups")) print(make_table(results, "groups"))
if args.wandb_args: if config["wandb_args"]:
# Tear down wandb run once all the logging is done. # Tear down wandb run once all the logging is done.
wandb_logger.run.finish() wandb_logger.run.finish()
......
import argparse
import collections import collections
import fnmatch import fnmatch
import functools import functools
...@@ -11,7 +12,7 @@ import re ...@@ -11,7 +12,7 @@ import re
from dataclasses import asdict, is_dataclass from dataclasses import asdict, is_dataclass
from itertools import islice from itertools import islice
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Generator, List, Optional, Tuple from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
import numpy as np import numpy as np
import yaml import yaml
...@@ -550,3 +551,96 @@ def weighted_f1_score(items): ...@@ -550,3 +551,96 @@ def weighted_f1_score(items):
preds = unzipped_list[1] preds = unzipped_list[1]
fscore = f1_score(golds, preds, average="weighted") fscore = f1_score(golds, preds, average="weighted")
return fscore return fscore
def parse_namespace(namespace: argparse.Namespace) -> Dict[str, Any]:
"""
Convert an argparse Namespace object to a dictionary.
Handles all argument types including boolean flags, lists, and None values.
Only includes arguments that were explicitly set (ignores defaults unless used).
Args:
namespace: The argparse.Namespace object to convert
Returns:
Dictionary containing all the namespace arguments and their values
"""
config = {key: value for key, value in vars(namespace).items()}
for key in config:
# TODO: pass this list as param
if key in [
"wandb_args",
"wandb_config_args",
"hf_hub_log_args",
"metadata",
"model_args",
]:
if not isinstance(config[key], dict):
config[key] = simple_parse_args_string(config[key])
if "model_args" not in config:
config["model_args"] = {}
if "metadata" not in config:
config["metadata"] = {}
non_default_args = []
if hasattr(namespace, "_explicit_args"):
non_default_args = namespace._explicit_args
return config, non_default_args
class TrackExplicitAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
# Create a set on the namespace to track explicitly set arguments if it doesn't exist
if not hasattr(namespace, "_explicit_args"):
setattr(namespace, "_explicit_args", set())
# Record that this argument was explicitly provided
namespace._explicit_args.add(self.dest)
setattr(namespace, self.dest, values)
class TrackExplicitStoreTrue(argparse.Action):
def __init__(
self, option_strings, dest, nargs=0, const=True, default=False, **kwargs
):
# Ensure that nargs is 0, as required by store_true actions.
if nargs != 0:
raise ValueError("nargs for store_true actions must be 0")
super().__init__(
option_strings, dest, nargs=0, const=const, default=default, **kwargs
)
def __call__(self, parser, namespace, values, option_string=None):
# Initialize or update the set of explicitly provided arguments.
if not hasattr(namespace, "_explicit_args"):
setattr(namespace, "_explicit_args", set())
namespace._explicit_args.add(self.dest)
setattr(namespace, self.dest, self.const)
def non_default_update(console_dict, local_dict, non_default_args):
"""
Update local_dict by taking non-default values from console_dict.
Args:
console_dict (dict): The dictionary that potentially provides updates.
local_dict (dict): The dictionary to be updated.
non_default_args (list): The list of args passed by user in console.
Returns:
dict: The updated local_dict.
"""
result_config = {}
for key in set(console_dict.keys()).union(local_dict.keys()):
if key in non_default_args:
result_config[key] = console_dict[key]
else:
if key in local_dict:
result_config[key] = local_dict[key]
else:
result_config[key] = console_dict[key]
return result_config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment