Commit 15ce554c authored by Baber's avatar Baber
Browse files

add tests

parent b9ee592b
...@@ -4,7 +4,7 @@ import logging ...@@ -4,7 +4,7 @@ import logging
from typing import Optional, Union from typing import Optional, Union
def try_parse_json(value: Union[dict, str]) -> Union[str, dict, None]: def try_parse_json(value: Union[str, dict, None]) -> Union[str, dict, None]:
"""Try to parse a string as JSON. If it fails, return the original string.""" """Try to parse a string as JSON. If it fails, return the original string."""
if value is None: if value is None:
return None return None
...@@ -69,7 +69,7 @@ def request_caching_arg_to_dict(cache_requests: Optional[str]) -> dict[str, bool ...@@ -69,7 +69,7 @@ def request_caching_arg_to_dict(cache_requests: Optional[str]) -> dict[str, bool
return request_caching_args return request_caching_args
def check_argument_types(parser: argparse.ArgumentParser): def check_argument_types(parser: argparse.ArgumentParser) -> None:
""" """
Check to make sure all CLI args are typed, raises error if not Check to make sure all CLI args are typed, raises error if not
""" """
......
import argparse
import sys
from unittest.mock import MagicMock, patch
import pytest
from lm_eval._cli.harness import HarnessCLI
from lm_eval._cli.ls import List
from lm_eval._cli.run import Run
from lm_eval._cli.utils import (
_int_or_none_list_arg_type,
check_argument_types,
request_caching_arg_to_dict,
try_parse_json,
)
from lm_eval._cli.validate import Validate
class TestHarnessCLI:
"""Test the main HarnessCLI class."""
def test_harness_cli_init(self):
"""Test HarnessCLI initialization."""
cli = HarnessCLI()
assert cli._parser is not None
assert cli._subparsers is not None
def test_harness_cli_has_subcommands(self):
"""Test that HarnessCLI has all expected subcommands."""
cli = HarnessCLI()
subcommands = cli._subparsers.choices
assert "run" in subcommands
assert "ls" in subcommands
assert "validate" in subcommands
def test_harness_cli_backward_compatibility(self):
"""Test backward compatibility: inserting 'run' when no subcommand is provided."""
cli = HarnessCLI()
test_args = ["lm-eval", "--model", "hf", "--tasks", "hellaswag"]
with patch.object(sys, "argv", test_args):
args = cli.parse_args()
assert args.command == "run"
assert args.model == "hf"
assert args.tasks == "hellaswag"
def test_harness_cli_help_default(self):
"""Test that help is printed when no arguments are provided."""
cli = HarnessCLI()
with patch.object(sys, "argv", ["lm-eval"]):
args = cli.parse_args()
# The func is a lambda that calls print_help
# Let's test it calls the help function correctly
with patch.object(cli._parser, "print_help") as mock_help:
args.func(args)
mock_help.assert_called_once()
def test_harness_cli_run_help_only(self):
"""Test that 'lm-eval run' shows help."""
cli = HarnessCLI()
with patch.object(sys, "argv", ["lm-eval", "run"]):
with pytest.raises(SystemExit):
cli.parse_args()
class TestListCommand:
"""Test the List subcommand."""
def test_list_command_creation(self):
"""Test List command creation."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
list_cmd = List.create(subparsers)
assert isinstance(list_cmd, List)
def test_list_command_arguments(self):
"""Test List command arguments."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
List.create(subparsers)
# Test valid arguments
args = parser.parse_args(["ls", "tasks"])
assert args.what == "tasks"
assert args.include_path is None
args = parser.parse_args(["ls", "groups", "--include_path", "/path/to/tasks"])
assert args.what == "groups"
assert args.include_path == "/path/to/tasks"
def test_list_command_choices(self):
"""Test List command only accepts valid choices."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
List.create(subparsers)
# Valid choices should work
for choice in ["tasks", "groups", "subtasks", "tags"]:
args = parser.parse_args(["ls", choice])
assert args.what == choice
# Invalid choice should fail
with pytest.raises(SystemExit):
parser.parse_args(["ls", "invalid"])
@patch("lm_eval.tasks.TaskManager")
def test_list_command_execute_tasks(self, mock_task_manager):
"""Test List command execution for tasks."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
list_cmd = List.create(subparsers)
mock_tm_instance = MagicMock()
mock_tm_instance.list_all_tasks.return_value = "task1\ntask2\ntask3"
mock_task_manager.return_value = mock_tm_instance
args = parser.parse_args(["ls", "tasks"])
with patch("builtins.print") as mock_print:
list_cmd._execute(args)
mock_print.assert_called_once_with("task1\ntask2\ntask3")
mock_tm_instance.list_all_tasks.assert_called_once_with()
@patch("lm_eval.tasks.TaskManager")
def test_list_command_execute_groups(self, mock_task_manager):
"""Test List command execution for groups."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
list_cmd = List.create(subparsers)
mock_tm_instance = MagicMock()
mock_tm_instance.list_all_tasks.return_value = "group1\ngroup2"
mock_task_manager.return_value = mock_tm_instance
args = parser.parse_args(["ls", "groups"])
with patch("builtins.print") as mock_print:
list_cmd._execute(args)
mock_print.assert_called_once_with("group1\ngroup2")
mock_tm_instance.list_all_tasks.assert_called_once_with(
list_subtasks=False, list_tags=False
)
class TestRunCommand:
"""Test the Run subcommand."""
def test_run_command_creation(self):
"""Test Run command creation."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
run_cmd = Run.create(subparsers)
assert isinstance(run_cmd, Run)
def test_run_command_basic_arguments(self):
"""Test Run command basic arguments."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
Run.create(subparsers)
args = parser.parse_args(
["run", "--model", "hf", "--tasks", "hellaswag,arc_easy"]
)
assert args.model == "hf"
assert args.tasks == "hellaswag,arc_easy"
def test_run_command_model_args(self):
"""Test Run command model arguments parsing."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
Run.create(subparsers)
# Test key=value format
args = parser.parse_args(["run", "--model_args", "pretrained=gpt2,device=cuda"])
assert args.model_args == "pretrained=gpt2,device=cuda"
# Test JSON format
args = parser.parse_args(
["run", "--model_args", '{"pretrained": "gpt2", "device": "cuda"}']
)
assert args.model_args == {"pretrained": "gpt2", "device": "cuda"}
def test_run_command_batch_size(self):
"""Test Run command batch size arguments."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
Run.create(subparsers)
# Test integer batch size
args = parser.parse_args(["run", "--batch_size", "32"])
assert args.batch_size == "32"
# Test auto batch size
args = parser.parse_args(["run", "--batch_size", "auto"])
assert args.batch_size == "auto"
# Test auto with repetitions
args = parser.parse_args(["run", "--batch_size", "auto:5"])
assert args.batch_size == "auto:5"
def test_run_command_seed_parsing(self):
"""Test Run command seed parsing."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
Run.create(subparsers)
# Test single seed
args = parser.parse_args(["run", "--seed", "42"])
assert args.seed == [42, 42, 42, 42]
# Test multiple seeds
args = parser.parse_args(["run", "--seed", "0,1234,5678,9999"])
assert args.seed == [0, 1234, 5678, 9999]
# Test with None values
args = parser.parse_args(["run", "--seed", "0,None,1234,None"])
assert args.seed == [0, None, 1234, None]
@patch("lm_eval.simple_evaluate")
@patch("lm_eval.config.evaluate_config.EvaluatorConfig")
@patch("lm_eval.loggers.EvaluationTracker")
@patch("lm_eval.utils.make_table")
def test_run_command_execute_basic(
self, mock_make_table, mock_tracker, mock_config, mock_simple_evaluate
):
"""Test Run command basic execution."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
run_cmd = Run.create(subparsers)
# Mock configuration
mock_cfg_instance = MagicMock()
mock_cfg_instance.wandb_args = None
mock_cfg_instance.output_path = None
mock_cfg_instance.hf_hub_log_args = {}
mock_cfg_instance.include_path = None
mock_cfg_instance.tasks = ["hellaswag"]
mock_cfg_instance.model = "hf"
mock_cfg_instance.model_args = {"pretrained": "gpt2"}
mock_cfg_instance.gen_kwargs = {}
mock_cfg_instance.limit = None
mock_cfg_instance.num_fewshot = 0
mock_cfg_instance.batch_size = 1
mock_cfg_instance.log_samples = False
mock_cfg_instance.process_tasks.return_value = MagicMock()
mock_config.from_cli.return_value = mock_cfg_instance
# Mock evaluation results
mock_simple_evaluate.return_value = {
"results": {"hellaswag": {"acc": 0.75}},
"config": {"batch_sizes": [1]},
"configs": {"hellaswag": {}},
"versions": {"hellaswag": "1.0"},
"n-shot": {"hellaswag": 0},
}
# Mock make_table to avoid complex table rendering
mock_make_table.return_value = (
"| Task | Result |\n|------|--------|\n| hellaswag | 0.75 |"
)
args = parser.parse_args(["run", "--model", "hf", "--tasks", "hellaswag"])
with patch("builtins.print"):
run_cmd._execute(args)
mock_config.from_cli.assert_called_once()
mock_simple_evaluate.assert_called_once()
mock_make_table.assert_called_once()
class TestValidateCommand:
"""Test the Validate subcommand."""
def test_validate_command_creation(self):
"""Test Validate command creation."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
validate_cmd = Validate.create(subparsers)
assert isinstance(validate_cmd, Validate)
def test_validate_command_arguments(self):
"""Test Validate command arguments."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
Validate.create(subparsers)
args = parser.parse_args(["validate", "--tasks", "hellaswag,arc_easy"])
assert args.tasks == "hellaswag,arc_easy"
assert args.include_path is None
args = parser.parse_args(
["validate", "--tasks", "custom_task", "--include_path", "/path/to/tasks"]
)
assert args.tasks == "custom_task"
assert args.include_path == "/path/to/tasks"
def test_validate_command_requires_tasks(self):
"""Test Validate command requires tasks argument."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
Validate.create(subparsers)
with pytest.raises(SystemExit):
parser.parse_args(["validate"])
@patch("lm_eval.tasks.TaskManager")
def test_validate_command_execute_success(self, mock_task_manager):
"""Test Validate command execution with valid tasks."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
validate_cmd = Validate.create(subparsers)
mock_tm_instance = MagicMock()
mock_tm_instance.match_tasks.return_value = ["hellaswag", "arc_easy"]
mock_task_manager.return_value = mock_tm_instance
args = parser.parse_args(["validate", "--tasks", "hellaswag,arc_easy"])
with patch("builtins.print") as mock_print:
validate_cmd._execute(args)
mock_print.assert_any_call("Validating tasks: ['hellaswag', 'arc_easy']")
mock_print.assert_any_call("All tasks found and valid")
@patch("lm_eval.tasks.TaskManager")
def test_validate_command_execute_missing_tasks(self, mock_task_manager):
"""Test Validate command execution with missing tasks."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
validate_cmd = Validate.create(subparsers)
mock_tm_instance = MagicMock()
mock_tm_instance.match_tasks.return_value = ["hellaswag"]
mock_task_manager.return_value = mock_tm_instance
args = parser.parse_args(["validate", "--tasks", "hellaswag,nonexistent"])
with patch("builtins.print") as mock_print:
with pytest.raises(SystemExit) as exc_info:
validate_cmd._execute(args)
assert exc_info.value.code == 1
mock_print.assert_any_call("Tasks not found: nonexistent")
class TestCLIUtils:
"""Test CLI utility functions."""
def test_try_parse_json_with_json_string(self):
"""Test try_parse_json with valid JSON string."""
result = try_parse_json('{"key": "value", "num": 42}')
assert result == {"key": "value", "num": 42}
def test_try_parse_json_with_dict(self):
"""Test try_parse_json with dict input."""
input_dict = {"key": "value"}
result = try_parse_json(input_dict)
assert result is input_dict
def test_try_parse_json_with_none(self):
"""Test try_parse_json with None."""
result = try_parse_json(None)
assert result is None
def test_try_parse_json_with_plain_string(self):
"""Test try_parse_json with plain string."""
result = try_parse_json("key=value,key2=value2")
assert result == "key=value,key2=value2"
def test_try_parse_json_with_invalid_json(self):
"""Test try_parse_json with invalid JSON."""
with pytest.raises(ValueError) as exc_info:
try_parse_json('{key: "value"}') # Invalid JSON (unquoted key)
assert "Invalid JSON" in str(exc_info.value)
assert "double quotes" in str(exc_info.value)
def test_int_or_none_list_single_value(self):
"""Test _int_or_none_list_arg_type with single value."""
result = _int_or_none_list_arg_type(3, 4, "0,1,2,3", "42")
assert result == [42, 42, 42, 42]
def test_int_or_none_list_multiple_values(self):
"""Test _int_or_none_list_arg_type with multiple values."""
result = _int_or_none_list_arg_type(3, 4, "0,1,2,3", "10,20,30,40")
assert result == [10, 20, 30, 40]
def test_int_or_none_list_with_none(self):
"""Test _int_or_none_list_arg_type with None values."""
result = _int_or_none_list_arg_type(3, 4, "0,1,2,3", "10,None,30,None")
assert result == [10, None, 30, None]
def test_int_or_none_list_invalid_value(self):
"""Test _int_or_none_list_arg_type with invalid value."""
with pytest.raises(ValueError):
_int_or_none_list_arg_type(3, 4, "0,1,2,3", "10,invalid,30,40")
def test_int_or_none_list_too_few_values(self):
"""Test _int_or_none_list_arg_type with too few values."""
with pytest.raises(ValueError):
_int_or_none_list_arg_type(3, 4, "0,1,2,3", "10,20")
def test_int_or_none_list_too_many_values(self):
"""Test _int_or_none_list_arg_type with too many values."""
with pytest.raises(ValueError):
_int_or_none_list_arg_type(3, 4, "0,1,2,3", "10,20,30,40,50")
def test_request_caching_arg_to_dict_none(self):
"""Test request_caching_arg_to_dict with None."""
result = request_caching_arg_to_dict(None)
assert result == {}
def test_request_caching_arg_to_dict_true(self):
"""Test request_caching_arg_to_dict with 'true'."""
result = request_caching_arg_to_dict("true")
assert result == {
"cache_requests": True,
"rewrite_requests_cache": False,
"delete_requests_cache": False,
}
def test_request_caching_arg_to_dict_refresh(self):
"""Test request_caching_arg_to_dict with 'refresh'."""
result = request_caching_arg_to_dict("refresh")
assert result == {
"cache_requests": True,
"rewrite_requests_cache": True,
"delete_requests_cache": False,
}
def test_request_caching_arg_to_dict_delete(self):
"""Test request_caching_arg_to_dict with 'delete'."""
result = request_caching_arg_to_dict("delete")
assert result == {
"cache_requests": False,
"rewrite_requests_cache": False,
"delete_requests_cache": True,
}
def test_check_argument_types_raises_on_untyped(self):
"""Test check_argument_types raises error for untyped arguments."""
parser = argparse.ArgumentParser()
parser.add_argument("--untyped") # No type specified
with pytest.raises(ValueError) as exc_info:
check_argument_types(parser)
assert "untyped" in str(exc_info.value)
assert "doesn't have a type specified" in str(exc_info.value)
def test_check_argument_types_passes_on_typed(self):
"""Test check_argument_types passes for typed arguments."""
parser = argparse.ArgumentParser()
parser.add_argument("--typed", type=str)
# Should not raise
check_argument_types(parser)
def test_check_argument_types_skips_const_actions(self):
"""Test check_argument_types skips const actions."""
parser = argparse.ArgumentParser()
parser.add_argument("--flag", action="store_const", const=True)
# Should not raise
check_argument_types(parser)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment