Commit fadd26e4 authored by Baber's avatar Baber
Browse files

add tests

parent 649ca8fc
...@@ -13,7 +13,7 @@ Equivalently, running the library can be done via the `lm-eval` entrypoint at th ...@@ -13,7 +13,7 @@ Equivalently, running the library can be done via the `lm-eval` entrypoint at th
The CLI now uses a subcommand structure for better organization: The CLI now uses a subcommand structure for better organization:
- `lm-eval run` - Execute evaluations (default behavior) - `lm-eval run` - Execute evaluations (default behavior)
- `lm-eval list` - List available tasks, models, etc. - `lm-eval ls` - List available tasks, models, etc.
- `lm-eval validate` - Validate task configurations - `lm-eval validate` - Validate task configurations
For backward compatibility, if no subcommand is specified, `run` is automatically inserted. So `lm-eval --model hf --tasks hellaswag` is equivalent to `lm-eval run --model hf --tasks hellaswag`. For backward compatibility, if no subcommand is specified, `run` is automatically inserted. So `lm-eval --model hf --tasks hellaswag` is equivalent to `lm-eval run --model hf --tasks hellaswag`.
......
from lm_eval._cli.eval import Eval from lm_eval._cli.harness import HarnessCLI
from lm_eval.utils import setup_logging from lm_eval.utils import setup_logging
def cli_evaluate() -> None: def cli_evaluate() -> None:
"""Main CLI entry point with subcommand and legacy support.""" """Main CLI entry point."""
setup_logging() setup_logging()
parser = Eval() parser = HarnessCLI()
args = parser.parse_args() args = parser.parse_args()
parser.execute(args) parser.execute(args)
......
...@@ -2,12 +2,12 @@ import argparse ...@@ -2,12 +2,12 @@ import argparse
import sys import sys
import textwrap import textwrap
from lm_eval._cli.listall import ListAll from lm_eval._cli.ls import List
from lm_eval._cli.run import Run from lm_eval._cli.run import Run
from lm_eval._cli.validate import Validate from lm_eval._cli.validate import Validate
class Eval: class HarnessCLI:
"""Main CLI parser that manages all subcommands.""" """Main CLI parser that manages all subcommands."""
def __init__(self): def __init__(self):
...@@ -20,7 +20,7 @@ class Eval: ...@@ -20,7 +20,7 @@ class Eval:
lm-eval run --model hf --model_args pretrained=gpt2 --tasks hellaswag lm-eval run --model hf --model_args pretrained=gpt2 --tasks hellaswag
# List available tasks # List available tasks
lm-eval list tasks lm-eval ls tasks
# Validate task configurations # Validate task configurations
lm-eval validate --tasks hellaswag,arc_easy lm-eval validate --tasks hellaswag,arc_easy
...@@ -40,7 +40,7 @@ class Eval: ...@@ -40,7 +40,7 @@ class Eval:
dest="command", help="Available commands", metavar="COMMAND" dest="command", help="Available commands", metavar="COMMAND"
) )
Run.create(self._subparsers) Run.create(self._subparsers)
ListAll.create(self._subparsers) List.create(self._subparsers)
Validate.create(self._subparsers) Validate.create(self._subparsers)
def parse_args(self) -> argparse.Namespace: def parse_args(self) -> argparse.Namespace:
......
...@@ -4,33 +4,33 @@ import textwrap ...@@ -4,33 +4,33 @@ import textwrap
from lm_eval._cli.subcommand import SubCommand from lm_eval._cli.subcommand import SubCommand
class ListAll(SubCommand): class List(SubCommand):
"""Command for listing available tasks.""" """Command for listing available tasks."""
def __init__(self, subparsers: argparse._SubParsersAction, *args, **kwargs): def __init__(self, subparsers: argparse._SubParsersAction, *args, **kwargs):
# Create and configure the parser # Create and configure the parser
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._parser = subparsers.add_parser( self._parser = subparsers.add_parser(
"list", "ls",
help="List available tasks, groups, subtasks, or tags", help="List available tasks, groups, subtasks, or tags",
description="List available tasks, groups, subtasks, or tags from the evaluation harness.", description="List available tasks, groups, subtasks, or tags from the evaluation harness.",
usage="lm-eval list [tasks|groups|subtasks|tags] [--include_path DIR]", usage="lm-eval list [tasks|groups|subtasks|tags] [--include_path DIR]",
epilog=textwrap.dedent(""" epilog=textwrap.dedent("""
examples: examples:
# List all available tasks (includes groups, subtasks, and tags) # List all available tasks (includes groups, subtasks, and tags)
$ lm-eval list tasks $ lm-eval ls tasks
# List only task groups (like 'mmlu', 'glue', 'superglue') # List only task groups (like 'mmlu', 'glue', 'superglue')
$ lm-eval list groups $ lm-eval ls groups
# List only individual subtasks (like 'mmlu_abstract_algebra') # List only individual subtasks (like 'mmlu_abstract_algebra')
$ lm-eval list subtasks $ lm-eval ls subtasks
# Include external task definitions # Include external task definitions
$ lm-eval list tasks --include_path /path/to/external/tasks $ lm-eval ls tasks --include_path /path/to/external/tasks
# List tasks from multiple external paths # List tasks from multiple external paths
$ lm-eval list tasks --include_path "/path/to/tasks1:/path/to/tasks2" $ lm-eval ls tasks --include_path "/path/to/tasks1:/path/to/tasks2"
organization: organization:
• Groups: Collections of tasks with aggregated metric across subtasks (e.g., 'mmlu') • Groups: Collections of tasks with aggregated metric across subtasks (e.g., 'mmlu')
...@@ -46,7 +46,7 @@ class ListAll(SubCommand): ...@@ -46,7 +46,7 @@ class ListAll(SubCommand):
formatter_class=argparse.RawDescriptionHelpFormatter, formatter_class=argparse.RawDescriptionHelpFormatter,
) )
self._add_args() self._add_args()
self._parser.set_defaults(func=lambda arg: self._parser.print_help()) self._parser.set_defaults(func=self._execute)
def _add_args(self) -> None: def _add_args(self) -> None:
self._parser.add_argument( self._parser.add_argument(
...@@ -63,7 +63,7 @@ class ListAll(SubCommand): ...@@ -63,7 +63,7 @@ class ListAll(SubCommand):
help="Additional path to include if there are external tasks.", help="Additional path to include if there are external tasks.",
) )
def execute(self, args: argparse.Namespace) -> None: def _execute(self, args: argparse.Namespace) -> None:
"""Execute the list command.""" """Execute the list command."""
from lm_eval.tasks import TaskManager from lm_eval.tasks import TaskManager
......
...@@ -42,12 +42,12 @@ class Run(SubCommand): ...@@ -42,12 +42,12 @@ class Run(SubCommand):
formatter_class=argparse.RawDescriptionHelpFormatter, formatter_class=argparse.RawDescriptionHelpFormatter,
) )
self._add_args() self._add_args()
self._parser.set_defaults(func=self.execute) self._parser.set_defaults(func=self._execute)
def _add_args(self) -> None: def _add_args(self) -> None:
self._parser = self._parser self._parser = self._parser
# Configuration # Defaults are set in config/evaluate_config.py
config_group = self._parser.add_argument_group("configuration") config_group = self._parser.add_argument_group("configuration")
config_group.add_argument( config_group.add_argument(
"--config", "--config",
...@@ -64,7 +64,7 @@ class Run(SubCommand): ...@@ -64,7 +64,7 @@ class Run(SubCommand):
"--model", "--model",
"-m", "-m",
type=str, type=str,
default="hf", default=None,
metavar="MODEL_NAME", metavar="MODEL_NAME",
help="Model name (default: hf)", help="Model name (default: hf)",
) )
...@@ -283,7 +283,7 @@ class Run(SubCommand): ...@@ -283,7 +283,7 @@ class Run(SubCommand):
advanced_group.add_argument( advanced_group.add_argument(
"--seed", "--seed",
type=partial(_int_or_none_list_arg_type, 3, 4, default_seed_string), type=partial(_int_or_none_list_arg_type, 3, 4, default_seed_string),
default=default_seed_string, default=None,
metavar="SEED|S1,S2,S3,S4", metavar="SEED|S1,S2,S3,S4",
help=textwrap.dedent(f""" help=textwrap.dedent(f"""
Random seeds for python,numpy,torch,fewshot (default: {default_seed_string}). Random seeds for python,numpy,torch,fewshot (default: {default_seed_string}).
...@@ -309,18 +309,21 @@ class Run(SubCommand): ...@@ -309,18 +309,21 @@ class Run(SubCommand):
default=None, default=None,
metavar="JSON", metavar="JSON",
help=textwrap.dedent( help=textwrap.dedent(
"JSON metadata for task configs (merged with model_args), required for some tasks such as RULER" """JSON metadata for task configs (merged with model_args), required for some tasks such as RULER"""
), ),
) )
def execute(self, args: argparse.Namespace) -> None: def _execute(self, args: argparse.Namespace) -> None:
"""Runs the evaluation harness with the provided arguments.""" """Runs the evaluation harness with the provided arguments."""
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from lm_eval.config.evaluate_config import EvaluatorConfig from lm_eval.config.evaluate_config import EvaluatorConfig
# Create and validate config (most validation now happens in EvaluationConfig) eval_logger = logging.getLogger(__name__)
# Create and validate config (most validation now occurs in EvaluationConfig)
cfg = EvaluatorConfig.from_cli(args) cfg = EvaluatorConfig.from_cli(args)
from lm_eval import simple_evaluate, utils from lm_eval import simple_evaluate
from lm_eval.loggers import EvaluationTracker, WandbLogger from lm_eval.loggers import EvaluationTracker, WandbLogger
from lm_eval.utils import handle_non_serializable, make_table from lm_eval.utils import handle_non_serializable, make_table
...@@ -328,10 +331,6 @@ class Run(SubCommand): ...@@ -328,10 +331,6 @@ class Run(SubCommand):
if cfg.wandb_args: if cfg.wandb_args:
wandb_logger = WandbLogger(cfg.wandb_args, cfg.wandb_config_args) wandb_logger = WandbLogger(cfg.wandb_args, cfg.wandb_config_args)
utils.setup_logging(cfg.verbosity)
eval_logger = logging.getLogger(__name__)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Set up evaluation tracker # Set up evaluation tracker
if cfg.output_path: if cfg.output_path:
cfg.hf_hub_log_args["output_path"] = cfg.output_path cfg.hf_hub_log_args["output_path"] = cfg.output_path
...@@ -342,7 +341,7 @@ class Run(SubCommand): ...@@ -342,7 +341,7 @@ class Run(SubCommand):
evaluation_tracker = EvaluationTracker(**cfg.hf_hub_log_args) evaluation_tracker = EvaluationTracker(**cfg.hf_hub_log_args)
# Create task manager (metadata already set up in config validation) # Create task manager (metadata already set up in config validation)
task_manager = cfg.process_tasks() task_manager = cfg.process_tasks(cfg.metadata)
# Validation warnings (keep these in CLI as they're logging-specific) # Validation warnings (keep these in CLI as they're logging-specific)
if "push_samples_to_hub" in cfg.hf_hub_log_args and not cfg.log_samples: if "push_samples_to_hub" in cfg.hf_hub_log_args and not cfg.log_samples:
......
...@@ -17,8 +17,3 @@ class SubCommand(ABC): ...@@ -17,8 +17,3 @@ class SubCommand(ABC):
def _add_args(self) -> None: def _add_args(self) -> None:
"""Add arguments specific to this subcommand.""" """Add arguments specific to this subcommand."""
pass pass
@abstractmethod
def execute(self, args: argparse.Namespace) -> None:
"""Execute the subcommand with the given arguments."""
pass
...@@ -73,7 +73,7 @@ class Validate(SubCommand): ...@@ -73,7 +73,7 @@ class Validate(SubCommand):
formatter_class=argparse.RawDescriptionHelpFormatter, formatter_class=argparse.RawDescriptionHelpFormatter,
) )
self._add_args() self._add_args()
self._parser.set_defaults(func=lambda arg: self._parser.print_help()) self._parser.set_defaults(func=self._execute)
def _add_args(self) -> None: def _add_args(self) -> None:
self._parser.add_argument( self._parser.add_argument(
...@@ -92,7 +92,7 @@ class Validate(SubCommand): ...@@ -92,7 +92,7 @@ class Validate(SubCommand):
help="Additional path to include if there are external tasks.", help="Additional path to include if there are external tasks.",
) )
def execute(self, args: argparse.Namespace) -> None: def _execute(self, args: argparse.Namespace) -> None:
"""Execute the validate command.""" """Execute the validate command."""
from lm_eval.tasks import TaskManager from lm_eval.tasks import TaskManager
......
import json import json
import logging import logging
import textwrap
from argparse import Namespace from argparse import Namespace
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from pathlib import Path from pathlib import Path
...@@ -186,14 +187,6 @@ class EvaluatorConfig: ...@@ -186,14 +187,6 @@ class EvaluatorConfig:
metadata={"help": "Additional metadata for tasks that require it"}, metadata={"help": "Additional metadata for tasks that require it"},
) )
@staticmethod
def _parse_dict_args(config: Dict[str, Any]) -> Dict[str, Any]:
"""Parse string arguments that should be dictionaries."""
for key in config:
if key in DICT_KEYS and isinstance(config[key], str):
config[key] = simple_parse_args_string(config[key])
return config
@classmethod @classmethod
def from_cli(cls, namespace: Namespace) -> "EvaluatorConfig": def from_cli(cls, namespace: Namespace) -> "EvaluatorConfig":
""" """
...@@ -204,8 +197,8 @@ class EvaluatorConfig: ...@@ -204,8 +197,8 @@ class EvaluatorConfig:
config = asdict(cls()) config = asdict(cls())
# Load and merge YAML config if provided # Load and merge YAML config if provided
if hasattr(namespace, "config") and namespace.config: if used_config := hasattr(namespace, "config") and namespace.config:
config.update(cls._load_yaml_config(namespace.config)) config.update(cls.load_yaml_config(namespace.config))
# Override with CLI args (only truthy values, exclude non-config args) # Override with CLI args (only truthy values, exclude non-config args)
excluded_args = {"config", "command", "func"} # argparse internal args excluded_args = {"config", "command", "func"} # argparse internal args
...@@ -219,7 +212,9 @@ class EvaluatorConfig: ...@@ -219,7 +212,9 @@ class EvaluatorConfig:
# Create instance and validate # Create instance and validate
instance = cls(**config) instance = cls(**config)
instance.validate_and_preprocess() if used_config:
print(textwrap.dedent(f"""{instance}"""))
instance.configure()
return instance return instance
...@@ -230,19 +225,24 @@ class EvaluatorConfig: ...@@ -230,19 +225,24 @@ class EvaluatorConfig:
Merges with built-in defaults and validates. Merges with built-in defaults and validates.
""" """
# Load YAML config # Load YAML config
yaml_config = cls._load_yaml_config(config_path) yaml_config = cls.load_yaml_config(config_path)
# Parse string arguments that should be dictionaries # Parse string arguments that should be dictionaries
yaml_config = cls._parse_dict_args(yaml_config) yaml_config = cls._parse_dict_args(yaml_config)
# Create instance and validate
instance = cls(**yaml_config) instance = cls(**yaml_config)
instance.validate_and_preprocess() instance.configure()
return instance return instance
@staticmethod @staticmethod
def _load_yaml_config(config_path: Union[str, Path]) -> Dict[str, Any]: def _parse_dict_args(config: Dict[str, Any]) -> Dict[str, Any]:
"""Parse string arguments that should be dictionaries."""
for key in config:
if key in DICT_KEYS and isinstance(config[key], str):
config[key] = simple_parse_args_string(config[key])
return config
@staticmethod
def load_yaml_config(config_path: Union[str, Path]) -> Dict[str, Any]:
"""Load and validate YAML config file.""" """Load and validate YAML config file."""
config_file = ( config_file = (
Path(config_path) if not isinstance(config_path, Path) else config_path Path(config_path) if not isinstance(config_path, Path) else config_path
...@@ -252,6 +252,7 @@ class EvaluatorConfig: ...@@ -252,6 +252,7 @@ class EvaluatorConfig:
try: try:
yaml_data = yaml.safe_load(config_file.read_text()) yaml_data = yaml.safe_load(config_file.read_text())
print(textwrap.dedent(f"""yaml: {yaml_data}"""))
except yaml.YAMLError as e: except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML in {config_path}: {e}") raise ValueError(f"Invalid YAML in {config_path}: {e}")
except (OSError, UnicodeDecodeError) as e: except (OSError, UnicodeDecodeError) as e:
...@@ -264,11 +265,11 @@ class EvaluatorConfig: ...@@ -264,11 +265,11 @@ class EvaluatorConfig:
return yaml_data return yaml_data
def validate_and_preprocess(self) -> None: def configure(self) -> None:
"""Validate configuration and preprocess fields after creation.""" """Validate configuration and preprocess fields after creation."""
self._validate_arguments() self._validate_arguments()
self._process_arguments() self._process_arguments()
self._apply_trust_remote_code() self._set_trust_remote_code()
def _validate_arguments(self) -> None: def _validate_arguments(self) -> None:
"""Validate configuration arguments and cross-field constraints.""" """Validate configuration arguments and cross-field constraints."""
...@@ -365,7 +366,7 @@ class EvaluatorConfig: ...@@ -365,7 +366,7 @@ class EvaluatorConfig:
self.tasks = task_names self.tasks = task_names
return task_manager return task_manager
def _apply_trust_remote_code(self) -> None: def _set_trust_remote_code(self) -> None:
"""Apply trust_remote_code setting if enabled.""" """Apply trust_remote_code setting if enabled."""
if self.trust_remote_code: if self.trust_remote_code:
# HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally, # HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally,
......
# Language Model Evaluation Harness Configuration File
#
# This YAML configuration file allows you to specify evaluation parameters
# instead of passing them as command-line arguments.
#
# Usage:
# $ lm_eval --config configs/default_config.yaml
#
# You can override any values in this config with command-line arguments:
# $ lm_eval --config configs/default_config.yaml --model_args pretrained=gpt2 --tasks mmlu
#
# All parameters are optional and have the same meaning as their CLI counterparts.
model: hf
model_args:
pretrained: EleutherAI/pythia-14m
dtype: float16
tasks:
- hellaswag
- gsm8k
batch_size: 1
trust_remote_code: true
log_samples: true
output_path: ./test
limit: 10
import argparse
import sys
from unittest.mock import MagicMock, patch
import pytest
from lm_eval._cli.harness import HarnessCLI
from lm_eval._cli.ls import List
from lm_eval._cli.run import Run
from lm_eval._cli.utils import (
_int_or_none_list_arg_type,
check_argument_types,
request_caching_arg_to_dict,
try_parse_json,
)
from lm_eval._cli.validate import Validate
class TestHarnessCLI:
"""Test the main HarnessCLI class."""
def test_harness_cli_init(self):
"""Test HarnessCLI initialization."""
cli = HarnessCLI()
assert cli._parser is not None
assert cli._subparsers is not None
def test_harness_cli_has_subcommands(self):
"""Test that HarnessCLI has all expected subcommands."""
cli = HarnessCLI()
subcommands = cli._subparsers.choices
assert "run" in subcommands
assert "ls" in subcommands
assert "validate" in subcommands
def test_harness_cli_backward_compatibility(self):
"""Test backward compatibility: inserting 'run' when no subcommand is provided."""
cli = HarnessCLI()
test_args = ["lm-eval", "--model", "hf", "--tasks", "hellaswag"]
with patch.object(sys, "argv", test_args):
args = cli.parse_args()
assert args.command == "run"
assert args.model == "hf"
assert args.tasks == "hellaswag"
def test_harness_cli_help_default(self):
"""Test that help is printed when no arguments are provided."""
cli = HarnessCLI()
with patch.object(sys, "argv", ["lm-eval"]):
args = cli.parse_args()
# The func is a lambda that calls print_help
# Let's test it calls the help function correctly
with patch.object(cli._parser, "print_help") as mock_help:
args.func(args)
mock_help.assert_called_once()
def test_harness_cli_run_help_only(self):
"""Test that 'lm-eval run' shows help."""
cli = HarnessCLI()
with patch.object(sys, "argv", ["lm-eval", "run"]):
with pytest.raises(SystemExit):
cli.parse_args()
class TestListCommand:
"""Test the List subcommand."""
def test_list_command_creation(self):
"""Test List command creation."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
list_cmd = List.create(subparsers)
assert isinstance(list_cmd, List)
def test_list_command_arguments(self):
"""Test List command arguments."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
List.create(subparsers)
# Test valid arguments
args = parser.parse_args(["ls", "tasks"])
assert args.what == "tasks"
assert args.include_path is None
args = parser.parse_args(["ls", "groups", "--include_path", "/path/to/tasks"])
assert args.what == "groups"
assert args.include_path == "/path/to/tasks"
def test_list_command_choices(self):
"""Test List command only accepts valid choices."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
List.create(subparsers)
# Valid choices should work
for choice in ["tasks", "groups", "subtasks", "tags"]:
args = parser.parse_args(["ls", choice])
assert args.what == choice
# Invalid choice should fail
with pytest.raises(SystemExit):
parser.parse_args(["ls", "invalid"])
@patch("lm_eval.tasks.TaskManager")
def test_list_command_execute_tasks(self, mock_task_manager):
"""Test List command execution for tasks."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
list_cmd = List.create(subparsers)
mock_tm_instance = MagicMock()
mock_tm_instance.list_all_tasks.return_value = "task1\ntask2\ntask3"
mock_task_manager.return_value = mock_tm_instance
args = parser.parse_args(["ls", "tasks"])
with patch("builtins.print") as mock_print:
list_cmd._execute(args)
mock_print.assert_called_once_with("task1\ntask2\ntask3")
mock_tm_instance.list_all_tasks.assert_called_once_with()
@patch("lm_eval.tasks.TaskManager")
def test_list_command_execute_groups(self, mock_task_manager):
"""Test List command execution for groups."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
list_cmd = List.create(subparsers)
mock_tm_instance = MagicMock()
mock_tm_instance.list_all_tasks.return_value = "group1\ngroup2"
mock_task_manager.return_value = mock_tm_instance
args = parser.parse_args(["ls", "groups"])
with patch("builtins.print") as mock_print:
list_cmd._execute(args)
mock_print.assert_called_once_with("group1\ngroup2")
mock_tm_instance.list_all_tasks.assert_called_once_with(
list_subtasks=False, list_tags=False
)
class TestRunCommand:
"""Test the Run subcommand."""
def test_run_command_creation(self):
"""Test Run command creation."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
run_cmd = Run.create(subparsers)
assert isinstance(run_cmd, Run)
def test_run_command_basic_arguments(self):
"""Test Run command basic arguments."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
Run.create(subparsers)
args = parser.parse_args(
["run", "--model", "hf", "--tasks", "hellaswag,arc_easy"]
)
assert args.model == "hf"
assert args.tasks == "hellaswag,arc_easy"
def test_run_command_model_args(self):
"""Test Run command model arguments parsing."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
Run.create(subparsers)
# Test key=value format
args = parser.parse_args(["run", "--model_args", "pretrained=gpt2,device=cuda"])
assert args.model_args == "pretrained=gpt2,device=cuda"
# Test JSON format
args = parser.parse_args(
["run", "--model_args", '{"pretrained": "gpt2", "device": "cuda"}']
)
assert args.model_args == {"pretrained": "gpt2", "device": "cuda"}
def test_run_command_batch_size(self):
"""Test Run command batch size arguments."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
Run.create(subparsers)
# Test integer batch size
args = parser.parse_args(["run", "--batch_size", "32"])
assert args.batch_size == "32"
# Test auto batch size
args = parser.parse_args(["run", "--batch_size", "auto"])
assert args.batch_size == "auto"
# Test auto with repetitions
args = parser.parse_args(["run", "--batch_size", "auto:5"])
assert args.batch_size == "auto:5"
def test_run_command_seed_parsing(self):
"""Test Run command seed parsing."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
Run.create(subparsers)
# Test single seed
args = parser.parse_args(["run", "--seed", "42"])
assert args.seed == [42, 42, 42, 42]
# Test multiple seeds
args = parser.parse_args(["run", "--seed", "0,1234,5678,9999"])
assert args.seed == [0, 1234, 5678, 9999]
# Test with None values
args = parser.parse_args(["run", "--seed", "0,None,1234,None"])
assert args.seed == [0, None, 1234, None]
@patch("lm_eval.simple_evaluate")
@patch("lm_eval.config.evaluate_config.EvaluatorConfig")
@patch("lm_eval.loggers.EvaluationTracker")
@patch("lm_eval.utils.make_table")
def test_run_command_execute_basic(
self, mock_make_table, mock_tracker, mock_config, mock_simple_evaluate
):
"""Test Run command basic execution."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
run_cmd = Run.create(subparsers)
# Mock configuration
mock_cfg_instance = MagicMock()
mock_cfg_instance.wandb_args = None
mock_cfg_instance.output_path = None
mock_cfg_instance.hf_hub_log_args = {}
mock_cfg_instance.include_path = None
mock_cfg_instance.tasks = ["hellaswag"]
mock_cfg_instance.model = "hf"
mock_cfg_instance.model_args = {"pretrained": "gpt2"}
mock_cfg_instance.gen_kwargs = {}
mock_cfg_instance.limit = None
mock_cfg_instance.num_fewshot = 0
mock_cfg_instance.batch_size = 1
mock_cfg_instance.log_samples = False
mock_cfg_instance.process_tasks.return_value = MagicMock()
mock_config.from_cli.return_value = mock_cfg_instance
# Mock evaluation results
mock_simple_evaluate.return_value = {
"results": {"hellaswag": {"acc": 0.75}},
"config": {"batch_sizes": [1]},
"configs": {"hellaswag": {}},
"versions": {"hellaswag": "1.0"},
"n-shot": {"hellaswag": 0},
}
# Mock make_table to avoid complex table rendering
mock_make_table.return_value = (
"| Task | Result |\n|------|--------|\n| hellaswag | 0.75 |"
)
args = parser.parse_args(["run", "--model", "hf", "--tasks", "hellaswag"])
with patch("builtins.print"):
run_cmd._execute(args)
mock_config.from_cli.assert_called_once()
mock_simple_evaluate.assert_called_once()
mock_make_table.assert_called_once()
class TestValidateCommand:
"""Test the Validate subcommand."""
def test_validate_command_creation(self):
"""Test Validate command creation."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
validate_cmd = Validate.create(subparsers)
assert isinstance(validate_cmd, Validate)
def test_validate_command_arguments(self):
"""Test Validate command arguments."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
Validate.create(subparsers)
args = parser.parse_args(["validate", "--tasks", "hellaswag,arc_easy"])
assert args.tasks == "hellaswag,arc_easy"
assert args.include_path is None
args = parser.parse_args(
["validate", "--tasks", "custom_task", "--include_path", "/path/to/tasks"]
)
assert args.tasks == "custom_task"
assert args.include_path == "/path/to/tasks"
def test_validate_command_requires_tasks(self):
"""Test Validate command requires tasks argument."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
Validate.create(subparsers)
with pytest.raises(SystemExit):
parser.parse_args(["validate"])
@patch("lm_eval.tasks.TaskManager")
def test_validate_command_execute_success(self, mock_task_manager):
"""Test Validate command execution with valid tasks."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
validate_cmd = Validate.create(subparsers)
mock_tm_instance = MagicMock()
mock_tm_instance.match_tasks.return_value = ["hellaswag", "arc_easy"]
mock_task_manager.return_value = mock_tm_instance
args = parser.parse_args(["validate", "--tasks", "hellaswag,arc_easy"])
with patch("builtins.print") as mock_print:
validate_cmd._execute(args)
mock_print.assert_any_call("Validating tasks: ['hellaswag', 'arc_easy']")
mock_print.assert_any_call("All tasks found and valid")
@patch("lm_eval.tasks.TaskManager")
def test_validate_command_execute_missing_tasks(self, mock_task_manager):
"""Test Validate command execution with missing tasks."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
validate_cmd = Validate.create(subparsers)
mock_tm_instance = MagicMock()
mock_tm_instance.match_tasks.return_value = ["hellaswag"]
mock_task_manager.return_value = mock_tm_instance
args = parser.parse_args(["validate", "--tasks", "hellaswag,nonexistent"])
with patch("builtins.print") as mock_print:
with pytest.raises(SystemExit) as exc_info:
validate_cmd._execute(args)
assert exc_info.value.code == 1
mock_print.assert_any_call("Tasks not found: nonexistent")
class TestCLIUtils:
"""Test CLI utility functions."""
def test_try_parse_json_with_json_string(self):
"""Test try_parse_json with valid JSON string."""
result = try_parse_json('{"key": "value", "num": 42}')
assert result == {"key": "value", "num": 42}
def test_try_parse_json_with_dict(self):
"""Test try_parse_json with dict input."""
input_dict = {"key": "value"}
result = try_parse_json(input_dict)
assert result is input_dict
def test_try_parse_json_with_none(self):
"""Test try_parse_json with None."""
result = try_parse_json(None)
assert result is None
def test_try_parse_json_with_plain_string(self):
"""Test try_parse_json with plain string."""
result = try_parse_json("key=value,key2=value2")
assert result == "key=value,key2=value2"
def test_try_parse_json_with_invalid_json(self):
"""Test try_parse_json with invalid JSON."""
with pytest.raises(ValueError) as exc_info:
try_parse_json('{key: "value"}') # Invalid JSON (unquoted key)
assert "Invalid JSON" in str(exc_info.value)
assert "double quotes" in str(exc_info.value)
def test_int_or_none_list_single_value(self):
"""Test _int_or_none_list_arg_type with single value."""
result = _int_or_none_list_arg_type(3, 4, "0,1,2,3", "42")
assert result == [42, 42, 42, 42]
def test_int_or_none_list_multiple_values(self):
"""Test _int_or_none_list_arg_type with multiple values."""
result = _int_or_none_list_arg_type(3, 4, "0,1,2,3", "10,20,30,40")
assert result == [10, 20, 30, 40]
def test_int_or_none_list_with_none(self):
"""Test _int_or_none_list_arg_type with None values."""
result = _int_or_none_list_arg_type(3, 4, "0,1,2,3", "10,None,30,None")
assert result == [10, None, 30, None]
def test_int_or_none_list_invalid_value(self):
"""Test _int_or_none_list_arg_type with invalid value."""
with pytest.raises(ValueError):
_int_or_none_list_arg_type(3, 4, "0,1,2,3", "10,invalid,30,40")
def test_int_or_none_list_too_few_values(self):
"""Test _int_or_none_list_arg_type with too few values."""
with pytest.raises(ValueError):
_int_or_none_list_arg_type(3, 4, "0,1,2,3", "10,20")
def test_int_or_none_list_too_many_values(self):
"""Test _int_or_none_list_arg_type with too many values."""
with pytest.raises(ValueError):
_int_or_none_list_arg_type(3, 4, "0,1,2,3", "10,20,30,40,50")
def test_request_caching_arg_to_dict_none(self):
"""Test request_caching_arg_to_dict with None."""
result = request_caching_arg_to_dict(None)
assert result == {}
def test_request_caching_arg_to_dict_true(self):
"""Test request_caching_arg_to_dict with 'true'."""
result = request_caching_arg_to_dict("true")
assert result == {
"cache_requests": True,
"rewrite_requests_cache": False,
"delete_requests_cache": False,
}
def test_request_caching_arg_to_dict_refresh(self):
"""Test request_caching_arg_to_dict with 'refresh'."""
result = request_caching_arg_to_dict("refresh")
assert result == {
"cache_requests": True,
"rewrite_requests_cache": True,
"delete_requests_cache": False,
}
def test_request_caching_arg_to_dict_delete(self):
"""Test request_caching_arg_to_dict with 'delete'."""
result = request_caching_arg_to_dict("delete")
assert result == {
"cache_requests": False,
"rewrite_requests_cache": False,
"delete_requests_cache": True,
}
def test_check_argument_types_raises_on_untyped(self):
"""Test check_argument_types raises error for untyped arguments."""
parser = argparse.ArgumentParser()
parser.add_argument("--untyped") # No type specified
with pytest.raises(ValueError) as exc_info:
check_argument_types(parser)
assert "untyped" in str(exc_info.value)
assert "doesn't have a type specified" in str(exc_info.value)
def test_check_argument_types_passes_on_typed(self):
"""Test check_argument_types passes for typed arguments."""
parser = argparse.ArgumentParser()
parser.add_argument("--typed", type=str)
# Should not raise
check_argument_types(parser)
def test_check_argument_types_skips_const_actions(self):
"""Test check_argument_types skips const actions."""
parser = argparse.ArgumentParser()
parser.add_argument("--flag", action="store_const", const=True)
# Should not raise
check_argument_types(parser)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment