Commit febdcc5b authored by Baber's avatar Baber
Browse files

add subcommands

parent 30fa3c7c
from typing import Union
import argparse
from lm_eval._cli import CLIParser from lm_eval._cli import CLIParser
def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: def cli_evaluate() -> None:
"""Main CLI entry point with subcommand and legacy support.""" """Main CLI entry point with subcommand and legacy support."""
parser = CLIParser() parser = CLIParser()
args = parser.parse_args()
if args is None: parser.execute(args)
# Parse from command line
parser.execute()
else:
# External call with pre-parsed args - use legacy mode
parser._handle_legacy_mode(args)
if __name__ == "__main__": if __name__ == "__main__":
cli_evaluate() cli_evaluate()
\ No newline at end of file
...@@ -3,17 +3,18 @@ CLI subcommands for the Language Model Evaluation Harness. ...@@ -3,17 +3,18 @@ CLI subcommands for the Language Model Evaluation Harness.
""" """
from lm_eval._cli.base import SubCommand from lm_eval._cli.base import SubCommand
from lm_eval._cli.cache import CacheCommand from lm_eval._cli.cache import Cache
from lm_eval._cli.evaluate import EvaluateCommand from lm_eval._cli.cli import CLIParser
from lm_eval._cli.list import ListCommand from lm_eval._cli.list import ListCommand
from lm_eval._cli.parser import CLIParser from lm_eval._cli.run import Run
from lm_eval._cli.validate import ValidateCommand from lm_eval._cli.validate import ValidateCommand
__all__ = [ __all__ = [
"SubCommand", "SubCommand",
"EvaluateCommand", "Run",
"ListCommand", "ListCommand",
"ValidateCommand", "ValidateCommand",
"CacheCommand", "Cache",
"CLIParser", "CLIParser",
] ]
\ No newline at end of file
import argparse import argparse
import json
import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Union
def try_parse_json(value: str) -> Union[str, dict, None]:
if value is None:
return None
try:
return json.loads(value)
except json.JSONDecodeError:
if "{" in value:
raise argparse.ArgumentTypeError(
f"Invalid JSON: {value}. Hint: Use double quotes for JSON strings."
)
return value
def _int_or_none_list_arg_type(
min_len: int, max_len: int, defaults: str, value: str, split_char: str = ","
):
def parse_value(item):
item = item.strip().lower()
if item == "none":
return None
try:
return int(item)
except ValueError:
raise argparse.ArgumentTypeError(f"{item} is not an integer or None")
items = [parse_value(v) for v in value.split(split_char)]
num_items = len(items)
if num_items == 1:
items = items * max_len
elif num_items < min_len or num_items > max_len:
raise argparse.ArgumentTypeError(
f"Argument requires {max_len} integers or None, separated by '{split_char}'"
)
elif num_items != max_len:
logging.warning(
f"Argument requires {max_len} integers or None, separated by '{split_char}'. "
"Missing values will be filled with defaults."
)
default_items = [parse_value(v) for v in defaults.split(split_char)]
items.extend(default_items[num_items:])
return items
class SubCommand(ABC): class SubCommand(ABC):
......
...@@ -3,7 +3,7 @@ import argparse ...@@ -3,7 +3,7 @@ import argparse
from lm_eval._cli.base import SubCommand from lm_eval._cli.base import SubCommand
class CacheCommand(SubCommand): class Cache(SubCommand):
"""Command for cache management.""" """Command for cache management."""
def __init__(self, subparsers: argparse._SubParsersAction, *args, **kwargs): def __init__(self, subparsers: argparse._SubParsersAction, *args, **kwargs):
...@@ -43,28 +43,4 @@ Examples: ...@@ -43,28 +43,4 @@ Examples:
def execute(self, args: argparse.Namespace) -> None: def execute(self, args: argparse.Namespace) -> None:
"""Execute the cache command.""" """Execute the cache command."""
import os raise NotImplementedError
if args.action == "clear":
if args.cache_path:
if os.path.exists(args.cache_path):
if os.path.isdir(args.cache_path):
import shutil
shutil.rmtree(args.cache_path)
else:
os.remove(args.cache_path)
print(f"✅ Cache cleared: {args.cache_path}")
else:
print(f"❌ Cache path not found: {args.cache_path}")
else:
print("❌ Please specify --cache_path")
elif args.action == "info":
if args.cache_path and os.path.exists(args.cache_path):
import os
size = os.path.getsize(args.cache_path)
print(f"Cache: {args.cache_path}")
print(f"Size: {size} bytes")
else:
print("❌ Cache path not found or not specified")
import argparse
import sys
from lm_eval._cli.cache import Cache
from lm_eval._cli.run import Run
from lm_eval._cli.list import ListCommand
from lm_eval._cli.validate import ValidateCommand
class CLIParser:
"""Main CLI parser class that manages all subcommands."""
def __init__(self):
self._parser = argparse.ArgumentParser(
prog="lm-eval",
description="Language Model Evaluation Harness",
formatter_class=argparse.RawTextHelpFormatter,
)
self._parser.set_defaults(func=lambda args: self._parser.print_help())
self._subparsers = self._parser.add_subparsers(
dest="command", help="Available commands", metavar="COMMAND"
)
Run.create(self._subparsers)
ListCommand.create(self._subparsers)
ValidateCommand.create(self._subparsers)
Cache.create(self._subparsers)
def parse_args(self) -> argparse.Namespace:
"""Parse arguments using the main parser."""
if len(sys.argv) > 2 and sys.argv[1] not in self._subparsers.choices:
# Arguments provided but no valid subcommand - insert 'run'
sys.argv.insert(1, "run")
return self._parser.parse_args()
def execute(self, args: argparse.Namespace) -> None:
"""Main execution method that handles subcommands and legacy support."""
# Handle legacy task listing
if hasattr(args, "tasks") and args.tasks in [
"list",
"list_groups",
"list_subtasks",
"list_tags",
]:
print(
f"'--tasks {args.tasks}' is no longer supported.\n"
f"Use the 'list' command instead:\n",
file=sys.stderr,
)
# Show list command help
list_parser = self._subparsers.choices["list"]
list_parser.print_help()
sys.exit(1)
args.func(args)
...@@ -19,14 +19,11 @@ Examples: ...@@ -19,14 +19,11 @@ Examples:
lm-eval list groups # List task groups only lm-eval list groups # List task groups only
lm-eval list subtasks # List subtasks only lm-eval list subtasks # List subtasks only
lm-eval list tags # List available tags lm-eval list tags # List available tags
lm-eval list tasks --include_path /path/to/external/tasks
""", """,
formatter_class=argparse.RawDescriptionHelpFormatter, formatter_class=argparse.RawDescriptionHelpFormatter,
) )
# Add command-specific arguments
self._add_args(parser) self._add_args(parser)
# Set the function to execute for this subcommand
parser.set_defaults(func=self.execute) parser.set_defaults(func=self.execute)
def _add_args(self, parser: argparse.ArgumentParser) -> None: def _add_args(self, parser: argparse.ArgumentParser) -> None:
......
import argparse
import sys
from typing import Dict, Type
from lm_eval._cli.base import SubCommand
from lm_eval._cli.cache import CacheCommand
from lm_eval._cli.evaluate import EvaluateCommand
from lm_eval._cli.list import ListCommand
from lm_eval._cli.validate import ValidateCommand
def check_argument_types(parser: argparse.ArgumentParser):
"""
Check to make sure all CLI args are typed, raises error if not
"""
for action in parser._actions:
# Skip help, subcommands, and const actions
if action.dest in ["help", "command"] or action.const is not None:
continue
if action.type is None:
raise ValueError(f"Argument '{action.dest}' doesn't have a type specified.")
else:
continue
class CLIParser:
"""Main CLI parser class that manages all subcommands."""
def __init__(self):
self.parser = None
self.subparsers = None
self.legacy_parser = None
self.command_instances: Dict[str, SubCommand] = {}
def setup_parser(self) -> argparse.ArgumentParser:
"""Set up the main parser with subcommands."""
if self.parser is not None:
return self.parser
self.parser = argparse.ArgumentParser(
prog="lm-eval",
description="Language Model Evaluation Harness",
formatter_class=argparse.RawTextHelpFormatter,
)
# Create subparsers
self.subparsers = self.parser.add_subparsers(
dest="command", help="Available commands", metavar="COMMAND"
)
# Create and register all command instances
self.command_instances = {
"evaluate": EvaluateCommand.create(self.subparsers),
"list": ListCommand.create(self.subparsers),
"validate": ValidateCommand.create(self.subparsers),
"cache": CacheCommand.create(self.subparsers),
}
return self.parser
def setup_legacy_parser(self) -> argparse.ArgumentParser:
"""Set up legacy parser for backward compatibility."""
if self.legacy_parser is not None:
return self.legacy_parser
self.legacy_parser = argparse.ArgumentParser(
formatter_class=argparse.RawTextHelpFormatter
)
# For legacy mode, we just need to add the evaluate command's arguments
# without the subcommand structure. We'll create a temporary instance.
from lm_eval._cli.evaluate import EvaluateCommand as EvalCmd
# Create a minimal instance just to get the arguments
temp_cmd = object.__new__(EvalCmd)
temp_cmd._add_args(self.legacy_parser)
return self.legacy_parser
def parse_args(self, args=None) -> argparse.Namespace:
"""Parse arguments using the main parser."""
parser = self.setup_parser()
check_argument_types(parser)
return parser.parse_args(args)
def parse_legacy_args(self, args=None) -> argparse.Namespace:
"""Parse arguments using the legacy parser."""
parser = self.setup_legacy_parser()
check_argument_types(parser)
return parser.parse_args(args)
def should_use_subcommand_mode(self, argv=None) -> bool:
"""Determine if we should use subcommand mode based on arguments."""
if argv is None:
argv = sys.argv[1:]
# If no arguments, show main help
if len(argv) == 0:
return True
# Check if first argument is a known subcommand
# First ensure parser is set up to populate command_instances
if not self.command_instances:
self.setup_parser()
if len(argv) > 0 and argv[0] in self.command_instances:
return True
return False
def execute(self, argv=None) -> None:
"""Main execution method that handles both subcommand and legacy modes."""
if self.should_use_subcommand_mode(argv):
# Use subcommand mode
if argv is None and len(sys.argv) == 1:
# No arguments provided, show help
self.setup_parser().print_help()
sys.exit(1)
args = self.parse_args(argv)
args.func(args)
else:
# Use legacy mode for backward compatibility
args = self.parse_legacy_args(argv)
self._handle_legacy_mode(args)
def _handle_legacy_mode(self, args: argparse.Namespace) -> None:
"""Handle legacy CLI mode for backward compatibility."""
# Handle legacy task listing
if hasattr(args, "tasks") and args.tasks in [
"list",
"list_groups",
"list_subtasks",
"list_tags",
]:
from lm_eval.tasks import TaskManager
task_manager = TaskManager(include_path=getattr(args, "include_path", None))
if args.tasks == "list":
print(task_manager.list_all_tasks())
elif args.tasks == "list_groups":
print(task_manager.list_all_tasks(list_subtasks=False, list_tags=False))
elif args.tasks == "list_subtasks":
print(task_manager.list_all_tasks(list_groups=False, list_tags=False))
elif args.tasks == "list_tags":
print(
task_manager.list_all_tasks(list_groups=False, list_subtasks=False)
)
sys.exit(0)
# Handle legacy evaluation
# Use existing instance if available, otherwise create temporary one
if "evaluate" in self.command_instances:
evaluate_cmd = self.command_instances["evaluate"]
else:
# For legacy mode, we don't need the subparser registration
# Just execute with the existing args
from lm_eval._cli.evaluate import EvaluateCommand as EvalCmd
# Create a minimal instance just for execution
evaluate_cmd = object.__new__(EvalCmd)
evaluate_cmd.execute(args)
def add_command(self, name: str, command_class: Type[SubCommand]) -> None:
"""Add a new command to the parser (for extensibility)."""
# If parser is already set up, create and register the command instance
if self.subparsers is not None:
self.command_instances[name] = command_class.create(self.subparsers)
else:
# Store class for later instantiation
if not hasattr(self, "_pending_commands"):
self._pending_commands = {}
self._pending_commands[name] = command_class
...@@ -2,28 +2,31 @@ import argparse ...@@ -2,28 +2,31 @@ import argparse
import json import json
import logging import logging
import os import os
import sys
from functools import partial from functools import partial
from pathlib import Path
from lm_eval._cli.base import SubCommand, _int_or_none_list_arg_type, try_parse_json from lm_eval._cli import SubCommand
from lm_eval._cli.utils import (
_int_or_none_list_arg_type,
request_caching_arg_to_dict,
try_parse_json,
)
class EvaluateCommand(SubCommand): class Run(SubCommand):
"""Command for running language model evaluation.""" """Command for running language model evaluation."""
def __init__(self, subparsers: argparse._SubParsersAction, *args, **kwargs): def __init__(self, subparsers: argparse._SubParsersAction, *args, **kwargs):
# Create and configure the parser # Create and configure the parser
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
parser = subparsers.add_parser( parser = subparsers.add_parser(
"evaluate", "run",
help="Run language model evaluation", help="Run language model evaluation",
description="Evaluate language models on various benchmarks and tasks.", description="Evaluate language models on various benchmarks and tasks.",
epilog=""" epilog="""
Examples: Examples:
lm-eval evaluate --model hf --model_args pretrained=gpt2 --tasks hellaswag lm-eval run --model hf --model_args pretrained=gpt2 --tasks hellaswag
lm-eval evaluate --config my_config.yaml --tasks arc_easy,arc_challenge lm-eval run --config my_config.yaml --tasks arc_easy,arc_challenge
lm-eval evaluate --model openai --tasks mmlu --num_fewshot 5 lm-eval run --model openai --tasks mmlu --num_fewshot 5
""", """,
formatter_class=argparse.RawDescriptionHelpFormatter, formatter_class=argparse.RawDescriptionHelpFormatter,
) )
...@@ -48,7 +51,7 @@ Examples: ...@@ -48,7 +51,7 @@ Examples:
"-m", "-m",
type=str, type=str,
default="hf", default="hf",
help="Name of model e.g. `hf`", help="Name of model. Default 'hf'",
) )
parser.add_argument( parser.add_argument(
"--tasks", "--tasks",
...@@ -61,7 +64,7 @@ Examples: ...@@ -61,7 +64,7 @@ Examples:
parser.add_argument( parser.add_argument(
"--model_args", "--model_args",
"-a", "-a",
default="", default=None,
type=try_parse_json, type=try_parse_json,
help="""Comma separated string or JSON formatted arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32` or '{"pretrained":"EleutherAI/pythia-160m","dtype":"float32"}'.""", help="""Comma separated string or JSON formatted arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32` or '{"pretrained":"EleutherAI/pythia-160m","dtype":"float32"}'.""",
) )
...@@ -77,9 +80,9 @@ Examples: ...@@ -77,9 +80,9 @@ Examples:
"--batch_size", "--batch_size",
"-b", "-b",
type=str, type=str,
default=1, default=argparse.SUPPRESS,
metavar="auto|auto:N|N", metavar="auto|auto:N|N",
help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.", help="Acceptable values are 'auto', 'auto:N' (recompute batchsize N times with time) or N, where N is an integer. Default 1.",
) )
parser.add_argument( parser.add_argument(
"--max_batch_size", "--max_batch_size",
...@@ -92,7 +95,7 @@ Examples: ...@@ -92,7 +95,7 @@ Examples:
"--device", "--device",
type=str, type=str,
default=None, default=None,
help="Device to use (e.g. cuda, cuda:0, cpu).", help="Device to use (e.g. cuda, cuda:0, cpu). Model defaults. Default None.",
) )
parser.add_argument( parser.add_argument(
"--output_path", "--output_path",
...@@ -115,7 +118,7 @@ Examples: ...@@ -115,7 +118,7 @@ Examples:
"--samples", "--samples",
"-E", "-E",
default=None, default=None,
type=str, type=try_parse_json,
metavar="/path/to/json", metavar="/path/to/json",
help='JSON string or path to JSON file containing doc indices of selected examples to test. Format: {"task_name":[indices],...}', help='JSON string or path to JSON file containing doc indices of selected examples to test. Format: {"task_name":[indices],...}',
) )
...@@ -129,7 +132,7 @@ Examples: ...@@ -129,7 +132,7 @@ Examples:
) )
parser.add_argument( parser.add_argument(
"--cache_requests", "--cache_requests",
type=str, type=request_caching_arg_to_dict,
default=None, default=None,
choices=["true", "refresh", "delete"], choices=["true", "refresh", "delete"],
help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.", help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.",
...@@ -137,20 +140,21 @@ Examples: ...@@ -137,20 +140,21 @@ Examples:
parser.add_argument( parser.add_argument(
"--check_integrity", "--check_integrity",
action="store_true", action="store_true",
default=argparse.SUPPRESS,
help="Whether to run the relevant part of the test suite for the tasks.", help="Whether to run the relevant part of the test suite for the tasks.",
) )
parser.add_argument( parser.add_argument(
"--write_out", "--write_out",
"-w", "-w",
action="store_true", action="store_true",
default=False, default=argparse.SUPPRESS,
help="Prints the prompt for the first few documents.", help="Prints the prompt for the first few documents.",
) )
parser.add_argument( parser.add_argument(
"--log_samples", "--log_samples",
"-s", "-s",
action="store_true", action="store_true",
default=False, default=argparse.SUPPRESS,
help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis. Use with --output_path.", help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis. Use with --output_path.",
) )
parser.add_argument( parser.add_argument(
...@@ -164,7 +168,7 @@ Examples: ...@@ -164,7 +168,7 @@ Examples:
type=str, type=str,
nargs="?", nargs="?",
const=True, const=True,
default=False, default=argparse.SUPPRESS,
help=( help=(
"If True, apply chat template to the prompt. " "If True, apply chat template to the prompt. "
"Providing `--apply_chat_template` without an argument will apply the default chat template to the prompt. " "Providing `--apply_chat_template` without an argument will apply the default chat template to the prompt. "
...@@ -175,13 +179,13 @@ Examples: ...@@ -175,13 +179,13 @@ Examples:
parser.add_argument( parser.add_argument(
"--fewshot_as_multiturn", "--fewshot_as_multiturn",
action="store_true", action="store_true",
default=False, default=argparse.SUPPRESS,
help="If True, uses the fewshot as a multi-turn conversation", help="If True, uses the fewshot as a multi-turn conversation",
) )
parser.add_argument( parser.add_argument(
"--show_config", "--show_config",
action="store_true", action="store_true",
default=False, default=argparse.SUPPRESS,
help="If True, shows the the full config of all tasks at the end of the evaluation.", help="If True, shows the the full config of all tasks at the end of the evaluation.",
) )
parser.add_argument( parser.add_argument(
...@@ -197,7 +201,7 @@ Examples: ...@@ -197,7 +201,7 @@ Examples:
default=None, default=None,
help=( help=(
"Either comma delimited string or JSON formatted arguments for model generation on greedy_until tasks," "Either comma delimited string or JSON formatted arguments for model generation on greedy_until tasks,"
""" e.g. '{"temperature":0.7,"until":["hello"]}' or temperature=0,top_p=0.1.""" """ e.g. '{"do_sample": True, temperature":0.7,"until":["hello"]}' or temperature=0,top_p=0.1."""
), ),
) )
parser.add_argument( parser.add_argument(
...@@ -211,26 +215,26 @@ Examples: ...@@ -211,26 +215,26 @@ Examples:
parser.add_argument( parser.add_argument(
"--wandb_args", "--wandb_args",
type=str, type=str,
default="", default=argparse.SUPPRESS,
help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval`", help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval`",
) )
parser.add_argument( parser.add_argument(
"--wandb_config_args", "--wandb_config_args",
type=str, type=str,
default="", default=argparse.SUPPRESS,
help="Comma separated string arguments passed to wandb.config.update. Use this to trace parameters that aren't already traced by default. eg. `lr=0.01,repeats=3`", help="Comma separated string arguments passed to wandb.config.update. Use this to trace parameters that aren't already traced by default. eg. `lr=0.01,repeats=3`",
) )
parser.add_argument( parser.add_argument(
"--hf_hub_log_args", "--hf_hub_log_args",
type=str, type=str,
default="", default=argparse.SUPPRESS,
help="Comma separated string arguments passed to Hugging Face Hub's log function, e.g. `hub_results_org=EleutherAI,hub_repo_name=lm-eval-results`", help="Comma separated string arguments passed to Hugging Face Hub's log function, e.g. `hub_results_org=EleutherAI,hub_repo_name=lm-eval-results`",
) )
parser.add_argument( parser.add_argument(
"--predict_only", "--predict_only",
"-x", "-x",
action="store_true", action="store_true",
default=False, default=argparse.SUPPRESS,
help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.", help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
) )
default_seed_string = "0,1234,1234,1234" default_seed_string = "0,1234,1234,1234"
...@@ -252,11 +256,13 @@ Examples: ...@@ -252,11 +256,13 @@ Examples:
parser.add_argument( parser.add_argument(
"--trust_remote_code", "--trust_remote_code",
action="store_true", action="store_true",
default=argparse.SUPPRESS,
help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub", help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
) )
parser.add_argument( parser.add_argument(
"--confirm_run_unsafe_code", "--confirm_run_unsafe_code",
action="store_true", action="store_true",
default=argparse.SUPPRESS,
help="Confirm that you understand the risks of running unsafe code for tasks that require it", help="Confirm that you understand the risks of running unsafe code for tasks that require it",
) )
parser.add_argument( parser.add_argument(
...@@ -268,16 +274,13 @@ Examples: ...@@ -268,16 +274,13 @@ Examples:
def execute(self, args: argparse.Namespace) -> None: def execute(self, args: argparse.Namespace) -> None:
"""Execute the evaluation command.""" """Execute the evaluation command."""
# Import here to avoid circular imports and for faster CLI loading from lm_eval.config.evaluate_config import EvaluatorConfig
from lm_eval.api.eval_config import EvaluationConfig
# Create and validate config (most validation now happens in EvaluationConfig)
# Create and validate config (validation now happens in EvaluationConfig) cfg = EvaluatorConfig.from_cli(args)
cfg = EvaluationConfig.from_cli(args)
from lm_eval import evaluator, utils from lm_eval import simple_evaluate, utils
from lm_eval.evaluator import request_caching_arg_to_dict
from lm_eval.loggers import EvaluationTracker, WandbLogger from lm_eval.loggers import EvaluationTracker, WandbLogger
from lm_eval.tasks import TaskManager
from lm_eval.utils import handle_non_serializable, make_table from lm_eval.utils import handle_non_serializable, make_table
# Set up logging # Set up logging
...@@ -298,7 +301,7 @@ Examples: ...@@ -298,7 +301,7 @@ Examples:
evaluation_tracker = EvaluationTracker(**cfg.hf_hub_log_args) evaluation_tracker = EvaluationTracker(**cfg.hf_hub_log_args)
# Create task manager (metadata already set up in config validation) # Create task manager (metadata already set up in config validation)
task_manager = TaskManager(include_path=cfg.include_path, metadata=cfg.metadata) task_manager = cfg.process_tasks()
# Validation warnings (keep these in CLI as they're logging-specific) # Validation warnings (keep these in CLI as they're logging-specific)
if "push_samples_to_hub" in cfg.hf_hub_log_args and not cfg.log_samples: if "push_samples_to_hub" in cfg.hf_hub_log_args and not cfg.log_samples:
...@@ -306,25 +309,13 @@ Examples: ...@@ -306,25 +309,13 @@ Examples:
"Pushing samples to the Hub requires --log_samples to be set." "Pushing samples to the Hub requires --log_samples to be set."
) )
if cfg.limit:
eval_logger.warning(
"--limit SHOULD ONLY BE USED FOR TESTING. "
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
# Log task selection (tasks already processed in config) # Log task selection (tasks already processed in config)
if cfg.include_path is not None: if cfg.include_path is not None:
eval_logger.info(f"Including path: {cfg.include_path}") eval_logger.info(f"Including path: {cfg.include_path}")
eval_logger.info(f"Selected Tasks: {cfg.tasks}") eval_logger.info(f"Selected Tasks: {cfg.tasks}")
# Set up caching
request_caching_args = request_caching_arg_to_dict(
cache_requests=cfg.cache_requests
)
cfg.request_caching_args = request_caching_args
# Run evaluation # Run evaluation
results = evaluator.simple_evaluate( results = simple_evaluate(
model=cfg.model, model=cfg.model,
model_args=cfg.model_args, model_args=cfg.model_args,
tasks=cfg.tasks, tasks=cfg.tasks,
...@@ -333,11 +324,11 @@ Examples: ...@@ -333,11 +324,11 @@ Examples:
max_batch_size=cfg.max_batch_size, max_batch_size=cfg.max_batch_size,
device=cfg.device, device=cfg.device,
use_cache=cfg.use_cache, use_cache=cfg.use_cache,
cache_requests=cfg.request_caching_args.get("cache_requests", False), cache_requests=cfg.cache_requests.get("cache_requests", False),
rewrite_requests_cache=cfg.request_caching_args.get( rewrite_requests_cache=cfg.cache_requests.get(
"rewrite_requests_cache", False "rewrite_requests_cache", False
), ),
delete_requests_cache=cfg.request_caching_args.get( delete_requests_cache=cfg.cache_requests.get(
"delete_requests_cache", False "delete_requests_cache", False
), ),
limit=cfg.limit, limit=cfg.limit,
......
import argparse
import json
import logging
from typing import Optional, Union
def try_parse_json(value: Union[dict, str]) -> Union[str, dict, None]:
"""Try to parse a string as JSON. If it fails, return the original string."""
if value is None:
return None
if isinstance(value, dict):
return value
try:
return json.loads(value)
except json.JSONDecodeError:
if "{" in value:
raise ValueError(
f"Invalid JSON: {value}. Hint: Use double quotes for JSON strings."
)
return value
def _int_or_none_list_arg_type(
min_len: int, max_len: int, defaults: str, value: str, split_char: str = ","
) -> list[Union[int, None]]:
"""Parses a string of integers or 'None' values separated by a specified character into a list.
Validates the number of items against specified minimum and maximum lengths and fills missing values with defaults."""
def parse_value(item):
"""Parses an individual item, converting it to an integer or `None`."""
item = item.strip().lower()
if item == "none":
return None
try:
return int(item)
except ValueError:
raise ValueError(f"{item} is not an integer or None")
items = [parse_value(v) for v in value.split(split_char)]
num_items = len(items)
if num_items == 1:
items = items * max_len
elif num_items < min_len or num_items > max_len:
raise ValueError(
f"Argument requires {max_len} integers or None, separated by '{split_char}'"
)
elif num_items != max_len:
logging.warning(
f"Argument requires {max_len} integers or None, separated by '{split_char}'. "
"Missing values will be filled with defaults."
)
default_items = [parse_value(v) for v in defaults.split(split_char)]
items.extend(default_items[num_items:])
return items
def request_caching_arg_to_dict(cache_requests: Optional[str]) -> dict[str, bool]:
"""Convert a request caching argument to a dictionary."""
if cache_requests is None:
return {}
request_caching_args = {
"cache_requests": cache_requests in {"true", "refresh"},
"rewrite_requests_cache": cache_requests == "refresh",
"delete_requests_cache": cache_requests == "delete",
}
return request_caching_args
def check_argument_types(parser: argparse.ArgumentParser):
"""
Check to make sure all CLI args are typed, raises error if not
"""
for action in parser._actions:
# Skip help, subcommands, and const actions
if action.dest in ["help", "command"] or action.const is not None:
continue
if action.type is None:
raise ValueError(f"Argument '{action.dest}' doesn't have a type specified.")
else:
continue
from .evaluate_config import EvaluatorConfig
__all__ = [
"EvaluatorConfig",
]
import json import json
import logging import logging
import warnings
from argparse import Namespace from argparse import Namespace
from dataclasses import dataclass from dataclasses import asdict, dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, Optional, Union
import yaml import yaml
from lm_eval.utils import simple_parse_args_string from lm_eval.utils import simple_parse_args_string
if TYPE_CHECKING:
from lm_eval.tasks import TaskManager
DICT_KEYS = [ DICT_KEYS = [
"wandb_args", "wandb_args",
"wandb_config_args", "wandb_config_args",
...@@ -20,65 +25,145 @@ DICT_KEYS = [ ...@@ -20,65 +25,145 @@ DICT_KEYS = [
@dataclass @dataclass
class EvaluationConfig: class EvaluatorConfig:
"""
Simple config container for holding params.
""" """
Configuration container for initializing evaluator or simple_evaluate.
config: Optional[str] = None This dataclass holds all the parameters needed for running evaluations,
model: Optional[str] = None with sensible defaults and documentation for each field.
model_args: Optional[dict] = None """
tasks: Optional[str] = None
num_fewshot: Optional[int] = None
batch_size: Optional[int] = None
max_batch_size: Optional[int] = None
device: Optional[str] = None
output_path: Optional[str] = None
limit: Optional[float] = None
samples: Optional[str] = None
use_cache: Optional[str] = None
cache_requests: Optional[str] = None
check_integrity: Optional[bool] = None
write_out: Optional[bool] = None
log_samples: Optional[bool] = None
predict_only: Optional[bool] = None
system_instruction: Optional[str] = None
apply_chat_template: Optional[Union[bool, str]] = None
fewshot_as_multiturn: Optional[bool] = None
show_config: Optional[bool] = None
include_path: Optional[str] = None
gen_kwargs: Optional[dict] = None
verbosity: Optional[str] = None
wandb_args: Optional[dict] = None
wandb_config_args: Optional[dict] = None
hf_hub_log_args: Optional[dict] = None
seed: Optional[list] = None
trust_remote_code: Optional[bool] = None
confirm_run_unsafe_code: Optional[bool] = None
metadata: Optional[dict] = None
request_caching_args: Optional[dict] = None
@staticmethod # Core evaluation parameters
def _get_defaults() -> Dict[str, Any]: config: Optional[str] = field(
"""Get default values for all configuration options.""" default=None, metadata={"help": "Path to YAML config file"}
return { )
"model": "hf", model: str = field(default="hf", metadata={"help": "Name of model e.g. 'hf'"})
"model_args": {}, model_args: dict = field(
"batch_size": 1, default_factory=dict, metadata={"help": "Arguments for model initialization"}
"check_integrity": False, )
"write_out": False, tasks: Union[str, list[str]] = field(
"log_samples": False, default_factory=list,
"predict_only": False, metadata={"help": "Comma-separated list of task names to evaluate"},
"fewshot_as_multiturn": False, )
"show_config": False,
"trust_remote_code": False, # Few-shot and batching
"confirm_run_unsafe_code": False, num_fewshot: Optional[int] = field(
"metadata": {}, default=None, metadata={"help": "Number of examples in few-shot context"}
"wandb_args": {}, )
"wandb_config_args": {}, batch_size: int = field(default=1, metadata={"help": "Batch size for evaluation"})
"hf_hub_log_args": {}, max_batch_size: Optional[int] = field(
"seed": [0, 1234, 1234, 1234], default=None, metadata={"help": "Maximum batch size for auto batching"}
} )
# Device
device: Optional[str] = field(
default=None, metadata={"help": "Device to use (e.g. cuda, cuda:0, cpu)"}
)
# Data sampling and limiting
limit: Optional[float] = field(
default=None, metadata={"help": "Limit number of examples per task"}
)
samples: Union[str, dict, None] = field(
default=None,
metadata={"help": "dict, JSON string or path to JSON file with doc indices"},
)
# Caching
use_cache: Optional[str] = field(
default=None,
metadata={"help": "Path to sqlite db file for caching model outputs"},
)
cache_requests: dict = field(
default_factory=dict,
metadata={"help": "Cache dataset requests: true/refresh/delete"},
)
# Output and logging flags
check_integrity: bool = field(
default=False, metadata={"help": "Run test suite for tasks"}
)
write_out: bool = field(
default=False, metadata={"help": "Print prompts for first few documents"}
)
log_samples: bool = field(
default=False, metadata={"help": "Save model outputs and inputs"}
)
output_path: Optional[str] = field(
default=None, metadata={"help": "Dir path where result metrics will be saved"}
)
predict_only: bool = field(
default=False,
metadata={
"help": "Only save model outputs, don't evaluate metrics. Use with log_samples."
},
)
# Chat and instruction handling
system_instruction: Optional[str] = field(
default=None, metadata={"help": "Custom System instruction to add"}
)
apply_chat_template: Union[bool, str] = field(
default=False, metadata={"help": "Apply chat template to prompt"}
)
fewshot_as_multiturn: bool = field(
default=False,
metadata={
"help": "Use fewshot as multi-turn conversation. Requires apply_chat_template=True."
},
)
# Configuration display
show_config: bool = field(
default=False, metadata={"help": "Show full config at end of evaluation"}
)
# External tasks and generation
include_path: Optional[str] = field(
default=None, metadata={"help": "Additional dir path for external tasks"}
)
gen_kwargs: Optional[dict] = field(
default=None, metadata={"help": "Arguments for model generation"}
)
# Logging and verbosity
verbosity: Optional[str] = field(
default=None, metadata={"help": "Logging verbosity level"}
)
# External integrations
wandb_args: dict = field(
default_factory=dict, metadata={"help": "Arguments for wandb.init"}
)
wandb_config_args: dict = field(
default_factory=dict, metadata={"help": "Arguments for wandb.config.update"}
)
hf_hub_log_args: dict = field(
default_factory=dict, metadata={"help": "Arguments for HF Hub logging"}
)
# Reproducibility
seed: list = field(
default_factory=lambda: [0, 1234, 1234, 1234],
metadata={"help": "Seeds for random, numpy, torch, fewshot (random)"},
)
# Security and safety
trust_remote_code: bool = field(
default=False, metadata={"help": "Trust remote code for HF datasets"}
)
confirm_run_unsafe_code: bool = field(
default=False,
metadata={
"help": "Confirm understanding of unsafe code risks (for code tasks that executes arbitrary Python)"
},
)
# Internal metadata
metadata: dict = field(
default_factory=dict,
metadata={"help": "Additional metadata for tasks that require it"},
)
@staticmethod @staticmethod
def _parse_dict_args(config: Dict[str, Any]) -> Dict[str, Any]: def _parse_dict_args(config: Dict[str, Any]) -> Dict[str, Any]:
...@@ -89,24 +174,22 @@ class EvaluationConfig: ...@@ -89,24 +174,22 @@ class EvaluationConfig:
return config return config
@classmethod @classmethod
def from_cli(cls, namespace: Namespace) -> "EvaluationConfig": def from_cli(cls, namespace: Namespace) -> "EvaluatorConfig":
""" """
Build an EvaluationConfig by merging with simple precedence: Build an EvaluationConfig by merging with simple precedence:
CLI args > YAML config > built-in defaults CLI args > YAML config > built-in defaults
""" """
# Start with built-in defaults # Start with built-in defaults
config = cls._get_defaults() config = asdict(cls())
# Load and merge YAML config if provided # Load and merge YAML config if provided
if hasattr(namespace, "config") and namespace.config: if hasattr(namespace, "config") and namespace.config:
config.update(cls._load_yaml_config(namespace.config)) config.update(cls._load_yaml_config(namespace.config))
# Override with CLI args (only non-None values, exclude non-config args) # Override with CLI args (only truthy values, exclude non-config args)
excluded_args = {"config", "command", "func"} # argparse internal args excluded_args = {"config", "command", "func"} # argparse internal args
cli_args = { cli_args = {
k: v k: v for k, v in vars(namespace).items() if v and k not in excluded_args
for k, v in vars(namespace).items()
if v is not None and k not in excluded_args
} }
config.update(cli_args) config.update(cli_args)
...@@ -119,10 +202,30 @@ class EvaluationConfig: ...@@ -119,10 +202,30 @@ class EvaluationConfig:
return instance return instance
@classmethod
def from_config(cls, config_path: Union[str, Path]) -> "EvaluatorConfig":
"""
Build an EvaluationConfig from a YAML config file.
Merges with built-in defaults and validates.
"""
# Load YAML config
yaml_config = cls._load_yaml_config(config_path)
# Parse string arguments that should be dictionaries
yaml_config = cls._parse_dict_args(yaml_config)
# Create instance and validate
instance = cls(**yaml_config)
instance.validate_and_preprocess()
return instance
@staticmethod @staticmethod
def _load_yaml_config(config_path: str) -> Dict[str, Any]: def _load_yaml_config(config_path: Union[str, Path]) -> Dict[str, Any]:
"""Load and validate YAML config file.""" """Load and validate YAML config file."""
config_file = Path(config_path) config_file = (
Path(config_path) if not isinstance(config_path, Path) else config_path
)
if not config_file.is_file(): if not config_file.is_file():
raise FileNotFoundError(f"Config file not found: {config_path}") raise FileNotFoundError(f"Config file not found: {config_path}")
...@@ -146,10 +249,15 @@ class EvaluationConfig: ...@@ -146,10 +249,15 @@ class EvaluationConfig:
self._process_samples() self._process_samples()
self._setup_metadata() self._setup_metadata()
self._apply_trust_remote_code() self._apply_trust_remote_code()
self._process_tasks()
def _validate_arguments(self) -> None: def _validate_arguments(self) -> None:
"""Validate configuration arguments and cross-field constraints.""" """Validate configuration arguments and cross-field constraints."""
if self.limit:
warnings.warn(
"--limit SHOULD ONLY BE USED FOR TESTING. "
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
# predict_only implies log_samples # predict_only implies log_samples
if self.predict_only: if self.predict_only:
self.log_samples = True self.log_samples = True
...@@ -177,23 +285,36 @@ class EvaluationConfig: ...@@ -177,23 +285,36 @@ class EvaluationConfig:
def _process_samples(self) -> None: def _process_samples(self) -> None:
"""Process samples argument - load from file if needed.""" """Process samples argument - load from file if needed."""
if self.samples: if self.samples:
if (samples_path := Path(self.samples)).is_file(): if isinstance(self.samples, dict):
self.samples = json.loads(samples_path.read_text()) self.samples = self.samples
else: elif isinstance(self.samples, str):
self.samples = json.loads(self.samples) try:
self.samples = json.loads(self.samples)
def _process_tasks(self, metadata: Union[dict, str]) -> List[str]: except json.JSONDecodeError:
if (samples_path := Path(self.samples)).is_file():
self.samples = json.loads(samples_path.read_text())
def process_tasks(self, metadata: Optional[dict] = None) -> "TaskManager":
"""Process and validate tasks, return resolved task names.""" """Process and validate tasks, return resolved task names."""
from lm_eval import utils from lm_eval import utils
from lm_eval.tasks import TaskManager from lm_eval.tasks import TaskManager
# if metadata manually passed use that:
self.metadata = metadata if metadata else self.metadata
# Create task manager with metadata # Create task manager with metadata
task_manager = TaskManager( task_manager = TaskManager(
include_path=self.include_path, metadata=self.metadata include_path=self.include_path,
metadata=self.metadata if self.metadata else {},
) )
# self.tasks is a comma-separated string of task names # self.tasks is a comma-separated string of task names
task_list = self.tasks.split(",") if isinstance((task_list := self.tasks), str):
task_list = self.tasks.split(",")
else:
assert isinstance(self.tasks, list), (
"Tasks must be a comma delimited string of task names or list[str]."
)
task_names = task_manager.match_tasks(task_list) task_names = task_manager.match_tasks(task_list)
# Check for any individual task files in the list # Check for any individual task files in the list
...@@ -214,7 +335,7 @@ class EvaluationConfig: ...@@ -214,7 +335,7 @@ class EvaluationConfig:
# Update tasks with resolved names # Update tasks with resolved names
self.tasks = task_names self.tasks = task_names
return task_names return task_manager
def _setup_metadata(self) -> None: def _setup_metadata(self) -> None:
"""Set up metadata by merging model_args and metadata.""" """Set up metadata by merging model_args and metadata."""
...@@ -223,9 +344,7 @@ class EvaluationConfig: ...@@ -223,9 +344,7 @@ class EvaluationConfig:
if self.metadata is None: if self.metadata is None:
self.metadata = {} self.metadata = {}
# Merge model_args and metadata self.metadata = self.model_args | self.metadata
merged_metadata = self.model_args | self.metadata
self.metadata = merged_metadata
def _apply_trust_remote_code(self) -> None: def _apply_trust_remote_code(self) -> None:
"""Apply trust_remote_code setting if enabled.""" """Apply trust_remote_code setting if enabled."""
......
...@@ -753,13 +753,3 @@ def evaluate( ...@@ -753,13 +753,3 @@ def evaluate(
else: else:
return None return None
def request_caching_arg_to_dict(cache_requests: str) -> dict:
request_caching_args = {
"cache_requests": cache_requests in {"true", "refresh"},
"rewrite_requests_cache": cache_requests == "refresh",
"delete_requests_cache": cache_requests == "delete",
}
return request_caching_args
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment