Commit b9ee592b authored by Baber's avatar Baber
Browse files

nit

parent f3cfff61
...@@ -13,7 +13,7 @@ Equivalently, running the library can be done via the `lm-eval` entrypoint at th ...@@ -13,7 +13,7 @@ Equivalently, running the library can be done via the `lm-eval` entrypoint at th
The CLI now uses a subcommand structure for better organization: The CLI now uses a subcommand structure for better organization:
- `lm-eval run` - Execute evaluations (default behavior) - `lm-eval run` - Execute evaluations (default behavior)
- `lm-eval list` - List available tasks, models, etc. - `lm-eval ls` - List available tasks, models, etc.
- `lm-eval validate` - Validate task configurations - `lm-eval validate` - Validate task configurations
For backward compatibility, if no subcommand is specified, `run` is automatically inserted. So `lm-eval --model hf --tasks hellaswag` is equivalent to `lm-eval run --model hf --tasks hellaswag`. For backward compatibility, if no subcommand is specified, `run` is automatically inserted. So `lm-eval --model hf --tasks hellaswag` is equivalent to `lm-eval run --model hf --tasks hellaswag`.
......
from lm_eval._cli.eval import Eval from lm_eval._cli.harness import HarnessCLI
from lm_eval.utils import setup_logging from lm_eval.utils import setup_logging
def cli_evaluate() -> None: def cli_evaluate() -> None:
"""Main CLI entry point with subcommand and legacy support.""" """Main CLI entry point."""
setup_logging() setup_logging()
parser = Eval() parser = HarnessCLI()
args = parser.parse_args() args = parser.parse_args()
parser.execute(args) parser.execute(args)
......
...@@ -2,12 +2,12 @@ import argparse ...@@ -2,12 +2,12 @@ import argparse
import sys import sys
import textwrap import textwrap
from lm_eval._cli.listall import ListAll from lm_eval._cli.ls import List
from lm_eval._cli.run import Run from lm_eval._cli.run import Run
from lm_eval._cli.validate import Validate from lm_eval._cli.validate import Validate
class Eval: class HarnessCLI:
"""Main CLI parser that manages all subcommands.""" """Main CLI parser that manages all subcommands."""
def __init__(self): def __init__(self):
...@@ -20,7 +20,7 @@ class Eval: ...@@ -20,7 +20,7 @@ class Eval:
lm-eval run --model hf --model_args pretrained=gpt2 --tasks hellaswag lm-eval run --model hf --model_args pretrained=gpt2 --tasks hellaswag
# List available tasks # List available tasks
lm-eval list tasks lm-eval ls tasks
# Validate task configurations # Validate task configurations
lm-eval validate --tasks hellaswag,arc_easy lm-eval validate --tasks hellaswag,arc_easy
...@@ -40,7 +40,7 @@ class Eval: ...@@ -40,7 +40,7 @@ class Eval:
dest="command", help="Available commands", metavar="COMMAND" dest="command", help="Available commands", metavar="COMMAND"
) )
Run.create(self._subparsers) Run.create(self._subparsers)
ListAll.create(self._subparsers) List.create(self._subparsers)
Validate.create(self._subparsers) Validate.create(self._subparsers)
def parse_args(self) -> argparse.Namespace: def parse_args(self) -> argparse.Namespace:
......
...@@ -4,33 +4,33 @@ import textwrap ...@@ -4,33 +4,33 @@ import textwrap
from lm_eval._cli.subcommand import SubCommand from lm_eval._cli.subcommand import SubCommand
class ListAll(SubCommand): class List(SubCommand):
"""Command for listing available tasks.""" """Command for listing available tasks."""
def __init__(self, subparsers: argparse._SubParsersAction, *args, **kwargs): def __init__(self, subparsers: argparse._SubParsersAction, *args, **kwargs):
# Create and configure the parser # Create and configure the parser
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._parser = subparsers.add_parser( self._parser = subparsers.add_parser(
"list", "ls",
help="List available tasks, groups, subtasks, or tags", help="List available tasks, groups, subtasks, or tags",
description="List available tasks, groups, subtasks, or tags from the evaluation harness.", description="List available tasks, groups, subtasks, or tags from the evaluation harness.",
usage="lm-eval list [tasks|groups|subtasks|tags] [--include_path DIR]", usage="lm-eval list [tasks|groups|subtasks|tags] [--include_path DIR]",
epilog=textwrap.dedent(""" epilog=textwrap.dedent("""
examples: examples:
# List all available tasks (includes groups, subtasks, and tags) # List all available tasks (includes groups, subtasks, and tags)
$ lm-eval list tasks $ lm-eval ls tasks
# List only task groups (like 'mmlu', 'glue', 'superglue') # List only task groups (like 'mmlu', 'glue', 'superglue')
$ lm-eval list groups $ lm-eval ls groups
# List only individual subtasks (like 'mmlu_abstract_algebra') # List only individual subtasks (like 'mmlu_abstract_algebra')
$ lm-eval list subtasks $ lm-eval ls subtasks
# Include external task definitions # Include external task definitions
$ lm-eval list tasks --include_path /path/to/external/tasks $ lm-eval ls tasks --include_path /path/to/external/tasks
# List tasks from multiple external paths # List tasks from multiple external paths
$ lm-eval list tasks --include_path "/path/to/tasks1:/path/to/tasks2" $ lm-eval ls tasks --include_path "/path/to/tasks1:/path/to/tasks2"
organization: organization:
• Groups: Collections of tasks with aggregated metric across subtasks (e.g., 'mmlu') • Groups: Collections of tasks with aggregated metric across subtasks (e.g., 'mmlu')
...@@ -46,7 +46,7 @@ class ListAll(SubCommand): ...@@ -46,7 +46,7 @@ class ListAll(SubCommand):
formatter_class=argparse.RawDescriptionHelpFormatter, formatter_class=argparse.RawDescriptionHelpFormatter,
) )
self._add_args() self._add_args()
self._parser.set_defaults(func=lambda arg: self._parser.print_help()) self._parser.set_defaults(func=self._execute)
def _add_args(self) -> None: def _add_args(self) -> None:
self._parser.add_argument( self._parser.add_argument(
...@@ -63,7 +63,7 @@ class ListAll(SubCommand): ...@@ -63,7 +63,7 @@ class ListAll(SubCommand):
help="Additional path to include if there are external tasks.", help="Additional path to include if there are external tasks.",
) )
def execute(self, args: argparse.Namespace) -> None: def _execute(self, args: argparse.Namespace) -> None:
"""Execute the list command.""" """Execute the list command."""
from lm_eval.tasks import TaskManager from lm_eval.tasks import TaskManager
......
...@@ -42,7 +42,7 @@ class Run(SubCommand): ...@@ -42,7 +42,7 @@ class Run(SubCommand):
formatter_class=argparse.RawDescriptionHelpFormatter, formatter_class=argparse.RawDescriptionHelpFormatter,
) )
self._add_args() self._add_args()
self._parser.set_defaults(func=self.execute) self._parser.set_defaults(func=self._execute)
def _add_args(self) -> None: def _add_args(self) -> None:
self._parser = self._parser self._parser = self._parser
...@@ -313,14 +313,17 @@ class Run(SubCommand): ...@@ -313,14 +313,17 @@ class Run(SubCommand):
), ),
) )
def execute(self, args: argparse.Namespace) -> None: def _execute(self, args: argparse.Namespace) -> None:
"""Runs the evaluation harness with the provided arguments.""" """Runs the evaluation harness with the provided arguments."""
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from lm_eval.config.evaluate_config import EvaluatorConfig from lm_eval.config.evaluate_config import EvaluatorConfig
# Create and validate config (most validation now happens in EvaluationConfig) eval_logger = logging.getLogger(__name__)
# Create and validate config (most validation now occurs in EvaluationConfig)
cfg = EvaluatorConfig.from_cli(args) cfg = EvaluatorConfig.from_cli(args)
from lm_eval import simple_evaluate, utils from lm_eval import simple_evaluate
from lm_eval.loggers import EvaluationTracker, WandbLogger from lm_eval.loggers import EvaluationTracker, WandbLogger
from lm_eval.utils import handle_non_serializable, make_table from lm_eval.utils import handle_non_serializable, make_table
...@@ -328,10 +331,6 @@ class Run(SubCommand): ...@@ -328,10 +331,6 @@ class Run(SubCommand):
if cfg.wandb_args: if cfg.wandb_args:
wandb_logger = WandbLogger(cfg.wandb_args, cfg.wandb_config_args) wandb_logger = WandbLogger(cfg.wandb_args, cfg.wandb_config_args)
utils.setup_logging(cfg.verbosity)
eval_logger = logging.getLogger(__name__)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Set up evaluation tracker # Set up evaluation tracker
if cfg.output_path: if cfg.output_path:
cfg.hf_hub_log_args["output_path"] = cfg.output_path cfg.hf_hub_log_args["output_path"] = cfg.output_path
...@@ -342,7 +341,7 @@ class Run(SubCommand): ...@@ -342,7 +341,7 @@ class Run(SubCommand):
evaluation_tracker = EvaluationTracker(**cfg.hf_hub_log_args) evaluation_tracker = EvaluationTracker(**cfg.hf_hub_log_args)
# Create task manager (metadata already set up in config validation) # Create task manager (metadata already set up in config validation)
task_manager = cfg.process_tasks() task_manager = cfg.process_tasks(cfg.metadata)
# Validation warnings (keep these in CLI as they're logging-specific) # Validation warnings (keep these in CLI as they're logging-specific)
if "push_samples_to_hub" in cfg.hf_hub_log_args and not cfg.log_samples: if "push_samples_to_hub" in cfg.hf_hub_log_args and not cfg.log_samples:
......
...@@ -17,8 +17,3 @@ class SubCommand(ABC): ...@@ -17,8 +17,3 @@ class SubCommand(ABC):
def _add_args(self) -> None: def _add_args(self) -> None:
"""Add arguments specific to this subcommand.""" """Add arguments specific to this subcommand."""
pass pass
@abstractmethod
def execute(self, args: argparse.Namespace) -> None:
"""Execute the subcommand with the given arguments."""
pass
...@@ -73,7 +73,7 @@ class Validate(SubCommand): ...@@ -73,7 +73,7 @@ class Validate(SubCommand):
formatter_class=argparse.RawDescriptionHelpFormatter, formatter_class=argparse.RawDescriptionHelpFormatter,
) )
self._add_args() self._add_args()
self._parser.set_defaults(func=lambda arg: self._parser.print_help()) self._parser.set_defaults(func=self._execute)
def _add_args(self) -> None: def _add_args(self) -> None:
self._parser.add_argument( self._parser.add_argument(
...@@ -92,7 +92,7 @@ class Validate(SubCommand): ...@@ -92,7 +92,7 @@ class Validate(SubCommand):
help="Additional path to include if there are external tasks.", help="Additional path to include if there are external tasks.",
) )
def execute(self, args: argparse.Namespace) -> None: def _execute(self, args: argparse.Namespace) -> None:
"""Execute the validate command.""" """Execute the validate command."""
from lm_eval.tasks import TaskManager from lm_eval.tasks import TaskManager
......
...@@ -187,14 +187,6 @@ class EvaluatorConfig: ...@@ -187,14 +187,6 @@ class EvaluatorConfig:
metadata={"help": "Additional metadata for tasks that require it"}, metadata={"help": "Additional metadata for tasks that require it"},
) )
@staticmethod
def _parse_dict_args(config: Dict[str, Any]) -> Dict[str, Any]:
"""Parse string arguments that should be dictionaries."""
for key in config:
if key in DICT_KEYS and isinstance(config[key], str):
config[key] = simple_parse_args_string(config[key])
return config
@classmethod @classmethod
def from_cli(cls, namespace: Namespace) -> "EvaluatorConfig": def from_cli(cls, namespace: Namespace) -> "EvaluatorConfig":
""" """
...@@ -206,7 +198,7 @@ class EvaluatorConfig: ...@@ -206,7 +198,7 @@ class EvaluatorConfig:
# Load and merge YAML config if provided # Load and merge YAML config if provided
if used_config := hasattr(namespace, "config") and namespace.config: if used_config := hasattr(namespace, "config") and namespace.config:
config.update(cls._load_yaml_config(namespace.config)) config.update(cls.load_yaml_config(namespace.config))
# Override with CLI args (only truthy values, exclude non-config args) # Override with CLI args (only truthy values, exclude non-config args)
excluded_args = {"config", "command", "func"} # argparse internal args excluded_args = {"config", "command", "func"} # argparse internal args
...@@ -222,7 +214,7 @@ class EvaluatorConfig: ...@@ -222,7 +214,7 @@ class EvaluatorConfig:
instance = cls(**config) instance = cls(**config)
if used_config: if used_config:
print(textwrap.dedent(f"""{instance}""")) print(textwrap.dedent(f"""{instance}"""))
instance.validate_and_preprocess() instance.configure()
return instance return instance
...@@ -233,19 +225,24 @@ class EvaluatorConfig: ...@@ -233,19 +225,24 @@ class EvaluatorConfig:
Merges with built-in defaults and validates. Merges with built-in defaults and validates.
""" """
# Load YAML config # Load YAML config
yaml_config = cls._load_yaml_config(config_path) yaml_config = cls.load_yaml_config(config_path)
# Parse string arguments that should be dictionaries # Parse string arguments that should be dictionaries
yaml_config = cls._parse_dict_args(yaml_config) yaml_config = cls._parse_dict_args(yaml_config)
# Create instance and validate
instance = cls(**yaml_config) instance = cls(**yaml_config)
instance.validate_and_preprocess() instance.configure()
return instance return instance
@staticmethod @staticmethod
def _load_yaml_config(config_path: Union[str, Path]) -> Dict[str, Any]: def _parse_dict_args(config: Dict[str, Any]) -> Dict[str, Any]:
"""Parse string arguments that should be dictionaries."""
for key in config:
if key in DICT_KEYS and isinstance(config[key], str):
config[key] = simple_parse_args_string(config[key])
return config
@staticmethod
def load_yaml_config(config_path: Union[str, Path]) -> Dict[str, Any]:
"""Load and validate YAML config file.""" """Load and validate YAML config file."""
config_file = ( config_file = (
Path(config_path) if not isinstance(config_path, Path) else config_path Path(config_path) if not isinstance(config_path, Path) else config_path
...@@ -268,11 +265,11 @@ class EvaluatorConfig: ...@@ -268,11 +265,11 @@ class EvaluatorConfig:
return yaml_data return yaml_data
def validate_and_preprocess(self) -> None: def configure(self) -> None:
"""Validate configuration and preprocess fields after creation.""" """Validate configuration and preprocess fields after creation."""
self._validate_arguments() self._validate_arguments()
self._process_arguments() self._process_arguments()
self._apply_trust_remote_code() self._set_trust_remote_code()
def _validate_arguments(self) -> None: def _validate_arguments(self) -> None:
"""Validate configuration arguments and cross-field constraints.""" """Validate configuration arguments and cross-field constraints."""
...@@ -369,7 +366,7 @@ class EvaluatorConfig: ...@@ -369,7 +366,7 @@ class EvaluatorConfig:
self.tasks = task_names self.tasks = task_names
return task_manager return task_manager
def _apply_trust_remote_code(self) -> None: def _set_trust_remote_code(self) -> None:
"""Apply trust_remote_code setting if enabled.""" """Apply trust_remote_code setting if enabled."""
if self.trust_remote_code: if self.trust_remote_code:
# HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally, # HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment