Commit 613d383e authored by Baber's avatar Baber
Browse files

cleanup

parent 768f55b3
from lm_eval._cli import Eval from lm_eval._cli.eval import Eval
def cli_evaluate() -> None: def cli_evaluate() -> None:
......
""" """
CLI subcommands to run from terminal. CLI subcommands to run from terminal.
""" """
from lm_eval._cli.base import SubCommand
from lm_eval._cli.eval import Eval
from lm_eval._cli.list import List
from lm_eval._cli.run import Run
from lm_eval._cli.validate import Validate
__all__ = [
"SubCommand",
"Run",
"List",
"Validate",
"Eval",
]
...@@ -2,7 +2,7 @@ import argparse ...@@ -2,7 +2,7 @@ import argparse
import sys import sys
import textwrap import textwrap
from lm_eval._cli.list import List from lm_eval._cli.listall import ListAll
from lm_eval._cli.run import Run from lm_eval._cli.run import Run
from lm_eval._cli.validate import Validate from lm_eval._cli.validate import Validate
...@@ -40,35 +40,20 @@ class Eval: ...@@ -40,35 +40,20 @@ class Eval:
dest="command", help="Available commands", metavar="COMMAND" dest="command", help="Available commands", metavar="COMMAND"
) )
Run.create(self._subparsers) Run.create(self._subparsers)
List.create(self._subparsers) ListAll.create(self._subparsers)
Validate.create(self._subparsers) Validate.create(self._subparsers)
def parse_args(self) -> argparse.Namespace: def parse_args(self) -> argparse.Namespace:
"""Parse arguments using the main parser.""" """Parse arguments using the main parser."""
if len(sys.argv) > 2 and sys.argv[1] not in self._subparsers.choices: if len(sys.argv) > 2 and sys.argv[1] not in self._subparsers.choices:
# Arguments provided but no valid subcommand - insert 'run' # Backward compatibility: arguments provided but no valid subcommand - insert 'run'
sys.argv.insert(1, "run") sys.argv.insert(1, "run")
elif len(sys.argv) == 2 and "run" in sys.argv:
# if only 'run' is specified, ensure it is treated as a subcommand
self._subparsers.choices["run"].print_help()
sys.exit(0)
return self._parser.parse_args() return self._parser.parse_args()
def execute(self, args: argparse.Namespace) -> None: def execute(self, args: argparse.Namespace) -> None:
"""Main execution method that handles subcommands and legacy support.""" """Main execution method that handles subcommands and legacy support."""
# Handle legacy task listing
if hasattr(args, "tasks") and args.tasks in [
"list",
"list_groups",
"list_subtasks",
"list_tags",
]:
print(
f"'--tasks {args.tasks}' is no longer supported.\n"
f"Use the 'list' command instead:\n",
file=sys.stderr,
)
# Show list command help
list_parser = self._subparsers.choices["list"]
list_parser.print_help()
sys.exit(1)
args.func(args) args.func(args)
import argparse import argparse
import textwrap import textwrap
from lm_eval._cli.base import SubCommand from lm_eval._cli.subcommand import SubCommand
class List(SubCommand): class ListAll(SubCommand):
"""Command for listing available tasks.""" """Command for listing available tasks."""
def __init__(self, subparsers: argparse._SubParsersAction, *args, **kwargs): def __init__(self, subparsers: argparse._SubParsersAction, *args, **kwargs):
......
...@@ -5,7 +5,7 @@ import os ...@@ -5,7 +5,7 @@ import os
import textwrap import textwrap
from functools import partial from functools import partial
from lm_eval._cli import SubCommand from lm_eval._cli.subcommand import SubCommand
from lm_eval._cli.utils import ( from lm_eval._cli.utils import (
_int_or_none_list_arg_type, _int_or_none_list_arg_type,
request_caching_arg_to_dict, request_caching_arg_to_dict,
...@@ -42,7 +42,7 @@ class Run(SubCommand): ...@@ -42,7 +42,7 @@ class Run(SubCommand):
formatter_class=argparse.RawDescriptionHelpFormatter, formatter_class=argparse.RawDescriptionHelpFormatter,
) )
self._add_args() self._add_args()
self._parser.set_defaults(func=lambda arg: self._parser.print_help()) self._parser.set_defaults(func=self.execute)
def _add_args(self) -> None: def _add_args(self) -> None:
self._parser = self._parser self._parser = self._parser
......
...@@ -2,7 +2,7 @@ import argparse ...@@ -2,7 +2,7 @@ import argparse
import sys import sys
import textwrap import textwrap
from lm_eval._cli.base import SubCommand from lm_eval._cli.subcommand import SubCommand
class Validate(SubCommand): class Validate(SubCommand):
...@@ -81,7 +81,7 @@ class Validate(SubCommand): ...@@ -81,7 +81,7 @@ class Validate(SubCommand):
"-t", "-t",
required=True, required=True,
type=str, type=str,
metavar="task1,task2", metavar="TASK1,TASK2",
help="Comma-separated list of task names to validate", help="Comma-separated list of task names to validate",
) )
self._parser.add_argument( self._parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment