Commit f9d5d3e7 authored by Baber's avatar Baber
Browse files

modularize cli

parent 223b9488
This diff is collapsed.
"""
CLI subcommands for the Language Model Evaluation Harness.
"""
from lm_eval._cli.base import SubCommand
from lm_eval._cli.cache import CacheCommand
from lm_eval._cli.evaluate import EvaluateCommand
from lm_eval._cli.list import ListCommand
from lm_eval._cli.parser import CLIParser
from lm_eval._cli.validate import ValidateCommand
__all__ = [
"SubCommand",
"EvaluateCommand",
"ListCommand",
"ValidateCommand",
"CacheCommand",
"CLIParser",
]
\ No newline at end of file
import argparse
import json
import logging
from abc import ABC, abstractmethod
from typing import Union
def try_parse_json(value: str) -> Union[str, dict, None]:
if value is None:
return None
try:
return json.loads(value)
except json.JSONDecodeError:
if "{" in value:
raise argparse.ArgumentTypeError(
f"Invalid JSON: {value}. Hint: Use double quotes for JSON strings."
)
return value
def _int_or_none_list_arg_type(
min_len: int, max_len: int, defaults: str, value: str, split_char: str = ","
):
def parse_value(item):
item = item.strip().lower()
if item == "none":
return None
try:
return int(item)
except ValueError:
raise argparse.ArgumentTypeError(f"{item} is not an integer or None")
items = [parse_value(v) for v in value.split(split_char)]
num_items = len(items)
if num_items == 1:
items = items * max_len
elif num_items < min_len or num_items > max_len:
raise argparse.ArgumentTypeError(
f"Argument requires {max_len} integers or None, separated by '{split_char}'"
)
elif num_items != max_len:
logging.warning(
f"Argument requires {max_len} integers or None, separated by '{split_char}'. "
"Missing values will be filled with defaults."
)
default_items = [parse_value(v) for v in defaults.split(split_char)]
items.extend(default_items[num_items:])
return items
class SubCommand(ABC):
"""Base class for all subcommands."""
def __init__(self, *args, **kwargs):
pass
@classmethod
def create(cls, subparsers: argparse._SubParsersAction):
"""Factory method to create and register a command instance."""
return cls(subparsers)
@abstractmethod
def _add_args(self, parser: argparse.ArgumentParser) -> None:
"""Add arguments specific to this subcommand."""
pass
@abstractmethod
def execute(self, args: argparse.Namespace) -> None:
"""Execute the subcommand with the given arguments."""
pass
import argparse
from lm_eval._cli.base import SubCommand
class CacheCommand(SubCommand):
"""Command for cache management."""
def __init__(self, subparsers: argparse._SubParsersAction, *args, **kwargs):
# Create and configure the parser
super().__init__(*args, **kwargs)
parser = subparsers.add_parser(
"cache",
help="Manage evaluation cache",
description="Manage evaluation cache files and directories.",
epilog="""
Examples:
lm-eval cache clear --cache_path ./cache.db # Clear cache file
lm-eval cache info --cache_path ./cache.db # Show cache info
lm-eval cache clear --cache_path ./cache_dir/ # Clear cache directory
""",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
# Add command-specific arguments
self._add_args(parser)
# Set the function to execute for this subcommand
parser.set_defaults(func=self.execute)
def _add_args(self, parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"action",
choices=["clear", "info"],
help="Action to perform: clear or info",
)
parser.add_argument(
"--cache_path",
type=str,
default=None,
help="Path to cache directory or file",
)
def execute(self, args: argparse.Namespace) -> None:
"""Execute the cache command."""
import os
if args.action == "clear":
if args.cache_path:
if os.path.exists(args.cache_path):
if os.path.isdir(args.cache_path):
import shutil
shutil.rmtree(args.cache_path)
else:
os.remove(args.cache_path)
print(f"✅ Cache cleared: {args.cache_path}")
else:
print(f"❌ Cache path not found: {args.cache_path}")
else:
print("❌ Please specify --cache_path")
elif args.action == "info":
if args.cache_path and os.path.exists(args.cache_path):
import os
size = os.path.getsize(args.cache_path)
print(f"Cache: {args.cache_path}")
print(f"Size: {size} bytes")
else:
print("❌ Cache path not found or not specified")
import argparse
import json
import logging
import os
import sys
from functools import partial
from pathlib import Path
from lm_eval._cli.base import SubCommand, _int_or_none_list_arg_type, try_parse_json
class EvaluateCommand(SubCommand):
"""Command for running language model evaluation."""
def __init__(self, subparsers: argparse._SubParsersAction, *args, **kwargs):
# Create and configure the parser
super().__init__(*args, **kwargs)
parser = subparsers.add_parser(
"evaluate",
help="Run language model evaluation",
description="Evaluate language models on various benchmarks and tasks.",
epilog="""
Examples:
lm-eval evaluate --model hf --model_args pretrained=gpt2 --tasks hellaswag
lm-eval evaluate --config my_config.yaml --tasks arc_easy,arc_challenge
lm-eval evaluate --model openai --tasks mmlu --num_fewshot 5
""",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
# Add command-specific arguments
self._add_args(parser)
# Set the function to execute for this subcommand
parser.set_defaults(func=self.execute)
def _add_args(self, parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--config",
"-C",
default=None,
type=str,
metavar="DIR/file.yaml",
help="Path to config with all arguments for `lm-eval`",
)
parser.add_argument(
"--model",
"-m",
type=str,
default="hf",
help="Name of model e.g. `hf`",
)
parser.add_argument(
"--tasks",
"-t",
default=None,
type=str,
metavar="task1,task2",
help="Comma-separated list of task names or task groupings to evaluate on.\nTo get full list of tasks, use one of the commands `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above",
)
parser.add_argument(
"--model_args",
"-a",
default="",
type=try_parse_json,
help="""Comma separated string or JSON formatted arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32` or '{"pretrained":"EleutherAI/pythia-160m","dtype":"float32"}'.""",
)
parser.add_argument(
"--num_fewshot",
"-f",
type=int,
default=None,
metavar="N",
help="Number of examples in few-shot context",
)
parser.add_argument(
"--batch_size",
"-b",
type=str,
default=1,
metavar="auto|auto:N|N",
help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.",
)
parser.add_argument(
"--max_batch_size",
type=int,
default=None,
metavar="N",
help="Maximal batch size to try with --batch_size auto.",
)
parser.add_argument(
"--device",
type=str,
default=None,
help="Device to use (e.g. cuda, cuda:0, cpu).",
)
parser.add_argument(
"--output_path",
"-o",
default=None,
type=str,
metavar="DIR|DIR/file.json",
help="Path where result metrics will be saved. Can be either a directory or a .json file. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
)
parser.add_argument(
"--limit",
"-L",
type=float,
default=None,
metavar="N|0<N<1",
help="Limit the number of examples per task. "
"If <1, limit is a percentage of the total number of examples.",
)
parser.add_argument(
"--samples",
"-E",
default=None,
type=str,
metavar="/path/to/json",
help='JSON string or path to JSON file containing doc indices of selected examples to test. Format: {"task_name":[indices],...}',
)
parser.add_argument(
"--use_cache",
"-c",
type=str,
default=None,
metavar="DIR",
help="A path to a sqlite db file for caching model responses. `None` if not caching.",
)
parser.add_argument(
"--cache_requests",
type=str,
default=None,
choices=["true", "refresh", "delete"],
help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.",
)
parser.add_argument(
"--check_integrity",
action="store_true",
help="Whether to run the relevant part of the test suite for the tasks.",
)
parser.add_argument(
"--write_out",
"-w",
action="store_true",
default=False,
help="Prints the prompt for the first few documents.",
)
parser.add_argument(
"--log_samples",
"-s",
action="store_true",
default=False,
help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis. Use with --output_path.",
)
parser.add_argument(
"--system_instruction",
type=str,
default=None,
help="System instruction to be used in the prompt",
)
parser.add_argument(
"--apply_chat_template",
type=str,
nargs="?",
const=True,
default=False,
help=(
"If True, apply chat template to the prompt. "
"Providing `--apply_chat_template` without an argument will apply the default chat template to the prompt. "
"To apply a specific template from the available list of templates, provide the template name as an argument. "
"E.g. `--apply_chat_template template_name`"
),
)
parser.add_argument(
"--fewshot_as_multiturn",
action="store_true",
default=False,
help="If True, uses the fewshot as a multi-turn conversation",
)
parser.add_argument(
"--show_config",
action="store_true",
default=False,
help="If True, shows the the full config of all tasks at the end of the evaluation.",
)
parser.add_argument(
"--include_path",
type=str,
default=None,
metavar="DIR",
help="Additional path to include if there are external tasks to include.",
)
parser.add_argument(
"--gen_kwargs",
type=try_parse_json,
default=None,
help=(
"Either comma delimited string or JSON formatted arguments for model generation on greedy_until tasks,"
""" e.g. '{"temperature":0.7,"until":["hello"]}' or temperature=0,top_p=0.1."""
),
)
parser.add_argument(
"--verbosity",
"-v",
type=str.upper,
default=None,
metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG",
help="(Deprecated) Controls logging verbosity level. Use the `LOGLEVEL` environment variable instead. Set to DEBUG for detailed output when testing or adding new task configurations.",
)
parser.add_argument(
"--wandb_args",
type=str,
default="",
help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval`",
)
parser.add_argument(
"--wandb_config_args",
type=str,
default="",
help="Comma separated string arguments passed to wandb.config.update. Use this to trace parameters that aren't already traced by default. eg. `lr=0.01,repeats=3`",
)
parser.add_argument(
"--hf_hub_log_args",
type=str,
default="",
help="Comma separated string arguments passed to Hugging Face Hub's log function, e.g. `hub_results_org=EleutherAI,hub_repo_name=lm-eval-results`",
)
parser.add_argument(
"--predict_only",
"-x",
action="store_true",
default=False,
help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
)
default_seed_string = "0,1234,1234,1234"
parser.add_argument(
"--seed",
type=partial(_int_or_none_list_arg_type, 3, 4, default_seed_string),
default=default_seed_string, # for backward compatibility
help=(
"Set seed for python's random, numpy, torch, and fewshot sampling.\n"
"Accepts a comma-separated list of 4 values for python's random, numpy, torch, and fewshot sampling seeds, "
"respectively, or a single integer to set the same seed for all four.\n"
f"The values are either an integer or 'None' to not set the seed. Default is `{default_seed_string}` "
"(for backward compatibility).\n"
"E.g. `--seed 0,None,8,52` sets `random.seed(0)`, `torch.manual_seed(8)`, and fewshot sampling seed to 52. "
"Here numpy's seed is not set since the second value is `None`.\n"
"E.g, `--seed 42` sets all four seeds to 42."
),
)
parser.add_argument(
"--trust_remote_code",
action="store_true",
help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
)
parser.add_argument(
"--confirm_run_unsafe_code",
action="store_true",
help="Confirm that you understand the risks of running unsafe code for tasks that require it",
)
parser.add_argument(
"--metadata",
type=json.loads,
default=None,
help="""JSON string metadata to pass to task configs, for example '{"max_seq_lengths":[4096,8192]}'. Will be merged with model_args. Can also be set in task config.""",
)
def execute(self, args: argparse.Namespace) -> None:
"""Execute the evaluation command."""
# Import here to avoid circular imports and for faster CLI loading
from lm_eval.api.eval_config import EvaluationConfig
# Create and validate config (validation now happens in EvaluationConfig)
cfg = EvaluationConfig.from_cli(args)
from lm_eval import evaluator, utils
from lm_eval.evaluator import request_caching_arg_to_dict
from lm_eval.loggers import EvaluationTracker, WandbLogger
from lm_eval.tasks import TaskManager
from lm_eval.utils import handle_non_serializable, make_table
# Set up logging
if cfg.wandb_args:
wandb_logger = WandbLogger(cfg.wandb_args, cfg.wandb_config_args)
utils.setup_logging(cfg.verbosity)
eval_logger = logging.getLogger(__name__)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Set up evaluation tracker
if cfg.output_path:
cfg.hf_hub_log_args["output_path"] = cfg.output_path
if os.environ.get("HF_TOKEN", None):
cfg.hf_hub_log_args["token"] = os.environ.get("HF_TOKEN")
evaluation_tracker = EvaluationTracker(**cfg.hf_hub_log_args)
# Create task manager (metadata already set up in config validation)
task_manager = TaskManager(include_path=cfg.include_path, metadata=cfg.metadata)
# Validation warnings (keep these in CLI as they're logging-specific)
if "push_samples_to_hub" in cfg.hf_hub_log_args and not cfg.log_samples:
eval_logger.warning(
"Pushing samples to the Hub requires --log_samples to be set."
)
if cfg.limit:
eval_logger.warning(
"--limit SHOULD ONLY BE USED FOR TESTING. "
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
# Log task selection (tasks already processed in config)
if cfg.include_path is not None:
eval_logger.info(f"Including path: {cfg.include_path}")
eval_logger.info(f"Selected Tasks: {cfg.tasks}")
# Set up caching
request_caching_args = request_caching_arg_to_dict(
cache_requests=cfg.cache_requests
)
cfg.request_caching_args = request_caching_args
# Run evaluation
results = evaluator.simple_evaluate(
model=cfg.model,
model_args=cfg.model_args,
tasks=cfg.tasks,
num_fewshot=cfg.num_fewshot,
batch_size=cfg.batch_size,
max_batch_size=cfg.max_batch_size,
device=cfg.device,
use_cache=cfg.use_cache,
cache_requests=cfg.request_caching_args.get("cache_requests", False),
rewrite_requests_cache=cfg.request_caching_args.get(
"rewrite_requests_cache", False
),
delete_requests_cache=cfg.request_caching_args.get(
"delete_requests_cache", False
),
limit=cfg.limit,
samples=cfg.samples,
check_integrity=cfg.check_integrity,
write_out=cfg.write_out,
log_samples=cfg.log_samples,
evaluation_tracker=evaluation_tracker,
system_instruction=cfg.system_instruction,
apply_chat_template=cfg.apply_chat_template,
fewshot_as_multiturn=cfg.fewshot_as_multiturn,
gen_kwargs=cfg.gen_kwargs,
task_manager=task_manager,
verbosity=cfg.verbosity,
predict_only=cfg.predict_only,
random_seed=cfg.seed[0] if cfg.seed else None,
numpy_random_seed=cfg.seed[1] if cfg.seed else None,
torch_random_seed=cfg.seed[2] if cfg.seed else None,
fewshot_random_seed=cfg.seed[3] if cfg.seed else None,
confirm_run_unsafe_code=cfg.confirm_run_unsafe_code,
metadata=cfg.metadata,
)
# Process results
if results is not None:
if cfg.log_samples:
samples = results.pop("samples")
dumped = json.dumps(
results, indent=2, default=handle_non_serializable, ensure_ascii=False
)
if cfg.show_config:
print(dumped)
batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))
# W&B logging
if cfg.wandb_args:
try:
wandb_logger.post_init(results)
wandb_logger.log_eval_result()
if cfg.log_samples:
wandb_logger.log_eval_samples(samples)
except Exception as e:
eval_logger.info(f"Logging to W&B failed: {e}")
# Save results
evaluation_tracker.save_results_aggregated(
results=results, samples=samples if cfg.log_samples else None
)
if cfg.log_samples:
for task_name, _ in results["configs"].items():
evaluation_tracker.save_results_samples(
task_name=task_name, samples=samples[task_name]
)
if (
evaluation_tracker.push_results_to_hub
or evaluation_tracker.push_samples_to_hub
):
evaluation_tracker.recreate_metadata_card()
# Print results
print(
f"{cfg.model} ({cfg.model_args}), gen_kwargs: ({cfg.gen_kwargs}), "
f"limit: {cfg.limit}, num_fewshot: {cfg.num_fewshot}, "
f"batch_size: {cfg.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
)
print(make_table(results))
if "groups" in results:
print(make_table(results, "groups"))
if cfg.wandb_args:
wandb_logger.run.finish()
import argparse
from lm_eval._cli.base import SubCommand
class ListCommand(SubCommand):
"""Command for listing available tasks."""
def __init__(self, subparsers: argparse._SubParsersAction, *args, **kwargs):
# Create and configure the parser
super().__init__(*args, **kwargs)
parser = subparsers.add_parser(
"list",
help="List available tasks, groups, subtasks, or tags",
description="List available tasks, groups, subtasks, or tags from the evaluation harness.",
epilog="""
Examples:
lm-eval list tasks # List all available tasks
lm-eval list groups # List task groups only
lm-eval list subtasks # List subtasks only
lm-eval list tags # List available tags
""",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
# Add command-specific arguments
self._add_args(parser)
# Set the function to execute for this subcommand
parser.set_defaults(func=self.execute)
def _add_args(self, parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"what",
choices=["tasks", "groups", "subtasks", "tags"],
help="What to list: tasks (all), groups, subtasks, or tags",
)
parser.add_argument(
"--include_path",
type=str,
default=None,
metavar="DIR",
help="Additional path to include if there are external tasks.",
)
def execute(self, args: argparse.Namespace) -> None:
"""Execute the list command."""
from lm_eval.tasks import TaskManager
task_manager = TaskManager(include_path=args.include_path)
if args.what == "tasks":
print(task_manager.list_all_tasks())
elif args.what == "groups":
print(task_manager.list_all_tasks(list_subtasks=False, list_tags=False))
elif args.what == "subtasks":
print(task_manager.list_all_tasks(list_groups=False, list_tags=False))
elif args.what == "tags":
print(task_manager.list_all_tasks(list_groups=False, list_subtasks=False))
import argparse
import sys
from typing import Dict, Type
from lm_eval._cli.base import SubCommand
from lm_eval._cli.cache import CacheCommand
from lm_eval._cli.evaluate import EvaluateCommand
from lm_eval._cli.list import ListCommand
from lm_eval._cli.validate import ValidateCommand
def check_argument_types(parser: argparse.ArgumentParser):
"""
Check to make sure all CLI args are typed, raises error if not
"""
for action in parser._actions:
# Skip help, subcommands, and const actions
if action.dest in ["help", "command"] or action.const is not None:
continue
if action.type is None:
raise ValueError(f"Argument '{action.dest}' doesn't have a type specified.")
else:
continue
class CLIParser:
"""Main CLI parser class that manages all subcommands."""
def __init__(self):
self.parser = None
self.subparsers = None
self.legacy_parser = None
self.command_instances: Dict[str, SubCommand] = {}
def setup_parser(self) -> argparse.ArgumentParser:
"""Set up the main parser with subcommands."""
if self.parser is not None:
return self.parser
self.parser = argparse.ArgumentParser(
prog="lm-eval",
description="Language Model Evaluation Harness",
formatter_class=argparse.RawTextHelpFormatter,
)
# Create subparsers
self.subparsers = self.parser.add_subparsers(
dest="command", help="Available commands", metavar="COMMAND"
)
# Create and register all command instances
self.command_instances = {
"evaluate": EvaluateCommand.create(self.subparsers),
"list": ListCommand.create(self.subparsers),
"validate": ValidateCommand.create(self.subparsers),
"cache": CacheCommand.create(self.subparsers),
}
return self.parser
def setup_legacy_parser(self) -> argparse.ArgumentParser:
"""Set up legacy parser for backward compatibility."""
if self.legacy_parser is not None:
return self.legacy_parser
self.legacy_parser = argparse.ArgumentParser(
formatter_class=argparse.RawTextHelpFormatter
)
# For legacy mode, we just need to add the evaluate command's arguments
# without the subcommand structure. We'll create a temporary instance.
from lm_eval._cli.evaluate import EvaluateCommand as EvalCmd
# Create a minimal instance just to get the arguments
temp_cmd = object.__new__(EvalCmd)
temp_cmd._add_args(self.legacy_parser)
return self.legacy_parser
def parse_args(self, args=None) -> argparse.Namespace:
"""Parse arguments using the main parser."""
parser = self.setup_parser()
check_argument_types(parser)
return parser.parse_args(args)
def parse_legacy_args(self, args=None) -> argparse.Namespace:
"""Parse arguments using the legacy parser."""
parser = self.setup_legacy_parser()
check_argument_types(parser)
return parser.parse_args(args)
def should_use_subcommand_mode(self, argv=None) -> bool:
"""Determine if we should use subcommand mode based on arguments."""
if argv is None:
argv = sys.argv[1:]
# If no arguments, show main help
if len(argv) == 0:
return True
# Check if first argument is a known subcommand
# First ensure parser is set up to populate command_instances
if not self.command_instances:
self.setup_parser()
if len(argv) > 0 and argv[0] in self.command_instances:
return True
return False
def execute(self, argv=None) -> None:
"""Main execution method that handles both subcommand and legacy modes."""
if self.should_use_subcommand_mode(argv):
# Use subcommand mode
if argv is None and len(sys.argv) == 1:
# No arguments provided, show help
self.setup_parser().print_help()
sys.exit(1)
args = self.parse_args(argv)
args.func(args)
else:
# Use legacy mode for backward compatibility
args = self.parse_legacy_args(argv)
self._handle_legacy_mode(args)
def _handle_legacy_mode(self, args: argparse.Namespace) -> None:
"""Handle legacy CLI mode for backward compatibility."""
# Handle legacy task listing
if hasattr(args, "tasks") and args.tasks in [
"list",
"list_groups",
"list_subtasks",
"list_tags",
]:
from lm_eval.tasks import TaskManager
task_manager = TaskManager(include_path=getattr(args, "include_path", None))
if args.tasks == "list":
print(task_manager.list_all_tasks())
elif args.tasks == "list_groups":
print(task_manager.list_all_tasks(list_subtasks=False, list_tags=False))
elif args.tasks == "list_subtasks":
print(task_manager.list_all_tasks(list_groups=False, list_tags=False))
elif args.tasks == "list_tags":
print(
task_manager.list_all_tasks(list_groups=False, list_subtasks=False)
)
sys.exit(0)
# Handle legacy evaluation
# Use existing instance if available, otherwise create temporary one
if "evaluate" in self.command_instances:
evaluate_cmd = self.command_instances["evaluate"]
else:
# For legacy mode, we don't need the subparser registration
# Just execute with the existing args
from lm_eval._cli.evaluate import EvaluateCommand as EvalCmd
# Create a minimal instance just for execution
evaluate_cmd = object.__new__(EvalCmd)
evaluate_cmd.execute(args)
def add_command(self, name: str, command_class: Type[SubCommand]) -> None:
"""Add a new command to the parser (for extensibility)."""
# If parser is already set up, create and register the command instance
if self.subparsers is not None:
self.command_instances[name] = command_class.create(self.subparsers)
else:
# Store class for later instantiation
if not hasattr(self, "_pending_commands"):
self._pending_commands = {}
self._pending_commands[name] = command_class
import argparse
import sys
from lm_eval._cli.base import SubCommand
class ValidateCommand(SubCommand):
"""Command for validating tasks."""
def __init__(self, subparsers: argparse._SubParsersAction, *args, **kwargs):
# Create and configure the parser
super().__init__(*args, **kwargs)
parser = subparsers.add_parser(
"validate",
help="Validate task configurations",
description="Validate task configurations and check for errors.",
epilog="""
Examples:
lm-eval validate --tasks hellaswag # Validate single task
lm-eval validate --tasks arc_easy,arc_challenge # Validate multiple tasks
lm-eval validate --tasks mmlu --include_path ./custom_tasks
""",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
# Add command-specific arguments
self._add_args(parser)
# Set the function to execute for this subcommand
parser.set_defaults(func=self.execute)
def _add_args(self, parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--tasks",
"-t",
required=True,
type=str,
metavar="task1,task2",
help="Comma-separated list of task names to validate",
)
parser.add_argument(
"--include_path",
type=str,
default=None,
metavar="DIR",
help="Additional path to include if there are external tasks.",
)
def execute(self, args: argparse.Namespace) -> None:
"""Execute the validate command."""
from lm_eval.tasks import TaskManager
task_manager = TaskManager(include_path=args.include_path)
task_list = args.tasks.split(",")
print(f"Validating tasks: {task_list}")
# For now, just validate that tasks exist
task_names = task_manager.match_tasks(task_list)
task_missing = [task for task in task_list if task not in task_names]
if task_missing:
missing = ", ".join(task_missing)
print(f"Tasks not found: {missing}")
sys.exit(1)
else:
print("All tasks found and valid")
import json
import logging
from argparse import Namespace
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import yaml
from lm_eval.utils import simple_parse_args_string
DICT_KEYS = [
"wandb_args",
"wandb_config_args",
"hf_hub_log_args",
"metadata",
"model_args",
]
@dataclass
class EvaluationConfig:
"""
Simple config container for holding params.
"""
config: Optional[str] = None
model: Optional[str] = None
model_args: Optional[dict] = None
tasks: Optional[str] = None
num_fewshot: Optional[int] = None
batch_size: Optional[int] = None
max_batch_size: Optional[int] = None
device: Optional[str] = None
output_path: Optional[str] = None
limit: Optional[float] = None
samples: Optional[str] = None
use_cache: Optional[str] = None
cache_requests: Optional[str] = None
check_integrity: Optional[bool] = None
write_out: Optional[bool] = None
log_samples: Optional[bool] = None
predict_only: Optional[bool] = None
system_instruction: Optional[str] = None
apply_chat_template: Optional[Union[bool, str]] = None
fewshot_as_multiturn: Optional[bool] = None
show_config: Optional[bool] = None
include_path: Optional[str] = None
gen_kwargs: Optional[dict] = None
verbosity: Optional[str] = None
wandb_args: Optional[dict] = None
wandb_config_args: Optional[dict] = None
hf_hub_log_args: Optional[dict] = None
seed: Optional[list] = None
trust_remote_code: Optional[bool] = None
confirm_run_unsafe_code: Optional[bool] = None
metadata: Optional[dict] = None
request_caching_args: Optional[dict] = None
@staticmethod
def _get_defaults() -> Dict[str, Any]:
"""Get default values for all configuration options."""
return {
"model": "hf",
"model_args": {},
"batch_size": 1,
"check_integrity": False,
"write_out": False,
"log_samples": False,
"predict_only": False,
"fewshot_as_multiturn": False,
"show_config": False,
"trust_remote_code": False,
"confirm_run_unsafe_code": False,
"metadata": {},
"wandb_args": {},
"wandb_config_args": {},
"hf_hub_log_args": {},
"seed": [0, 1234, 1234, 1234],
}
@staticmethod
def _parse_dict_args(config: Dict[str, Any]) -> Dict[str, Any]:
"""Parse string arguments that should be dictionaries."""
for key in config:
if key in DICT_KEYS and isinstance(config[key], str):
config[key] = simple_parse_args_string(config[key])
return config
@classmethod
def from_cli(cls, namespace: Namespace) -> "EvaluationConfig":
"""
Build an EvaluationConfig by merging with simple precedence:
CLI args > YAML config > built-in defaults
"""
# Start with built-in defaults
config = cls._get_defaults()
# Load and merge YAML config if provided
if hasattr(namespace, "config") and namespace.config:
config.update(cls._load_yaml_config(namespace.config))
# Override with CLI args (only non-None values, exclude non-config args)
excluded_args = {"config", "command", "func"} # argparse internal args
cli_args = {
k: v
for k, v in vars(namespace).items()
if v is not None and k not in excluded_args
}
config.update(cli_args)
# Parse string arguments that should be dictionaries
config = cls._parse_dict_args(config)
# Create instance and validate
instance = cls(**config)
instance.validate_and_preprocess()
return instance
@staticmethod
def _load_yaml_config(config_path: str) -> Dict[str, Any]:
"""Load and validate YAML config file."""
config_file = Path(config_path)
if not config_file.is_file():
raise FileNotFoundError(f"Config file not found: {config_path}")
try:
yaml_data = yaml.safe_load(config_file.read_text())
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML in {config_path}: {e}")
except (OSError, UnicodeDecodeError) as e:
raise ValueError(f"Could not read config file {config_path}: {e}")
if not isinstance(yaml_data, dict):
raise ValueError(
f"YAML root must be a mapping, got {type(yaml_data).__name__}"
)
return yaml_data
def validate_and_preprocess(self) -> None:
"""Validate configuration and preprocess fields after creation."""
self._validate_arguments()
self._process_samples()
self._setup_metadata()
self._apply_trust_remote_code()
self._process_tasks()
def _validate_arguments(self) -> None:
"""Validate configuration arguments and cross-field constraints."""
# predict_only implies log_samples
if self.predict_only:
self.log_samples = True
# log_samples or predict_only requires output_path
if (self.log_samples or self.predict_only) and not self.output_path:
raise ValueError(
"Specify --output_path if providing --log_samples or --predict_only"
)
# fewshot_as_multiturn requires apply_chat_template
if self.fewshot_as_multiturn and self.apply_chat_template is False:
raise ValueError(
"When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set."
)
# samples and limit are mutually exclusive
if self.samples and self.limit is not None:
raise ValueError("If --samples is not None, then --limit must be None.")
# tasks is required
if self.tasks is None:
raise ValueError("Need to specify task to evaluate.")
def _process_samples(self) -> None:
"""Process samples argument - load from file if needed."""
if self.samples:
if (samples_path := Path(self.samples)).is_file():
self.samples = json.loads(samples_path.read_text())
else:
self.samples = json.loads(self.samples)
def _process_tasks(self, metadata: Union[dict, str]) -> List[str]:
"""Process and validate tasks, return resolved task names."""
from lm_eval import utils
from lm_eval.tasks import TaskManager
# Create task manager with metadata
task_manager = TaskManager(
include_path=self.include_path, metadata=self.metadata
)
# self.tasks is a comma-separated string of task names
task_list = self.tasks.split(",")
task_names = task_manager.match_tasks(task_list)
# Check for any individual task files in the list
for task in [task for task in task_list if task not in task_names]:
task_path = Path(task)
if task_path.is_file():
config = utils.load_yaml_config(str(task_path))
task_names.append(config)
# Check for missing tasks
task_missing = [
task for task in task_list if task not in task_names and "*" not in task
]
if task_missing:
missing = ", ".join(task_missing)
raise ValueError(f"Tasks not found: {missing}")
# Update tasks with resolved names
self.tasks = task_names
return task_names
def _setup_metadata(self) -> None:
"""Set up metadata by merging model_args and metadata."""
if self.model_args is None:
self.model_args = {}
if self.metadata is None:
self.metadata = {}
# Merge model_args and metadata
merged_metadata = self.model_args | self.metadata
self.metadata = merged_metadata
def _apply_trust_remote_code(self) -> None:
"""Apply trust_remote_code setting if enabled."""
if self.trust_remote_code:
eval_logger = logging.getLogger(__name__)
eval_logger.info("Setting HF_DATASETS_TRUST_REMOTE_CODE=true")
# HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally,
# because it's already been determined based on the prior env var before launching our
# script--`datasets` gets imported by lm_eval internally before these lines can update the env.
import datasets
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
# Add to model_args for the actual model initialization
if self.model_args is None:
self.model_args = {}
self.model_args["trust_remote_code"] = True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment